MLIR  14.0.0git
SymbolTable.h
Go to the documentation of this file.
1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 namespace mlir {
18 
19 /// This class allows for representing and managing the symbol table used by
20 /// operations with the 'SymbolTable' trait. Inserting into and erasing from
21 /// this SymbolTable will also insert and erase from the Operation given to it
22 /// at construction.
23 class SymbolTable {
24 public:
25  /// Build a symbol table with the symbols within the given operation.
26  SymbolTable(Operation *symbolTableOp);
27 
28  /// Look up a symbol with the specified name, returning null if no such
29  /// name exists. Names never include the @ on them.
30  Operation *lookup(StringRef name) const;
31  template <typename T>
32  T lookup(StringRef name) const {
33  return dyn_cast_or_null<T>(lookup(name));
34  }
35 
36  /// Look up a symbol with the specified name, returning null if no such
37  /// name exists. Names never include the @ on them.
38  Operation *lookup(StringAttr name) const;
39  template <typename T>
40  T lookup(StringAttr name) const {
41  return dyn_cast_or_null<T>(lookup(name));
42  }
43 
44  /// Erase the given symbol from the table.
45  void erase(Operation *symbol);
46 
47  /// Insert a new symbol into the table, and rename it as necessary to avoid
48  /// collisions. Also insert at the specified location in the body of the
49  /// associated operation if it is not already there. It is asserted that the
50  /// symbol is not inside another operation. Return the name of the symbol
51  /// after insertion as attribute.
52  StringAttr insert(Operation *symbol, Block::iterator insertPt = {});
53 
54  /// Return the name of the attribute used for symbol names.
55  static StringRef getSymbolAttrName() { return "sym_name"; }
56 
57  /// Returns the associated operation.
58  Operation *getOp() const { return symbolTableOp; }
59 
60  /// Return the name of the attribute used for symbol visibility.
61  static StringRef getVisibilityAttrName() { return "sym_visibility"; }
62 
63  //===--------------------------------------------------------------------===//
64  // Symbol Utilities
65  //===--------------------------------------------------------------------===//
66 
67  /// An enumeration detailing the different visibility types that a symbol may
68  /// have.
69  enum class Visibility {
70  /// The symbol is public and may be referenced anywhere internal or external
71  /// to the visible references in the IR.
72  Public,
73 
74  /// The symbol is private and may only be referenced by SymbolRefAttrs local
75  /// to the operations within the current symbol table.
76  Private,
77 
78  /// The symbol is visible to the current IR, which may include operations in
79  /// symbol tables above the one that owns the current symbol. `Nested`
80  /// visibility allows for referencing a symbol outside of its current symbol
81  /// table, while retaining the ability to observe all uses.
82  Nested,
83  };
84 
85  /// Returns the name of the given symbol operation, aborting if no symbol is
86  /// present.
87  static StringAttr getSymbolName(Operation *symbol);
88 
89  /// Sets the name of the given symbol operation.
90  static void setSymbolName(Operation *symbol, StringAttr name);
91  static void setSymbolName(Operation *symbol, StringRef name) {
92  setSymbolName(symbol, StringAttr::get(symbol->getContext(), name));
93  }
94 
95  /// Returns the visibility of the given symbol operation.
97  /// Sets the visibility of the given symbol operation.
98  static void setSymbolVisibility(Operation *symbol, Visibility vis);
99 
100  /// Returns the nearest symbol table from a given operation `from`. Returns
101  /// nullptr if no valid parent symbol table could be found.
103 
104  /// Walks all symbol table operations nested within, and including, `op`. For
105  /// each symbol table operation, the provided callback is invoked with the op
106  /// and a boolean signifying if the symbols within that symbol table can be
107  /// treated as if all uses within the IR are visible to the caller.
108  /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
109  /// within `op` are visible.
110  static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
111  function_ref<void(Operation *, bool)> callback);
112 
113  /// Returns the operation registered with the given symbol name with the
114  /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
115  /// with the 'OpTrait::SymbolTable' trait.
116  static Operation *lookupSymbolIn(Operation *op, StringAttr symbol);
117  static Operation *lookupSymbolIn(Operation *op, StringRef symbol) {
118  return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol));
119  }
120  static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
121  /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
122  /// by a given SymbolRefAttr. Returns failure if any of the nested references
123  /// could not be resolved.
124  static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
126 
127  /// Returns the operation registered with the given symbol name within the
128  /// closest parent operation of, or including, 'from' with the
129  /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
130  /// found.
131  static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
133  SymbolRefAttr symbol);
134  template <typename T>
135  static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
136  return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
137  }
138  template <typename T>
139  static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
140  return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
141  }
142 
143  /// This class represents a specific symbol use.
144  class SymbolUse {
145  public:
146  SymbolUse(Operation *op, SymbolRefAttr symbolRef)
147  : owner(op), symbolRef(symbolRef) {}
148 
149  /// Return the operation user of this symbol reference.
150  Operation *getUser() const { return owner; }
151 
152  /// Return the symbol reference that this use represents.
153  SymbolRefAttr getSymbolRef() const { return symbolRef; }
154 
155  private:
156  /// The operation that this access is held by.
157  Operation *owner;
158 
159  /// The symbol reference that this use represents.
160  SymbolRefAttr symbolRef;
161  };
162 
163  /// This class implements a range of SymbolRef uses.
164  class UseRange {
165  public:
166  UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
167 
168  using iterator = std::vector<SymbolUse>::const_iterator;
169  iterator begin() const { return uses.begin(); }
170  iterator end() const { return uses.end(); }
171  bool empty() const { return uses.empty(); }
172 
173  private:
174  std::vector<SymbolUse> uses;
175  };
176 
177  /// Get an iterator range for all of the uses, for any symbol, that are nested
178  /// within the given operation 'from'. This does not traverse into any nested
179  /// symbol tables. This function returns None if there are any unknown
180  /// operations that may potentially be symbol tables.
183 
184  /// Get all of the uses of the given symbol that are nested within the given
185  /// operation 'from'. This does not traverse into any nested symbol tables.
186  /// This function returns None if there are any unknown operations that may
187  /// potentially be symbol tables.
188  static Optional<UseRange> getSymbolUses(StringAttr symbol, Operation *from);
189  static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
190  static Optional<UseRange> getSymbolUses(StringAttr symbol, Region *from);
191  static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
192 
193  /// Return if the given symbol is known to have no uses that are nested
194  /// within the given operation 'from'. This does not traverse into any nested
195  /// symbol tables. This function will also return false if there are any
196  /// unknown operations that may potentially be symbol tables. This doesn't
197  /// necessarily mean that there are no uses, we just can't conservatively
198  /// prove it.
199  static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from);
200  static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
201  static bool symbolKnownUseEmpty(StringAttr symbol, Region *from);
202  static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
203 
204  /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
205  /// provided symbol 'newSymbol' that are nested within the given operation
206  /// 'from'. This does not traverse into any nested symbol tables. If there are
207  /// any unknown operations that may potentially be symbol tables, no uses are
208  /// replaced and failure is returned.
209  static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
210  StringAttr newSymbol,
211  Operation *from);
212  static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
213  StringAttr newSymbolName,
214  Operation *from);
215  static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
216  StringAttr newSymbol, Region *from);
217  static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
218  StringAttr newSymbolName,
219  Region *from);
220 
221 private:
222  Operation *symbolTableOp;
223 
224  /// This is a mapping from a name to the symbol with that name. They key is
225  /// always known to be a StringAttr.
227 
228  /// This is used when name conflicts are detected.
229  unsigned uniquingCounter = 0;
230 };
231 
232 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility);
233 
234 //===----------------------------------------------------------------------===//
235 // SymbolTableCollection
236 //===----------------------------------------------------------------------===//
237 
238 /// This class represents a collection of `SymbolTable`s. This simplifies
239 /// certain algorithms that run recursively on nested symbol tables. Symbol
240 /// tables are constructed lazily to reduce the upfront cost of constructing
241 /// unnecessary tables.
243 public:
244  /// Look up a symbol with the specified name within the specified symbol table
245  /// operation, returning null if no such name exists.
246  Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
247  Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
248  template <typename T, typename NameT>
249  T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
250  return dyn_cast_or_null<T>(
251  lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
252  }
253  /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
254  /// by a given SymbolRefAttr when resolved within the provided symbol table
255  /// operation. Returns failure if any of the nested references could not be
256  /// resolved.
257  LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
259 
260  /// Returns the operation registered with the given symbol name within the
261  /// closest parent operation of, or including, 'from' with the
262  /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
263  /// found.
264  Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
265  Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
266  template <typename T>
267  T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
268  return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
269  }
270  template <typename T>
271  T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
272  return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
273  }
274 
275  /// Lookup, or create, a symbol table for an operation.
276  SymbolTable &getSymbolTable(Operation *op);
277 
278 private:
279  /// The constructed symbol tables nested within this table.
281 };
282 
283 //===----------------------------------------------------------------------===//
284 // SymbolUserMap
285 //===----------------------------------------------------------------------===//
286 
287 /// This class represents a map of symbols to users, and provides efficient
288 /// implementations of symbol queries related to users; such as collecting the
289 /// users of a symbol, replacing all uses, etc.
291 public:
292  /// Build a user map for all of the symbols defined in regions nested under
293  /// 'symbolTableOp'. A reference to the provided symbol table collection is
294  /// kept by the user map to ensure efficient lookups, thus the lifetime should
295  /// extend beyond that of this map.
296  SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
297 
298  /// Return the users of the provided symbol operation.
300  auto it = symbolToUsers.find(symbol);
301  return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None;
302  }
303 
304  /// Return true if the given symbol has no uses.
305  bool useEmpty(Operation *symbol) const {
306  return !symbolToUsers.count(symbol);
307  }
308 
309  /// Replace all of the uses of the given symbol with `newSymbolName`.
310  void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName);
311 
312 private:
313  /// A reference to the symbol table used to construct this map.
314  SymbolTableCollection &symbolTable;
315 
316  /// A map of symbol operations to symbol users.
318 };
319 
320 //===----------------------------------------------------------------------===//
321 // SymbolTable Trait Types
322 //===----------------------------------------------------------------------===//
323 
324 namespace detail {
327 } // namespace detail
328 
329 namespace OpTrait {
330 /// A trait used to provide symbol table functionalities to a region operation.
331 /// This operation must hold exactly 1 region. Once attached, all operations
332 /// that are directly within the region, i.e not including those within child
333 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
334 /// be verified to ensure that the names are uniqued. These operations must also
335 /// adhere to the constraints defined by the `Symbol` trait, even if they do not
336 /// inherit from it.
337 template <typename ConcreteType>
338 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
339 public:
342  }
343 
344  /// Look up a symbol with the specified name, returning null if no such
345  /// name exists. Symbol names never include the @ on them. Note: This
346  /// performs a linear scan of held symbols.
347  Operation *lookupSymbol(StringAttr name) {
348  return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
349  }
350  template <typename T>
351  T lookupSymbol(StringAttr name) {
352  return dyn_cast_or_null<T>(lookupSymbol(name));
353  }
354  Operation *lookupSymbol(SymbolRefAttr symbol) {
355  return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
356  }
357  template <typename T>
358  T lookupSymbol(SymbolRefAttr symbol) {
359  return dyn_cast_or_null<T>(lookupSymbol(symbol));
360  }
361 
362  Operation *lookupSymbol(StringRef name) {
363  return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
364  }
365  template <typename T>
366  T lookupSymbol(StringRef name) {
367  return dyn_cast_or_null<T>(lookupSymbol(name));
368  }
369 };
370 
371 } // namespace OpTrait
372 
373 //===----------------------------------------------------------------------===//
374 // Visibility parsing implementation.
375 //===----------------------------------------------------------------------===//
376 
377 namespace impl {
378 /// Parse an optional visibility attribute keyword (i.e., public, private, or
379 /// nested) without quotes in a string attribute named 'attrName'.
381  NamedAttrList &attrs);
382 } // namespace impl
383 
384 } // namespace mlir
385 
386 /// Include the generated symbol interfaces.
387 #include "mlir/IR/SymbolInterfaces.h.inc"
388 
389 #endif // MLIR_IR_SYMBOLTABLE_H
Include the generated interface declarations.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present. ...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const
Definition: SymbolTable.h:249
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
Operation * getUser() const
Return the operation user of this symbol reference.
Definition: SymbolTable.h:150
SymbolUse(Operation *op, SymbolRefAttr symbolRef)
Definition: SymbolTable.h:146
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
T lookupSymbol(SymbolRefAttr symbol)
Definition: SymbolTable.h:358
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of &#39;symbolTableOp&#39;.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static Optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
T lookup(StringRef name) const
Definition: SymbolTable.h:32
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition: SymbolTable.h:61
Operation * lookupSymbol(SymbolRefAttr symbol)
Definition: SymbolTable.h:354
static LogicalResult verifyTrait(Operation *op)
Definition: SymbolTable.h:340
The symbol is public and may be referenced anywhere internal or external to the visible references in...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
void erase(Operation *symbol)
Erase the given symbol from the table.
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:164
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol)
Definition: SymbolTable.h:139
Operation * lookupSymbol(StringRef name)
Definition: SymbolTable.h:362
static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Definition: SymbolTable.h:135
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition: SymbolTable.h:69
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
OpListType::iterator iterator
Definition: Block.h:131
Operation * getOp() const
Returns the associated operation.
Definition: SymbolTable.h:58
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation &#39;from&#39;...
T lookupSymbol(StringAttr name)
Definition: SymbolTable.h:351
UseRange(std::vector< SymbolUse > &&uses)
Definition: SymbolTable.h:166
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
SymbolRefAttr getSymbolRef() const
Return the symbol reference that this use represents.
Definition: SymbolTable.h:153
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
This class represents a map of symbols to users, and provides efficient implementations of symbol que...
Definition: SymbolTable.h:290
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
ArrayRef< Operation * > getUsers(Operation *symbol) const
Return the users of the provided symbol operation.
Definition: SymbolTable.h:299
T lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Definition: SymbolTable.h:267
LogicalResult verifySymbolTable(Operation *op)
Helper class for implementing traits.
Definition: OpDefinition.h:291
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
LogicalResult verifySymbol(Operation *op)
bool useEmpty(Operation *symbol) const
Return true if the given symbol has no uses.
Definition: SymbolTable.h:305
static void setSymbolName(Operation *symbol, StringRef name)
Definition: SymbolTable.h:91
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
T lookup(StringAttr name) const
Definition: SymbolTable.h:40
iterator begin() const
Definition: SymbolTable.h:169
T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol)
Definition: SymbolTable.h:271
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol &#39;oldSymbol&#39; with the provided symbol &#39;newSymbol&#39; that...
static Operation * lookupSymbolIn(Operation *op, StringRef symbol)
Definition: SymbolTable.h:117
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
This class represents a specific symbol use.
Definition: SymbolTable.h:144
std::vector< SymbolUse >::const_iterator iterator
Definition: SymbolTable.h:168
Operation * lookupSymbol(StringAttr name)
Look up a symbol with the specified name, returning null if no such name exists.
Definition: SymbolTable.h:347
T lookupSymbol(StringRef name)
Definition: SymbolTable.h:366