MLIR 22.0.0git
Var.cpp
Go to the documentation of this file.
1//===- Var.cpp ------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Var.h"
10#include "DimLvlMap.h"
11
12using namespace mlir;
13using namespace mlir::sparse_tensor;
14using namespace mlir::sparse_tensor::ir_detail;
15
16//===----------------------------------------------------------------------===//
17// `VarKind` helpers.
18//===----------------------------------------------------------------------===//
19
20/// For use in foreach loops.
23
24//===----------------------------------------------------------------------===//
25// `Var` implementation.
26//===----------------------------------------------------------------------===//
27
28std::string Var::str() const {
29 std::string str;
30 llvm::raw_string_ostream os(str);
31 print(os);
32 return str;
33}
34
35void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
36
37void Var::print(llvm::raw_ostream &os) const {
38 os << toChar(getKind()) << getNum();
39}
40
41void Var::dump() const {
42 print(llvm::errs());
43 llvm::errs() << "\n";
44}
45
46//===----------------------------------------------------------------------===//
47// `Ranks` implementation.
48//===----------------------------------------------------------------------===//
49
50bool Ranks::operator==(Ranks const &other) const {
51 for (const auto vk : everyVarKind)
52 if (getRank(vk) != other.getRank(vk))
53 return false;
54 return true;
55}
56
57bool Ranks::isValid(DimLvlExpr expr) const {
58 assert(expr);
59 // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
60 // (each `DimLvlExpr` only allows one kind of non-symbol variable).
61 int64_t maxSym = -1, maxVar = -1;
63 maxVar, maxSym);
64 return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
65}
66
67//===----------------------------------------------------------------------===//
68// `VarSet` implementation.
69//===----------------------------------------------------------------------===//
70
71VarSet::VarSet(Ranks const &ranks) {
72 for (const auto vk : everyVarKind)
73 impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
74 assert(getRanks() == ranks);
75}
76
77bool VarSet::contains(Var var) const {
78 // NOTE: We make sure to return false on OOB, for consistency with
79 // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
80 // However beware that, as always with silencing OOB, this can hide
81 // bugs in client code.
82 const llvm::SmallBitVector &bits = impl[var.getKind()];
83 const auto num = var.getNum();
84 return num < bits.size() && bits[num];
85}
86
87void VarSet::add(Var var) {
88 // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
89 impl[var.getKind()][var.getNum()] = true;
90}
91
92void VarSet::add(VarSet const &other) {
93 // NOTE: `SmallBitVector::operator&=` will implicitly resize
94 // the bitvector (unlike `BitVector::operator&=`), so we add an
95 // assertion against OOB for consistency with the implementation
96 // of `VarSet::add(Var)`.
97 for (const auto vk : everyVarKind) {
98 assert(impl[vk].size() >= other.impl[vk].size());
99 impl[vk] &= other.impl[vk];
100 }
101}
102
104 if (!expr)
105 return;
106 switch (expr.getAffineKind()) {
108 return;
110 add(expr.castSymVar());
111 return;
113 add(expr.castDimLvlVar());
114 return;
120 const auto [lhs, op, rhs] = expr.unpackBinop();
121 (void)op;
122 add(lhs);
123 add(rhs);
124 return;
125 }
126 }
127 llvm_unreachable("unknown AffineExprKind");
128}
129
130//===----------------------------------------------------------------------===//
131// `VarInfo` implementation.
132//===----------------------------------------------------------------------===//
133
135 assert(!hasNum() && "Var::Num is already set");
136 assert(Var::isWF_Num(n) && "Var::Num is too large");
137 num = n;
138}
139
140//===----------------------------------------------------------------------===//
141// `VarEnv` implementation.
142//===----------------------------------------------------------------------===//
143
144/// Helper function for `assertUsageConsistency` to better handle SMLoc
145/// mismatches.
146[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1,
147 llvm::SMLoc sm2) {
148 const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
149 assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
150 const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
151 assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
152 if (loc1.getFilename() != loc2.getFilename())
153 return SMLoc();
154 const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
155 const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
156 return pair1 <= pair2 ? sm1 : sm2;
157}
158
159static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id,
160 StringRef name) {
161 const auto &var = env.access(id);
162 return (var.getName() == name && var.getID() == id);
163}
164
165static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id,
166 llvm::SMLoc loc, VarKind vk) {
167 const auto &var = env.access(id);
168 return var.getKind() == vk;
169}
170
171std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
172 const auto iter = ids.find(name);
173 if (iter == ids.end())
174 return std::nullopt;
175 const auto id = iter->second;
176 if (!isInternalConsistent(*this, id, name))
177 return std::nullopt;
178 return id;
179}
180
181std::optional<std::pair<VarInfo::ID, bool>>
182VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
183 const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
184 const auto id = iter->second;
185 if (didInsert) {
186 vars.emplace_back(id, name, loc, vk);
187 } else {
188 if (!isInternalConsistent(*this, id, name))
189 return std::nullopt;
190 if (verifyUsage)
191 if (!isUsageConsistent(*this, id, loc, vk))
192 return std::nullopt;
193 }
194 return std::make_pair(id, didInsert);
195}
196
197std::optional<std::pair<VarInfo::ID, bool>>
198VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
199 VarKind vk) {
200 switch (creationPolicy) {
201 case Policy::MustNot: {
202 const auto oid = lookup(name);
203 if (!oid)
204 return std::nullopt; // Doesn't exist, but must not create.
205 if (!isUsageConsistent(*this, *oid, loc, vk))
206 return std::nullopt;
207 return std::make_pair(*oid, false);
208 }
209 case Policy::May:
210 return create(name, loc, vk, /*verifyUsage=*/true);
211 case Policy::Must: {
212 const auto res = create(name, loc, vk, /*verifyUsage=*/false);
213 const auto didCreate = res->second;
214 if (!didCreate)
215 return std::nullopt; // Already exists, but must create.
216 return res;
217 }
218 }
219 llvm_unreachable("unknown Policy");
220}
221
222Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
224 auto &info = access(id);
225 const auto var = bindUnusedVar(info.getKind());
226 info.setNum(var.getNum());
227 return var;
228}
229
231 for (const auto &var : vars)
232 if (!var.hasNum())
233 return parser.emitError(var.getLoc(),
234 "Unbound variable: " + var.getName());
235 return {};
236}
237
238//===----------------------------------------------------------------------===//
lhs
static constexpr const VarKind everyVarKind[]
For use in foreach loops.
Definition Var.cpp:21
static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2)
Helper function for assertUsageConsistency to better handle SMLoc mismatches.
Definition Var.cpp:146
static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name)
Definition Var.cpp:159
static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, VarKind vk)
Definition Var.cpp:165
#define add(a, b)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
This class represents a diagnostic that is inflight and set to be reported.
constexpr AffineExpr getAffineExpr() const
Definition DimLvlMap.h:68
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
Definition DimLvlMap.cpp:40
constexpr VarKind getAllowedVarKind() const
Definition DimLvlMap.h:65
constexpr unsigned getRank(VarKind vk) const
Definition Var.h:228
bool operator==(Ranks const &other) const
Definition Var.cpp:50
constexpr unsigned getSymRank() const
Definition Var.h:229
constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
Definition Var.h:215
constexpr bool isValid(Var var) const
Definition Var.h:233
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
Definition Var.cpp:222
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
Definition Var.cpp:230
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Definition Var.cpp:198
std::optional< VarInfo::ID > lookup(StringRef name) const
Looks up the variable with the given name.
Definition Var.cpp:171
std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)
Creates a new currently-unbound variable.
Definition Var.cpp:182
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
Definition Var.cpp:223
VarInfo const & access(VarInfo::ID id) const
Gets the underlying storage for the VarInfo identified by the VarInfo::ID.
Definition Var.h:338
constexpr VarKind getKind() const
Definition Var.h:304
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
Definition Var.h:279
constexpr bool hasNum() const
Definition Var.h:306
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
Definition Var.cpp:77
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
Definition Var.cpp:87
A concrete variable, to be used in our variant of AffineExpr.
Definition Var.h:61
constexpr Num getNum() const
Definition Var.h:125
constexpr VarKind getKind() const
Definition Var.h:124
void print(llvm::raw_ostream &os) const
Definition Var.cpp:37
std::string str() const
Definition Var.cpp:28
static constexpr bool isWF_Num(Num n)
Checks whether the number would be accepted by Var(VarKind,Var::Num).
Definition Var.h:83
unsigned Num
Typedef for the type of variable numbers.
Definition Var.h:64
VarKind
The three kinds of variables that Var can be.
Definition Var.h:32
constexpr char toChar(VarKind vk)
Gets the ASCII character used as the prefix when printing Var.
Definition Var.h:40
Include the generated interface declarations.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ Constant
Constant integer.
Definition AffineExpr.h:57
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
static void getMaxDimAndSymbol(ArrayRef< AffineExprContainer > exprsList, int64_t &maxDim, int64_t &maxSym)
Calculates maximum dimension and symbol positions from the expressions in exprsLists and stores them ...
Definition AffineMap.h:697