MLIR  22.0.0git
Var.cpp
Go to the documentation of this file.
1 //===- Var.cpp ------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Var.h"
10 #include "DimLvlMap.h"
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 using namespace mlir::sparse_tensor::ir_detail;
15 
16 //===----------------------------------------------------------------------===//
17 // `VarKind` helpers.
18 //===----------------------------------------------------------------------===//
19 
20 /// For use in foreach loops.
21 static constexpr const VarKind everyVarKind[] = {
23 
24 //===----------------------------------------------------------------------===//
25 // `Var` implementation.
26 //===----------------------------------------------------------------------===//
27 
28 std::string Var::str() const {
29  std::string str;
30  llvm::raw_string_ostream os(str);
31  print(os);
32  return str;
33 }
34 
35 void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
36 
37 void Var::print(llvm::raw_ostream &os) const {
38  os << toChar(getKind()) << getNum();
39 }
40 
41 void Var::dump() const {
42  print(llvm::errs());
43  llvm::errs() << "\n";
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // `Ranks` implementation.
48 //===----------------------------------------------------------------------===//
49 
50 bool Ranks::operator==(Ranks const &other) const {
51  for (const auto vk : everyVarKind)
52  if (getRank(vk) != other.getRank(vk))
53  return false;
54  return true;
55 }
56 
57 bool Ranks::isValid(DimLvlExpr expr) const {
58  assert(expr);
59  // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
60  // (each `DimLvlExpr` only allows one kind of non-symbol variable).
61  int64_t maxSym = -1, maxVar = -1;
62  mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
63  maxVar, maxSym);
64  return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // `VarSet` implementation.
69 //===----------------------------------------------------------------------===//
70 
71 VarSet::VarSet(Ranks const &ranks) {
72  for (const auto vk : everyVarKind)
73  impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
74  assert(getRanks() == ranks);
75 }
76 
77 bool VarSet::contains(Var var) const {
78  // NOTE: We make sure to return false on OOB, for consistency with
79  // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
80  // However beware that, as always with silencing OOB, this can hide
81  // bugs in client code.
82  const llvm::SmallBitVector &bits = impl[var.getKind()];
83  const auto num = var.getNum();
84  return num < bits.size() && bits[num];
85 }
86 
87 void VarSet::add(Var var) {
88  // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
89  impl[var.getKind()][var.getNum()] = true;
90 }
91 
92 void VarSet::add(VarSet const &other) {
93  // NOTE: `SmallBitVector::operator&=` will implicitly resize
94  // the bitvector (unlike `BitVector::operator&=`), so we add an
95  // assertion against OOB for consistency with the implementation
96  // of `VarSet::add(Var)`.
97  for (const auto vk : everyVarKind) {
98  assert(impl[vk].size() >= other.impl[vk].size());
99  impl[vk] &= other.impl[vk];
100  }
101 }
102 
104  if (!expr)
105  return;
106  switch (expr.getAffineKind()) {
108  return;
110  add(expr.castSymVar());
111  return;
113  add(expr.castDimLvlVar());
114  return;
115  case AffineExprKind::Add:
116  case AffineExprKind::Mul:
117  case AffineExprKind::Mod:
120  const auto [lhs, op, rhs] = expr.unpackBinop();
121  (void)op;
122  add(lhs);
123  add(rhs);
124  return;
125  }
126  }
127  llvm_unreachable("unknown AffineExprKind");
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // `VarInfo` implementation.
132 //===----------------------------------------------------------------------===//
133 
135  assert(!hasNum() && "Var::Num is already set");
136  assert(Var::isWF_Num(n) && "Var::Num is too large");
137  num = n;
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // `VarEnv` implementation.
142 //===----------------------------------------------------------------------===//
143 
144 /// Helper function for `assertUsageConsistency` to better handle SMLoc
145 /// mismatches.
146 LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
147 minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
148  const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
149  assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
150  const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
151  assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
152  if (loc1.getFilename() != loc2.getFilename())
153  return SMLoc();
154  const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
155  const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
156  return pair1 <= pair2 ? sm1 : sm2;
157 }
158 
159 static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id,
160  StringRef name) {
161  const auto &var = env.access(id);
162  return (var.getName() == name && var.getID() == id);
163 }
164 
165 static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id,
166  llvm::SMLoc loc, VarKind vk) {
167  const auto &var = env.access(id);
168  return var.getKind() == vk;
169 }
170 
171 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
172  const auto iter = ids.find(name);
173  if (iter == ids.end())
174  return std::nullopt;
175  const auto id = iter->second;
176  if (!isInternalConsistent(*this, id, name))
177  return std::nullopt;
178  return id;
179 }
180 
181 std::optional<std::pair<VarInfo::ID, bool>>
182 VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
183  const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
184  const auto id = iter->second;
185  if (didInsert) {
186  vars.emplace_back(id, name, loc, vk);
187  } else {
188  if (!isInternalConsistent(*this, id, name))
189  return std::nullopt;
190  if (verifyUsage)
191  if (!isUsageConsistent(*this, id, loc, vk))
192  return std::nullopt;
193  }
194  return std::make_pair(id, didInsert);
195 }
196 
197 std::optional<std::pair<VarInfo::ID, bool>>
198 VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
199  VarKind vk) {
200  switch (creationPolicy) {
201  case Policy::MustNot: {
202  const auto oid = lookup(name);
203  if (!oid)
204  return std::nullopt; // Doesn't exist, but must not create.
205  if (!isUsageConsistent(*this, *oid, loc, vk))
206  return std::nullopt;
207  return std::make_pair(*oid, false);
208  }
209  case Policy::May:
210  return create(name, loc, vk, /*verifyUsage=*/true);
211  case Policy::Must: {
212  const auto res = create(name, loc, vk, /*verifyUsage=*/false);
213  const auto didCreate = res->second;
214  if (!didCreate)
215  return std::nullopt; // Already exists, but must create.
216  return res;
217  }
218  }
219  llvm_unreachable("unknown Policy");
220 }
221 
222 Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
224  auto &info = access(id);
225  const auto var = bindUnusedVar(info.getKind());
226  info.setNum(var.getNum());
227  return var;
228 }
229 
231  for (const auto &var : vars)
232  if (!var.hasNum())
233  return parser.emitError(var.getLoc(),
234  "Unbound variable: " + var.getName());
235  return {};
236 }
237 
238 //===----------------------------------------------------------------------===//
static LLVM_ATTRIBUTE_UNUSED llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2)
Helper function for assertUsageConsistency to better handle SMLoc mismatches.
Definition: Var.cpp:147
static constexpr const VarKind everyVarKind[]
For use in foreach loops.
Definition: Var.cpp:21
static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name)
Definition: Var.cpp:159
static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, VarKind vk)
Definition: Var.cpp:165
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
AffineExprKind getAffineKind() const
Definition: DimLvlMap.h:69
constexpr AffineExpr getAffineExpr() const
Definition: DimLvlMap.h:68
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
Definition: DimLvlMap.cpp:40
constexpr VarKind getAllowedVarKind() const
Definition: DimLvlMap.h:65
constexpr unsigned getRank(VarKind vk) const
Definition: Var.h:228
bool operator==(Ranks const &other) const
Definition: Var.cpp:50
constexpr unsigned getSymRank() const
Definition: Var.h:229
constexpr bool isValid(Var var) const
Definition: Var.h:233
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
Definition: Var.cpp:222
VarInfo const & access(VarInfo::ID id) const
Gets the underlying storage for the VarInfo identified by the VarInfo::ID.
Definition: Var.h:338
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
Definition: Var.cpp:230
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Definition: Var.cpp:198
std::optional< VarInfo::ID > lookup(StringRef name) const
Looks up the variable with the given name.
Definition: Var.cpp:171
std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)
Creates a new currently-unbound variable.
Definition: Var.cpp:182
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
Definition: Var.cpp:223
constexpr VarKind getKind() const
Definition: Var.h:304
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
Definition: Var.h:279
constexpr bool hasNum() const
Definition: Var.h:306
Efficient representation of a set of Var.
Definition: Var.h:242
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
Definition: Var.cpp:77
VarSet(Ranks const &ranks)
Definition: Var.cpp:71
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
Definition: Var.cpp:87
A concrete variable, to be used in our variant of AffineExpr.
Definition: Var.h:61
constexpr Num getNum() const
Definition: Var.h:125
constexpr VarKind getKind() const
Definition: Var.h:124
void print(llvm::raw_ostream &os) const
Definition: Var.cpp:37
std::string str() const
Definition: Var.cpp:28
static constexpr bool isWF_Num(Num n)
Checks whether the number would be accepted by Var(VarKind,Var::Num).
Definition: Var.h:83
unsigned Num
Typedef for the type of variable numbers.
Definition: Var.h:64
VarKind
The three kinds of variables that Var can be.
Definition: Var.h:32
constexpr char toChar(VarKind vk)
Gets the ASCII character used as the prefix when printing Var.
Definition: Var.h:40
Include the generated interface declarations.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.