MLIR  18.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 /// Return true if the given operation is unknown and may potentially define a
21 /// symbol table.
23  return op->getNumRegions() == 1 && !op->getDialect();
24 }
25 
26 /// Returns the string name of the given symbol, or null if this is not a
27 /// symbol.
28 static StringAttr getNameIfSymbol(Operation *op) {
29  return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
30 }
31 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32  return op->getAttrOfType<StringAttr>(symbolAttrNameId);
33 }
34 
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
39 static LogicalResult
40 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41  Operation *within,
43  assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44  MLIRContext *ctx = symbol->getContext();
45 
46  auto leafRef = FlatSymbolRefAttr::get(symbolName);
47  results.push_back(leafRef);
48 
49  // Early exit for when 'within' is the parent of 'symbol'.
50  Operation *symbolTableOp = symbol->getParentOp();
51  if (within == symbolTableOp)
52  return success();
53 
54  // Collect references until 'symbolTableOp' reaches 'within'.
55  SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56  StringAttr symbolNameId =
58  do {
59  // Each parent of 'symbol' should define a symbol table.
60  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61  return failure();
62  // Each parent of 'symbol' should also be a symbol.
63  StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64  if (!symbolTableName)
65  return failure();
66  results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
67 
68  symbolTableOp = symbolTableOp->getParentOp();
69  if (symbolTableOp == within)
70  break;
71  nestedRefs.insert(nestedRefs.begin(),
72  FlatSymbolRefAttr::get(symbolTableName));
73  } while (true);
74  return success();
75 }
76 
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional<WalkResult>
82  function_ref<std::optional<WalkResult>(Operation *)> callback) {
83  SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
84  while (!worklist.empty()) {
85  for (Operation &op : worklist.pop_back_val()->getOps()) {
86  std::optional<WalkResult> result = callback(&op);
87  if (result != WalkResult::advance())
88  return result;
89 
90  // If this op defines a new symbol table scope, we can't traverse. Any
91  // symbol references nested within 'op' are different semantically.
92  if (!op.hasTrait<OpTrait::SymbolTable>()) {
93  for (Region &region : op.getRegions())
94  worklist.push_back(&region);
95  }
96  }
97  }
98  return WalkResult::advance();
99 }
100 
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional<WalkResult>
106  function_ref<std::optional<WalkResult>(Operation *)> callback) {
107  std::optional<WalkResult> result = callback(op);
108  if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109  return result;
110  return walkSymbolTable(op->getRegions(), callback);
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // SymbolTable
115 //===----------------------------------------------------------------------===//
116 
117 /// Build a symbol table with the symbols within the given operation.
119  : symbolTableOp(symbolTableOp) {
120  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121  "expected operation to have SymbolTable trait");
122  assert(symbolTableOp->getNumRegions() == 1 &&
123  "expected operation to have a single region");
124  assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125  "expected operation to have a single block");
126 
127  StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
129  for (auto &op : symbolTableOp->getRegion(0).front()) {
130  StringAttr name = getNameIfSymbol(&op, symbolNameId);
131  if (!name)
132  continue;
133 
134  auto inserted = symbolTable.insert({name, &op});
135  (void)inserted;
136  assert(inserted.second &&
137  "expected region to contain uniquely named symbol operations");
138  }
139 }
140 
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation *SymbolTable::lookup(StringRef name) const {
144  return lookup(StringAttr::get(symbolTableOp->getContext(), name));
145 }
146 Operation *SymbolTable::lookup(StringAttr name) const {
147  return symbolTable.lookup(name);
148 }
149 
151  StringAttr name = getNameIfSymbol(op);
152  assert(name && "expected valid 'name' attribute");
153  assert(op->getParentOp() == symbolTableOp &&
154  "expected this operation to be inside of the operation with this "
155  "SymbolTable");
156 
157  auto it = symbolTable.find(name);
158  if (it != symbolTable.end() && it->second == op)
159  symbolTable.erase(it);
160 }
161 
163  remove(symbol);
164  symbol->erase();
165 }
166 
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172  // The symbol cannot be the child of another op and must be the child of the
173  // symbolTableOp after this.
174  //
175  // TODO: consider if SymbolTable's constructor should behave the same.
176  if (!symbol->getParentOp()) {
177  auto &body = symbolTableOp->getRegion(0).front();
178  if (insertPt == Block::iterator()) {
179  insertPt = Block::iterator(body.end());
180  } else {
181  assert((insertPt == body.end() ||
182  insertPt->getParentOp() == symbolTableOp) &&
183  "expected insertPt to be in the associated module operation");
184  }
185  // Insert before the terminator, if any.
186  if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187  std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
188  insertPt = std::prev(body.end());
189 
190  body.getOperations().insert(insertPt, symbol);
191  }
192  assert(symbol->getParentOp() == symbolTableOp &&
193  "symbol is already inserted in another op");
194 
195  // Add this symbol to the symbol table, uniquing the name if a conflict is
196  // detected.
197  StringAttr name = getSymbolName(symbol);
198  if (symbolTable.insert({name, symbol}).second)
199  return name;
200  // If the symbol was already in the table, also return.
201  if (symbolTable.lookup(name) == symbol)
202  return name;
203  // If a conflict was detected, then the symbol will not have been added to
204  // the symbol table. Try suffixes until we get to a unique name that works.
205  SmallString<128> nameBuffer(name.getValue());
206  unsigned originalLength = nameBuffer.size();
207 
208  MLIRContext *context = symbol->getContext();
209 
210  // Iteratively try suffixes until we find one that isn't used.
211  do {
212  nameBuffer.resize(originalLength);
213  nameBuffer += '_';
214  nameBuffer += std::to_string(uniquingCounter++);
215  } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
216  .second);
217  setSymbolName(symbol, nameBuffer);
218  return getSymbolName(symbol);
219 }
220 
221 /// Returns the name of the given symbol operation.
223  StringAttr name = getNameIfSymbol(symbol);
224  assert(name && "expected valid symbol name");
225  return name;
226 }
227 
228 /// Sets the name of the given symbol operation.
229 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
230  symbol->setAttr(getSymbolAttrName(), name);
231 }
232 
233 /// Returns the visibility of the given symbol operation.
235  // If the attribute doesn't exist, assume public.
236  StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
237  if (!vis)
238  return Visibility::Public;
239 
240  // Otherwise, switch on the string value.
241  return StringSwitch<Visibility>(vis.getValue())
242  .Case("private", Visibility::Private)
243  .Case("nested", Visibility::Nested)
244  .Case("public", Visibility::Public);
245 }
246 /// Sets the visibility of the given symbol operation.
248  MLIRContext *ctx = symbol->getContext();
249 
250  // If the visibility is public, just drop the attribute as this is the
251  // default.
252  if (vis == Visibility::Public) {
254  return;
255  }
256 
257  // Otherwise, update the attribute.
258  assert((vis == Visibility::Private || vis == Visibility::Nested) &&
259  "unknown symbol visibility kind");
260 
261  StringRef visName = vis == Visibility::Private ? "private" : "nested";
262  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
263 }
264 
265 /// Returns the nearest symbol table from a given operation `from`. Returns
266 /// nullptr if no valid parent symbol table could be found.
268  assert(from && "expected valid operation");
270  return nullptr;
271 
272  while (!from->hasTrait<OpTrait::SymbolTable>()) {
273  from = from->getParentOp();
274 
275  // Check that this is a valid op and isn't an unknown symbol table.
276  if (!from || isPotentiallyUnknownSymbolTable(from))
277  return nullptr;
278  }
279  return from;
280 }
281 
282 /// Walks all symbol table operations nested within, and including, `op`. For
283 /// each symbol table operation, the provided callback is invoked with the op
284 /// and a boolean signifying if the symbols within that symbol table can be
285 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
286 /// all of the symbol uses of symbols within `op` are visible.
288  Operation *op, bool allSymUsesVisible,
289  function_ref<void(Operation *, bool)> callback) {
290  bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
291  if (isSymbolTable) {
292  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
293  allSymUsesVisible |= !symbol || symbol.isPrivate();
294  } else {
295  // Otherwise if 'op' is not a symbol table, any nested symbols are
296  // guaranteed to be hidden.
297  allSymUsesVisible = true;
298  }
299 
300  for (Region &region : op->getRegions())
301  for (Block &block : region)
302  for (Operation &nestedOp : block)
303  walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
304 
305  // If 'op' had the symbol table trait, visit it after any nested symbol
306  // tables.
307  if (isSymbolTable)
308  callback(op, allSymUsesVisible);
309 }
310 
311 /// Returns the operation registered with the given symbol name with the
312 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
313 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
314 /// was found.
316  StringAttr symbol) {
317  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
318  Region &region = symbolTableOp->getRegion(0);
319  if (region.empty())
320  return nullptr;
321 
322  // Look for a symbol with the given name.
323  StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
325  for (auto &op : region.front())
326  if (getNameIfSymbol(&op, symbolNameId) == symbol)
327  return &op;
328  return nullptr;
329 }
331  SymbolRefAttr symbol) {
332  SmallVector<Operation *, 4> resolvedSymbols;
333  if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
334  return nullptr;
335  return resolvedSymbols.back();
336 }
337 
338 /// Internal implementation of `lookupSymbolIn` that allows for specialized
339 /// implementations of the lookup function.
341  Operation *symbolTableOp, SymbolRefAttr symbol,
343  function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
344  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
345 
346  // Lookup the root reference for this symbol.
347  symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
348  if (!symbolTableOp)
349  return failure();
350  symbols.push_back(symbolTableOp);
351 
352  // If there are no nested references, just return the root symbol directly.
353  ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
354  if (nestedRefs.empty())
355  return success();
356 
357  // Verify that the root is also a symbol table.
358  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
359  return failure();
360 
361  // Otherwise, lookup each of the nested non-leaf references and ensure that
362  // each corresponds to a valid symbol table.
363  for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
364  symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
365  if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
366  return failure();
367  symbols.push_back(symbolTableOp);
368  }
369  symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
370  return success(symbols.back());
371 }
372 
374 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
375  SmallVectorImpl<Operation *> &symbols) {
376  auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
377  return lookupSymbolIn(symbolTableOp, symbol);
378  };
379  return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
380 }
381 
382 /// Returns the operation registered with the given symbol name within the
383 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
384 /// nullptr if no valid symbol was found.
386  StringAttr symbol) {
387  Operation *symbolTableOp = getNearestSymbolTable(from);
388  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
389 }
391  SymbolRefAttr symbol) {
392  Operation *symbolTableOp = getNearestSymbolTable(from);
393  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
394 }
395 
396 raw_ostream &mlir::operator<<(raw_ostream &os,
397  SymbolTable::Visibility visibility) {
398  switch (visibility) {
400  return os << "public";
402  return os << "private";
404  return os << "nested";
405  }
406  llvm_unreachable("Unexpected visibility");
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // SymbolTable Trait Types
411 //===----------------------------------------------------------------------===//
412 
414  if (op->getNumRegions() != 1)
415  return op->emitOpError()
416  << "Operations with a 'SymbolTable' must have exactly one region";
417  if (!llvm::hasSingleElement(op->getRegion(0)))
418  return op->emitOpError()
419  << "Operations with a 'SymbolTable' must have exactly one block";
420 
421  // Check that all symbols are uniquely named within child regions.
422  DenseMap<Attribute, Location> nameToOrigLoc;
423  for (auto &block : op->getRegion(0)) {
424  for (auto &op : block) {
425  // Check for a symbol name attribute.
426  auto nameAttr =
428  if (!nameAttr)
429  continue;
430 
431  // Try to insert this symbol into the table.
432  auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
433  if (!it.second)
434  return op.emitError()
435  .append("redefinition of symbol named '", nameAttr.getValue(), "'")
436  .attachNote(it.first->second)
437  .append("see existing symbol definition here");
438  }
439  }
440 
441  // Verify any nested symbol user operations.
442  SymbolTableCollection symbolTable;
443  auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
444  if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
445  return WalkResult(user.verifySymbolUses(symbolTable));
446  return WalkResult::advance();
447  };
448 
449  std::optional<WalkResult> result =
450  walkSymbolTable(op->getRegions(), verifySymbolUserFn);
451  return success(result && !result->wasInterrupted());
452 }
453 
455  // Verify the name attribute.
456  if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
457  return op->emitOpError() << "requires string attribute '"
459 
460  // Verify the visibility attribute.
462  StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
463  if (!visStrAttr)
464  return op->emitOpError() << "requires visibility attribute '"
466  << "' to be a string attribute, but got " << vis;
467 
468  if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
469  visStrAttr.getValue()))
470  return op->emitOpError()
471  << "visibility expected to be one of [\"public\", \"private\", "
472  "\"nested\"], but got "
473  << visStrAttr;
474  }
475  return success();
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Symbol Use Lists
480 //===----------------------------------------------------------------------===//
481 
482 /// Walk all of the symbol references within the given operation, invoking the
483 /// provided callback for each found use. The callbacks takes the use of the
484 /// symbol.
485 static WalkResult
488  return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
489  [&](SymbolRefAttr symbolRef) {
490  if (callback({op, symbolRef}).wasInterrupted())
491  return WalkResult::interrupt();
492 
493  // Don't walk nested references.
494  return WalkResult::skip();
495  });
496 }
497 
498 /// Walk all of the uses, for any symbol, that are nested within the given
499 /// regions, invoking the provided callback for each. This does not traverse
500 /// into any nested symbol tables.
501 static std::optional<WalkResult>
504  return walkSymbolTable(regions,
505  [&](Operation *op) -> std::optional<WalkResult> {
506  // Check that this isn't a potentially unknown symbol
507  // table.
509  return std::nullopt;
510 
511  return walkSymbolRefs(op, callback);
512  });
513 }
514 /// Walk all of the uses, for any symbol, that are nested within the given
515 /// operation 'from', invoking the provided callback for each. This does not
516 /// traverse into any nested symbol tables.
517 static std::optional<WalkResult>
520  // If this operation has regions, and it, as well as its dialect, isn't
521  // registered then conservatively fail. The operation may define a
522  // symbol table, so we can't opaquely know if we should traverse to find
523  // nested uses.
525  return std::nullopt;
526 
527  // Walk the uses on this operation.
528  if (walkSymbolRefs(from, callback).wasInterrupted())
529  return WalkResult::interrupt();
530 
531  // Only recurse if this operation is not a symbol table. A symbol table
532  // defines a new scope, so we can't walk the attributes from within the symbol
533  // table op.
534  if (!from->hasTrait<OpTrait::SymbolTable>())
535  return walkSymbolUses(from->getRegions(), callback);
536  return WalkResult::advance();
537 }
538 
539 namespace {
540 /// This class represents a single symbol scope. A symbol scope represents the
541 /// set of operations nested within a symbol table that may reference symbols
542 /// within that table. A symbol scope does not contain the symbol table
543 /// operation itself, just its contained operations. A scope ends at leaf
544 /// operations or another symbol table operation.
545 struct SymbolScope {
546  /// Walk the symbol uses within this scope, invoking the given callback.
547  /// This variant is used when the callback type matches that expected by
548  /// 'walkSymbolUses'.
549  template <typename CallbackT,
550  std::enable_if_t<!std::is_same<
551  typename llvm::function_traits<CallbackT>::result_t,
552  void>::value> * = nullptr>
553  std::optional<WalkResult> walk(CallbackT cback) {
554  if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
555  return walkSymbolUses(*region, cback);
556  return walkSymbolUses(limit.get<Operation *>(), cback);
557  }
558  /// This variant is used when the callback type matches a stripped down type:
559  /// void(SymbolTable::SymbolUse use)
560  template <typename CallbackT,
561  std::enable_if_t<std::is_same<
562  typename llvm::function_traits<CallbackT>::result_t,
563  void>::value> * = nullptr>
564  std::optional<WalkResult> walk(CallbackT cback) {
565  return walk([=](SymbolTable::SymbolUse use) {
566  return cback(use), WalkResult::advance();
567  });
568  }
569 
570  /// Walk all of the operations nested under the current scope without
571  /// traversing into any nested symbol tables.
572  template <typename CallbackT>
573  std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
574  if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
575  return ::walkSymbolTable(*region, cback);
576  return ::walkSymbolTable(limit.get<Operation *>(), cback);
577  }
578 
579  /// The representation of the symbol within this scope.
580  SymbolRefAttr symbol;
581 
582  /// The IR unit representing this scope.
584 };
585 } // namespace
586 
587 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
589  Operation *limit) {
590  StringAttr symName = SymbolTable::getSymbolName(symbol);
591  assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
592 
593  // Compute the ancestors of 'limit'.
596  limitAncestors;
597  Operation *limitAncestor = limit;
598  do {
599  // Check to see if 'symbol' is an ancestor of 'limit'.
600  if (limitAncestor == symbol) {
601  // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
602  // doesn't support parent references.
604  symbol->getParentOp())
605  return {{SymbolRefAttr::get(symName), limit}};
606  return {};
607  }
608 
609  limitAncestors.insert(limitAncestor);
610  } while ((limitAncestor = limitAncestor->getParentOp()));
611 
612  // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
613  Operation *commonAncestor = symbol->getParentOp();
614  do {
615  if (limitAncestors.count(commonAncestor))
616  break;
617  } while ((commonAncestor = commonAncestor->getParentOp()));
618  assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
619 
620  // Compute the set of valid nested references for 'symbol' as far up to the
621  // common ancestor as possible.
623  bool collectedAllReferences = succeeded(
624  collectValidReferencesFor(symbol, symName, commonAncestor, references));
625 
626  // Handle the case where the common ancestor is 'limit'.
627  if (commonAncestor == limit) {
629 
630  // Walk each of the ancestors of 'symbol', calling the compute function for
631  // each one.
632  Operation *limitIt = symbol->getParentOp();
633  for (size_t i = 0, e = references.size(); i != e;
634  ++i, limitIt = limitIt->getParentOp()) {
635  assert(limitIt->hasTrait<OpTrait::SymbolTable>());
636  scopes.push_back({references[i], &limitIt->getRegion(0)});
637  }
638  return scopes;
639  }
640 
641  // Otherwise, we just need the symbol reference for 'symbol' that will be
642  // used within 'limit'. This is the last reference in the list we computed
643  // above if we were able to collect all references.
644  if (!collectedAllReferences)
645  return {};
646  return {{references.back(), limit}};
647 }
649  Region *limit) {
650  auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
651 
652  // If we collected some scopes to walk, make sure to constrain the one for
653  // limit to the specific region requested.
654  if (!scopes.empty())
655  scopes.back().limit = limit;
656  return scopes;
657 }
658 template <typename IRUnit>
660  IRUnit *limit) {
661  return {{SymbolRefAttr::get(symbol), limit}};
662 }
663 
664 /// Returns true if the given reference 'SubRef' is a sub reference of the
665 /// reference 'ref', i.e. 'ref' is a further qualified reference.
666 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
667  if (ref == subRef)
668  return true;
669 
670  // If the references are not pointer equal, check to see if `subRef` is a
671  // prefix of `ref`.
672  if (llvm::isa<FlatSymbolRefAttr>(ref) ||
673  ref.getRootReference() != subRef.getRootReference())
674  return false;
675 
676  auto refLeafs = ref.getNestedReferences();
677  auto subRefLeafs = subRef.getNestedReferences();
678  return subRefLeafs.size() < refLeafs.size() &&
679  subRefLeafs == refLeafs.take_front(subRefLeafs.size());
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // SymbolTable::getSymbolUses
684 
685 /// The implementation of SymbolTable::getSymbolUses below.
686 template <typename FromT>
687 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
688  std::vector<SymbolTable::SymbolUse> uses;
689  auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
690  uses.push_back(symbolUse);
691  return WalkResult::advance();
692  };
693  auto result = walkSymbolUses(from, walkFn);
694  return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
695  : std::nullopt;
696 }
697 
698 /// Get an iterator range for all of the uses, for any symbol, that are nested
699 /// within the given operation 'from'. This does not traverse into any nested
700 /// symbol tables, and will also only return uses on 'from' if it does not
701 /// also define a symbol table. This is because we treat the region as the
702 /// boundary of the symbol table, and not the op itself. This function returns
703 /// std::nullopt if there are any unknown operations that may potentially be
704 /// symbol tables.
705 auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
706  return getSymbolUsesImpl(from);
707 }
708 auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // SymbolTable::getSymbolUses
714 
715 /// The implementation of SymbolTable::getSymbolUses below.
716 template <typename SymbolT, typename IRUnitT>
717 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
718  IRUnitT *limit) {
719  std::vector<SymbolTable::SymbolUse> uses;
720  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
721  if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
722  if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
723  uses.push_back(symbolUse);
724  }))
725  return std::nullopt;
726  }
727  return SymbolTable::UseRange(std::move(uses));
728 }
729 
730 /// Get all of the uses of the given symbol that are nested within the given
731 /// operation 'from', invoking the provided callback for each. This does not
732 /// traverse into any nested symbol tables. This function returns std::nullopt
733 /// if there are any unknown operations that may potentially be symbol tables.
734 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
735  -> std::optional<UseRange> {
736  return getSymbolUsesImpl(symbol, from);
737 }
739  -> std::optional<UseRange> {
740  return getSymbolUsesImpl(symbol, from);
741 }
742 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
743  -> std::optional<UseRange> {
744  return getSymbolUsesImpl(symbol, from);
745 }
747  -> std::optional<UseRange> {
748  return getSymbolUsesImpl(symbol, from);
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // SymbolTable::symbolKnownUseEmpty
753 
754 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
755 template <typename SymbolT, typename IRUnitT>
756 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
757  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
758  // Walk all of the symbol uses looking for a reference to 'symbol'.
759  if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
760  return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
761  ? WalkResult::interrupt()
762  : WalkResult::advance();
763  }) != WalkResult::advance())
764  return false;
765  }
766  return true;
767 }
768 
769 /// Return if the given symbol is known to have no uses that are nested within
770 /// the given operation 'from'. This does not traverse into any nested symbol
771 /// tables. This function will also return false if there are any unknown
772 /// operations that may potentially be symbol tables.
773 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
774  return symbolKnownUseEmptyImpl(symbol, from);
775 }
777  return symbolKnownUseEmptyImpl(symbol, from);
778 }
779 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
780  return symbolKnownUseEmptyImpl(symbol, from);
781 }
783  return symbolKnownUseEmptyImpl(symbol, from);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // SymbolTable::replaceAllSymbolUses
788 
789 /// Generates a new symbol reference attribute with a new leaf reference.
790 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
791  FlatSymbolRefAttr newLeafAttr) {
792  if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
793  return newLeafAttr;
794  auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
795  nestedRefs.back() = newLeafAttr;
796  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
797 }
798 
799 /// The implementation of SymbolTable::replaceAllSymbolUses below.
800 template <typename SymbolT, typename IRUnitT>
801 static LogicalResult
802 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
803  // Generate a new attribute to replace the given attribute.
804  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
805  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
806  SymbolRefAttr oldAttr = scope.symbol;
807  SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
808  AttrTypeReplacer replacer;
809  replacer.addReplacement(
810  [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
811  // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
812  // want to accidentally replace an inner reference.
813  if (attr == oldAttr)
814  return {newAttr, WalkResult::skip()};
815  // Handle prefix matches.
816  if (isReferencePrefixOf(oldAttr, attr)) {
817  auto oldNestedRefs = oldAttr.getNestedReferences();
818  auto nestedRefs = attr.getNestedReferences();
819  if (oldNestedRefs.empty())
820  return {SymbolRefAttr::get(newSymbol, nestedRefs),
821  WalkResult::skip()};
822 
823  auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
824  newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
825  return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
826  WalkResult::skip()};
827  }
828  return {attr, WalkResult::skip()};
829  });
830 
831  auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
832  replacer.replaceElementsIn(op);
833  return WalkResult::advance();
834  };
835  if (!scope.walkSymbolTable(walkFn))
836  return failure();
837  }
838  return success();
839 }
840 
841 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
842 /// provided symbol 'newSymbol' that are nested within the given operation
843 /// 'from'. This does not traverse into any nested symbol tables. If there are
844 /// any unknown operations that may potentially be symbol tables, no uses are
845 /// replaced and failure is returned.
847  StringAttr newSymbol,
848  Operation *from) {
849  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
850 }
852  StringAttr newSymbol,
853  Operation *from) {
854  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
855 }
857  StringAttr newSymbol,
858  Region *from) {
859  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
860 }
862  StringAttr newSymbol,
863  Region *from) {
864  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // SymbolTableCollection
869 //===----------------------------------------------------------------------===//
870 
872  StringAttr symbol) {
873  return getSymbolTable(symbolTableOp).lookup(symbol);
874 }
876  SymbolRefAttr name) {
878  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
879  return nullptr;
880  return symbols.back();
881 }
882 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
883 /// a given SymbolRefAttr. Returns failure if any of the nested references could
884 /// not be resolved.
887  SymbolRefAttr name,
888  SmallVectorImpl<Operation *> &symbols) {
889  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
890  return lookupSymbolIn(symbolTableOp, symbol);
891  };
892  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
893 }
894 
895 /// Returns the operation registered with the given symbol name within the
896 /// closest parent operation of, or including, 'from' with the
897 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
898 /// found.
900  StringAttr symbol) {
901  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
902  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
903 }
904 Operation *
906  SymbolRefAttr symbol) {
907  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
908  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
909 }
910 
911 /// Lookup, or create, a symbol table for an operation.
913  auto it = symbolTables.try_emplace(op, nullptr);
914  if (it.second)
915  it.first->second = std::make_unique<SymbolTable>(op);
916  return *it.first->second;
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // LockedSymbolTableCollection
921 //===----------------------------------------------------------------------===//
922 
924  StringAttr symbol) {
925  return getSymbolTable(symbolTableOp).lookup(symbol);
926 }
927 
928 Operation *
930  FlatSymbolRefAttr symbol) {
931  return lookupSymbolIn(symbolTableOp, symbol.getAttr());
932 }
933 
935  SymbolRefAttr name) {
936  SmallVector<Operation *> symbols;
937  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
938  return nullptr;
939  return symbols.back();
940 }
941 
943  Operation *symbolTableOp, SymbolRefAttr name,
944  SmallVectorImpl<Operation *> &symbols) {
945  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
946  return lookupSymbolIn(symbolTableOp, symbol);
947  };
948  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
949 }
950 
951 SymbolTable &
952 LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
953  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
954  // Try to find an existing symbol table.
955  {
956  llvm::sys::SmartScopedReader<true> lock(mutex);
957  auto it = collection.symbolTables.find(symbolTableOp);
958  if (it != collection.symbolTables.end())
959  return *it->second;
960  }
961  // Create a symbol table for the operation. Perform construction outside of
962  // the critical section.
963  auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
964  // Insert the constructed symbol table.
965  llvm::sys::SmartScopedWriter<true> lock(mutex);
966  return *collection.symbolTables
967  .insert({symbolTableOp, std::move(symbolTable)})
968  .first->second;
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // SymbolUserMap
973 //===----------------------------------------------------------------------===//
974 
976  Operation *symbolTableOp)
977  : symbolTable(symbolTable) {
978  // Walk each of the symbol tables looking for discardable callgraph nodes.
979  SmallVector<Operation *> symbols;
980  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
981  for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
982  auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
983  assert(symbolUses && "expected uses to be valid");
984 
985  for (const SymbolTable::SymbolUse &use : *symbolUses) {
986  symbols.clear();
987  (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
988  symbols);
989  for (Operation *symbolOp : symbols)
990  symbolToUsers[symbolOp].insert(use.getUser());
991  }
992  }
993  };
994  // We just set `allSymUsesVisible` to false here because it isn't necessary
995  // for building the user map.
996  SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
997  walkFn);
998 }
999 
1001  StringAttr newSymbolName) {
1002  auto it = symbolToUsers.find(symbol);
1003  if (it == symbolToUsers.end())
1004  return;
1005 
1006  // Replace the uses within the users of `symbol`.
1007  for (Operation *user : it->second)
1008  (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1009 
1010  // Move the current users of `symbol` to the new symbol if it is in the
1011  // symbol table.
1012  Operation *newSymbol =
1013  symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1014  if (newSymbol != symbol) {
1015  // Transfer over the users to the new symbol. The reference to the old one
1016  // is fetched again as the iterator is invalidated during the insertion.
1017  auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
1018  auto oldIt = symbolToUsers.find(symbol);
1019  assert(oldIt != symbolToUsers.end() && "missing old users list");
1020  if (newIt.second)
1021  newIt.first->second = std::move(oldIt->second);
1022  else
1023  newIt.first->second.set_union(oldIt->second);
1024  symbolToUsers.erase(oldIt);
1025  }
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // Visibility parsing implementation.
1030 //===----------------------------------------------------------------------===//
1031 
1033  NamedAttrList &attrs) {
1034  StringRef visibility;
1035  if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1036  return failure();
1037 
1038  StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1039  attrs.push_back(parser.getBuilder().getNamedAttr(
1040  SymbolTable::getVisibilityAttrName(), visibilityAttr));
1041  return success();
1042 }
1043 
1044 //===----------------------------------------------------------------------===//
1045 // Symbol Interfaces
1046 //===----------------------------------------------------------------------===//
1047 
1048 /// Include the generated symbol interfaces.
1049 #include "mlir/IR/SymbolInterfaces.cpp.inc"
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
Definition: SymbolTable.cpp:81
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
Definition: SymbolTable.cpp:40
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
Definition: SymbolTable.cpp:28
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
Definition: SymbolTable.cpp:22
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
IRUnit is a union of the different types of IR objects that consistute the IR structure (other than T...
Definition: Unit.h:28
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Definition: Diagnostics.h:334
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:757
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:400
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:295
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:578
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:248
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
Definition: SymbolTable.h:148
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:168
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition: SymbolTable.h:73
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition: SymbolTable.h:65
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26