MLIR  19.0.0git
Var.cpp
Go to the documentation of this file.
1 //===- Var.cpp ------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Var.h"
10 #include "DimLvlMap.h"
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 using namespace mlir::sparse_tensor::ir_detail;
15 
16 //===----------------------------------------------------------------------===//
17 // `VarKind` helpers.
18 //===----------------------------------------------------------------------===//
19 
20 /// For use in foreach loops.
21 static constexpr const VarKind everyVarKind[] = {
23 
24 //===----------------------------------------------------------------------===//
25 // `Var` implementation.
26 //===----------------------------------------------------------------------===//
27 
28 std::string Var::str() const {
29  std::string str;
30  llvm::raw_string_ostream os(str);
31  print(os);
32  return os.str();
33 }
34 
35 void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
36 
37 void Var::print(llvm::raw_ostream &os) const {
38  os << toChar(getKind()) << getNum();
39 }
40 
41 void Var::dump() const {
42  print(llvm::errs());
43  llvm::errs() << "\n";
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // `Ranks` implementation.
48 //===----------------------------------------------------------------------===//
49 
50 bool Ranks::operator==(Ranks const &other) const {
51  for (const auto vk : everyVarKind)
52  if (getRank(vk) != other.getRank(vk))
53  return false;
54  return true;
55 }
56 
57 bool Ranks::isValid(DimLvlExpr expr) const {
58  assert(expr);
59  // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
60  // (each `DimLvlExpr` only allows one kind of non-symbol variable).
61  int64_t maxSym = -1, maxVar = -1;
62  mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
63  maxVar, maxSym);
64  return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // `VarSet` implementation.
69 //===----------------------------------------------------------------------===//
70 
71 VarSet::VarSet(Ranks const &ranks) {
72  for (const auto vk : everyVarKind)
73  impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
74  assert(getRanks() == ranks);
75 }
76 
77 bool VarSet::contains(Var var) const {
78  // NOTE: We make sure to return false on OOB, for consistency with
79  // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
80  // However beware that, as always with silencing OOB, this can hide
81  // bugs in client code.
82  const llvm::SmallBitVector &bits = impl[var.getKind()];
83  const auto num = var.getNum();
84  return num < bits.size() && bits[num];
85 }
86 
87 void VarSet::add(Var var) {
88  // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
89  impl[var.getKind()][var.getNum()] = true;
90 }
91 
92 void VarSet::add(VarSet const &other) {
93  // NOTE: `SmallBitVector::operator&=` will implicitly resize
94  // the bitvector (unlike `BitVector::operator&=`), so we add an
95  // assertion against OOB for consistency with the implementation
96  // of `VarSet::add(Var)`.
97  for (const auto vk : everyVarKind) {
98  assert(impl[vk].size() >= other.impl[vk].size());
99  impl[vk] &= other.impl[vk];
100  }
101 }
102 
104  if (!expr)
105  return;
106  switch (expr.getAffineKind()) {
108  return;
110  add(expr.castSymVar());
111  return;
113  add(expr.castDimLvlVar());
114  return;
115  case AffineExprKind::Add:
116  case AffineExprKind::Mul:
117  case AffineExprKind::Mod:
120  const auto [lhs, op, rhs] = expr.unpackBinop();
121  (void)op;
122  add(lhs);
123  add(rhs);
124  return;
125  }
126  }
127  llvm_unreachable("unknown AffineExprKind");
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // `VarInfo` implementation.
132 //===----------------------------------------------------------------------===//
133 
135  assert(!hasNum() && "Var::Num is already set");
136  assert(Var::isWF_Num(n) && "Var::Num is too large");
137  num = n;
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // `VarEnv` implementation.
142 //===----------------------------------------------------------------------===//
143 
144 /// Helper function for `assertUsageConsistency` to better handle SMLoc
145 /// mismatches.
146 LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
147 minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
148  const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast<FileLineColLoc>();
149  assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
150  const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast<FileLineColLoc>();
151  assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
152  if (loc1.getFilename() != loc2.getFilename())
153  return SMLoc();
154  const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
155  const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
156  return pair1 <= pair2 ? sm1 : sm2;
157 }
158 
159 bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
160  const auto &var = env.access(id);
161  return (var.getName() == name && var.getID() == id);
162 }
163 
164 bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
165  VarKind vk) {
166  const auto &var = env.access(id);
167  return var.getKind() == vk;
168 }
169 
170 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
171  const auto iter = ids.find(name);
172  if (iter == ids.end())
173  return std::nullopt;
174  const auto id = iter->second;
175  if (!isInternalConsistent(*this, id, name))
176  return std::nullopt;
177  return id;
178 }
179 
180 std::optional<std::pair<VarInfo::ID, bool>>
181 VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
182  const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
183  const auto id = iter->second;
184  if (didInsert) {
185  vars.emplace_back(id, name, loc, vk);
186  } else {
187  if (!isInternalConsistent(*this, id, name))
188  return std::nullopt;
189  if (verifyUsage)
190  if (!isUsageConsistent(*this, id, loc, vk))
191  return std::nullopt;
192  }
193  return std::make_pair(id, didInsert);
194 }
195 
196 std::optional<std::pair<VarInfo::ID, bool>>
197 VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
198  VarKind vk) {
199  switch (creationPolicy) {
200  case Policy::MustNot: {
201  const auto oid = lookup(name);
202  if (!oid)
203  return std::nullopt; // Doesn't exist, but must not create.
204  if (!isUsageConsistent(*this, *oid, loc, vk))
205  return std::nullopt;
206  return std::make_pair(*oid, false);
207  }
208  case Policy::May:
209  return create(name, loc, vk, /*verifyUsage=*/true);
210  case Policy::Must: {
211  const auto res = create(name, loc, vk, /*verifyUsage=*/false);
212  const auto didCreate = res->second;
213  if (!didCreate)
214  return std::nullopt; // Already exists, but must create.
215  return res;
216  }
217  }
218  llvm_unreachable("unknown Policy");
219 }
220 
221 Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
223  auto &info = access(id);
224  const auto var = bindUnusedVar(info.getKind());
225  info.setNum(var.getNum());
226  return var;
227 }
228 
230  for (const auto &var : vars)
231  if (!var.hasNum())
232  return parser.emitError(var.getLoc(),
233  "Unbound variable: " + var.getName());
234  return {};
235 }
236 
237 //===----------------------------------------------------------------------===//
static LLVM_ATTRIBUTE_UNUSED llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2)
Helper function for assertUsageConsistency to better handle SMLoc mismatches.
Definition: Var.cpp:147
static constexpr const VarKind everyVarKind[]
For use in foreach loops.
Definition: Var.cpp:21
bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, VarKind vk)
Definition: Var.cpp:164
bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name)
Definition: Var.cpp:159
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
U dyn_cast() const
Definition: Location.h:85
AffineExprKind getAffineKind() const
Definition: DimLvlMap.h:69
constexpr AffineExpr getAffineExpr() const
Definition: DimLvlMap.h:68
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
Definition: DimLvlMap.cpp:40
constexpr VarKind getAllowedVarKind() const
Definition: DimLvlMap.h:65
constexpr unsigned getRank(VarKind vk) const
Definition: Var.h:228
bool operator==(Ranks const &other) const
Definition: Var.cpp:50
constexpr unsigned getSymRank() const
Definition: Var.h:229
constexpr bool isValid(Var var) const
Definition: Var.h:233
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
Definition: Var.cpp:221
VarInfo const & access(VarInfo::ID id) const
Gets the underlying storage for the VarInfo identified by the VarInfo::ID.
Definition: Var.h:338
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
Definition: Var.cpp:229
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Definition: Var.cpp:197
std::optional< VarInfo::ID > lookup(StringRef name) const
Looks up the variable with the given name.
Definition: Var.cpp:170
std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)
Creates a new currently-unbound variable.
Definition: Var.cpp:181
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
Definition: Var.cpp:222
constexpr VarKind getKind() const
Definition: Var.h:304
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
Definition: Var.h:279
constexpr bool hasNum() const
Definition: Var.h:306
Efficient representation of a set of Var.
Definition: Var.h:242
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
Definition: Var.cpp:77
VarSet(Ranks const &ranks)
Definition: Var.cpp:71
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
Definition: Var.cpp:87
A concrete variable, to be used in our variant of AffineExpr.
Definition: Var.h:61
constexpr Num getNum() const
Definition: Var.h:125
constexpr VarKind getKind() const
Definition: Var.h:124
void print(llvm::raw_ostream &os) const
Definition: Var.cpp:37
std::string str() const
Definition: Var.cpp:28
static constexpr bool isWF_Num(Num n)
Checks whether the number would be accepted by Var(VarKind,Var::Num).
Definition: Var.h:83
unsigned Num
Typedef for the type of variable numbers.
Definition: Var.h:64
VarKind
The three kinds of variables that Var can be.
Definition: Var.h:32
constexpr char toChar(VarKind vk)
Gets the ASCII character used as the prefix when printing Var.
Definition: Var.h:40
Include the generated interface declarations.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.