MLIR  19.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 /// Return true if the given operation is unknown and may potentially define a
21 /// symbol table.
23  return op->getNumRegions() == 1 && !op->getDialect();
24 }
25 
26 /// Returns the string name of the given symbol, or null if this is not a
27 /// symbol.
28 static StringAttr getNameIfSymbol(Operation *op) {
29  return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
30 }
31 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32  return op->getAttrOfType<StringAttr>(symbolAttrNameId);
33 }
34 
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
39 static LogicalResult
40 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41  Operation *within,
43  assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44  MLIRContext *ctx = symbol->getContext();
45 
46  auto leafRef = FlatSymbolRefAttr::get(symbolName);
47  results.push_back(leafRef);
48 
49  // Early exit for when 'within' is the parent of 'symbol'.
50  Operation *symbolTableOp = symbol->getParentOp();
51  if (within == symbolTableOp)
52  return success();
53 
54  // Collect references until 'symbolTableOp' reaches 'within'.
55  SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56  StringAttr symbolNameId =
58  do {
59  // Each parent of 'symbol' should define a symbol table.
60  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61  return failure();
62  // Each parent of 'symbol' should also be a symbol.
63  StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64  if (!symbolTableName)
65  return failure();
66  results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
67 
68  symbolTableOp = symbolTableOp->getParentOp();
69  if (symbolTableOp == within)
70  break;
71  nestedRefs.insert(nestedRefs.begin(),
72  FlatSymbolRefAttr::get(symbolTableName));
73  } while (true);
74  return success();
75 }
76 
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional<WalkResult>
82  function_ref<std::optional<WalkResult>(Operation *)> callback) {
83  SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
84  while (!worklist.empty()) {
85  for (Operation &op : worklist.pop_back_val()->getOps()) {
86  std::optional<WalkResult> result = callback(&op);
87  if (result != WalkResult::advance())
88  return result;
89 
90  // If this op defines a new symbol table scope, we can't traverse. Any
91  // symbol references nested within 'op' are different semantically.
92  if (!op.hasTrait<OpTrait::SymbolTable>()) {
93  for (Region &region : op.getRegions())
94  worklist.push_back(&region);
95  }
96  }
97  }
98  return WalkResult::advance();
99 }
100 
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional<WalkResult>
106  function_ref<std::optional<WalkResult>(Operation *)> callback) {
107  std::optional<WalkResult> result = callback(op);
108  if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109  return result;
110  return walkSymbolTable(op->getRegions(), callback);
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // SymbolTable
115 //===----------------------------------------------------------------------===//
116 
117 /// Build a symbol table with the symbols within the given operation.
119  : symbolTableOp(symbolTableOp) {
120  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121  "expected operation to have SymbolTable trait");
122  assert(symbolTableOp->getNumRegions() == 1 &&
123  "expected operation to have a single region");
124  assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125  "expected operation to have a single block");
126 
127  StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
129  for (auto &op : symbolTableOp->getRegion(0).front()) {
130  StringAttr name = getNameIfSymbol(&op, symbolNameId);
131  if (!name)
132  continue;
133 
134  auto inserted = symbolTable.insert({name, &op});
135  (void)inserted;
136  assert(inserted.second &&
137  "expected region to contain uniquely named symbol operations");
138  }
139 }
140 
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation *SymbolTable::lookup(StringRef name) const {
144  return lookup(StringAttr::get(symbolTableOp->getContext(), name));
145 }
146 Operation *SymbolTable::lookup(StringAttr name) const {
147  return symbolTable.lookup(name);
148 }
149 
151  StringAttr name = getNameIfSymbol(op);
152  assert(name && "expected valid 'name' attribute");
153  assert(op->getParentOp() == symbolTableOp &&
154  "expected this operation to be inside of the operation with this "
155  "SymbolTable");
156 
157  auto it = symbolTable.find(name);
158  if (it != symbolTable.end() && it->second == op)
159  symbolTable.erase(it);
160 }
161 
163  remove(symbol);
164  symbol->erase();
165 }
166 
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172  // The symbol cannot be the child of another op and must be the child of the
173  // symbolTableOp after this.
174  //
175  // TODO: consider if SymbolTable's constructor should behave the same.
176  if (!symbol->getParentOp()) {
177  auto &body = symbolTableOp->getRegion(0).front();
178  if (insertPt == Block::iterator()) {
179  insertPt = Block::iterator(body.end());
180  } else {
181  assert((insertPt == body.end() ||
182  insertPt->getParentOp() == symbolTableOp) &&
183  "expected insertPt to be in the associated module operation");
184  }
185  // Insert before the terminator, if any.
186  if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187  std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
188  insertPt = std::prev(body.end());
189 
190  body.getOperations().insert(insertPt, symbol);
191  }
192  assert(symbol->getParentOp() == symbolTableOp &&
193  "symbol is already inserted in another op");
194 
195  // Add this symbol to the symbol table, uniquing the name if a conflict is
196  // detected.
197  StringAttr name = getSymbolName(symbol);
198  if (symbolTable.insert({name, symbol}).second)
199  return name;
200  // If the symbol was already in the table, also return.
201  if (symbolTable.lookup(name) == symbol)
202  return name;
203 
204  MLIRContext *context = symbol->getContext();
205  SmallString<128> nameBuffer = generateSymbolName<128>(
206  name.getValue(),
207  [&](StringRef candidate) {
208  return !symbolTable
209  .insert({StringAttr::get(context, candidate), symbol})
210  .second;
211  },
212  uniquingCounter);
213  setSymbolName(symbol, nameBuffer);
214  return getSymbolName(symbol);
215 }
216 
217 LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
218  Operation *op = lookup(from);
219  return rename(op, to);
220 }
221 
223  StringAttr from = getNameIfSymbol(op);
224  (void)from;
225 
226  assert(from && "expected valid 'name' attribute");
227  assert(op->getParentOp() == symbolTableOp &&
228  "expected this operation to be inside of the operation with this "
229  "SymbolTable");
230  assert(lookup(from) == op && "current name does not resolve to op");
231  assert(lookup(to) == nullptr && "new name already exists");
232 
234  return failure();
235 
236  // Remove op with old name, change name, add with new name. The order is
237  // important here due to how `remove` and `insert` rely on the op name.
238  remove(op);
239  setSymbolName(op, to);
240  insert(op);
241 
242  assert(lookup(to) == op && "new name does not resolve to renamed op");
243  assert(lookup(from) == nullptr && "old name still exists");
244 
245  return success();
246 }
247 
248 LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
249  auto toAttr = StringAttr::get(getOp()->getContext(), to);
250  return rename(from, toAttr);
251 }
252 
254  auto toAttr = StringAttr::get(getOp()->getContext(), to);
255  return rename(op, toAttr);
256 }
257 
259 SymbolTable::renameToUnique(StringAttr oldName,
260  ArrayRef<SymbolTable *> others) {
261 
262  // Determine new name that is unique in all symbol tables.
263  StringAttr newName;
264  {
265  MLIRContext *context = oldName.getContext();
266  SmallString<64> prefix = oldName.getValue();
267  int uniqueId = 0;
268  prefix.push_back('_');
269  while (true) {
270  newName = StringAttr::get(context, prefix + Twine(uniqueId++));
271  auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
272  if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
273  break;
274  }
275  }
276  }
277 
278  // Apply renaming.
279  if (failed(rename(oldName, newName)))
280  return failure();
281  return newName;
282 }
283 
286  StringAttr from = getNameIfSymbol(op);
287  assert(from && "expected valid 'name' attribute");
288  return renameToUnique(from, others);
289 }
290 
291 /// Returns the name of the given symbol operation.
293  StringAttr name = getNameIfSymbol(symbol);
294  assert(name && "expected valid symbol name");
295  return name;
296 }
297 
298 /// Sets the name of the given symbol operation.
299 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
300  symbol->setAttr(getSymbolAttrName(), name);
301 }
302 
303 /// Returns the visibility of the given symbol operation.
305  // If the attribute doesn't exist, assume public.
306  StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
307  if (!vis)
308  return Visibility::Public;
309 
310  // Otherwise, switch on the string value.
311  return StringSwitch<Visibility>(vis.getValue())
312  .Case("private", Visibility::Private)
313  .Case("nested", Visibility::Nested)
314  .Case("public", Visibility::Public);
315 }
316 /// Sets the visibility of the given symbol operation.
318  MLIRContext *ctx = symbol->getContext();
319 
320  // If the visibility is public, just drop the attribute as this is the
321  // default.
322  if (vis == Visibility::Public) {
324  return;
325  }
326 
327  // Otherwise, update the attribute.
328  assert((vis == Visibility::Private || vis == Visibility::Nested) &&
329  "unknown symbol visibility kind");
330 
331  StringRef visName = vis == Visibility::Private ? "private" : "nested";
332  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
333 }
334 
335 /// Returns the nearest symbol table from a given operation `from`. Returns
336 /// nullptr if no valid parent symbol table could be found.
338  assert(from && "expected valid operation");
340  return nullptr;
341 
342  while (!from->hasTrait<OpTrait::SymbolTable>()) {
343  from = from->getParentOp();
344 
345  // Check that this is a valid op and isn't an unknown symbol table.
346  if (!from || isPotentiallyUnknownSymbolTable(from))
347  return nullptr;
348  }
349  return from;
350 }
351 
352 /// Walks all symbol table operations nested within, and including, `op`. For
353 /// each symbol table operation, the provided callback is invoked with the op
354 /// and a boolean signifying if the symbols within that symbol table can be
355 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
356 /// all of the symbol uses of symbols within `op` are visible.
358  Operation *op, bool allSymUsesVisible,
359  function_ref<void(Operation *, bool)> callback) {
360  bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
361  if (isSymbolTable) {
362  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
363  allSymUsesVisible |= !symbol || symbol.isPrivate();
364  } else {
365  // Otherwise if 'op' is not a symbol table, any nested symbols are
366  // guaranteed to be hidden.
367  allSymUsesVisible = true;
368  }
369 
370  for (Region &region : op->getRegions())
371  for (Block &block : region)
372  for (Operation &nestedOp : block)
373  walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
374 
375  // If 'op' had the symbol table trait, visit it after any nested symbol
376  // tables.
377  if (isSymbolTable)
378  callback(op, allSymUsesVisible);
379 }
380 
381 /// Returns the operation registered with the given symbol name with the
382 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
383 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
384 /// was found.
386  StringAttr symbol) {
387  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
388  Region &region = symbolTableOp->getRegion(0);
389  if (region.empty())
390  return nullptr;
391 
392  // Look for a symbol with the given name.
393  StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
395  for (auto &op : region.front())
396  if (getNameIfSymbol(&op, symbolNameId) == symbol)
397  return &op;
398  return nullptr;
399 }
401  SymbolRefAttr symbol) {
402  SmallVector<Operation *, 4> resolvedSymbols;
403  if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
404  return nullptr;
405  return resolvedSymbols.back();
406 }
407 
408 /// Internal implementation of `lookupSymbolIn` that allows for specialized
409 /// implementations of the lookup function.
411  Operation *symbolTableOp, SymbolRefAttr symbol,
413  function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
414  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
415 
416  // Lookup the root reference for this symbol.
417  symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
418  if (!symbolTableOp)
419  return failure();
420  symbols.push_back(symbolTableOp);
421 
422  // If there are no nested references, just return the root symbol directly.
423  ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
424  if (nestedRefs.empty())
425  return success();
426 
427  // Verify that the root is also a symbol table.
428  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
429  return failure();
430 
431  // Otherwise, lookup each of the nested non-leaf references and ensure that
432  // each corresponds to a valid symbol table.
433  for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
434  symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
435  if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
436  return failure();
437  symbols.push_back(symbolTableOp);
438  }
439  symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
440  return success(symbols.back());
441 }
442 
444 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
445  SmallVectorImpl<Operation *> &symbols) {
446  auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
447  return lookupSymbolIn(symbolTableOp, symbol);
448  };
449  return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
450 }
451 
452 /// Returns the operation registered with the given symbol name within the
453 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
454 /// nullptr if no valid symbol was found.
456  StringAttr symbol) {
457  Operation *symbolTableOp = getNearestSymbolTable(from);
458  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
459 }
461  SymbolRefAttr symbol) {
462  Operation *symbolTableOp = getNearestSymbolTable(from);
463  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
464 }
465 
466 raw_ostream &mlir::operator<<(raw_ostream &os,
467  SymbolTable::Visibility visibility) {
468  switch (visibility) {
470  return os << "public";
472  return os << "private";
474  return os << "nested";
475  }
476  llvm_unreachable("Unexpected visibility");
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // SymbolTable Trait Types
481 //===----------------------------------------------------------------------===//
482 
484  if (op->getNumRegions() != 1)
485  return op->emitOpError()
486  << "Operations with a 'SymbolTable' must have exactly one region";
487  if (!llvm::hasSingleElement(op->getRegion(0)))
488  return op->emitOpError()
489  << "Operations with a 'SymbolTable' must have exactly one block";
490 
491  // Check that all symbols are uniquely named within child regions.
492  DenseMap<Attribute, Location> nameToOrigLoc;
493  for (auto &block : op->getRegion(0)) {
494  for (auto &op : block) {
495  // Check for a symbol name attribute.
496  auto nameAttr =
498  if (!nameAttr)
499  continue;
500 
501  // Try to insert this symbol into the table.
502  auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
503  if (!it.second)
504  return op.emitError()
505  .append("redefinition of symbol named '", nameAttr.getValue(), "'")
506  .attachNote(it.first->second)
507  .append("see existing symbol definition here");
508  }
509  }
510 
511  // Verify any nested symbol user operations.
512  SymbolTableCollection symbolTable;
513  auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
514  if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
515  return WalkResult(user.verifySymbolUses(symbolTable));
516  return WalkResult::advance();
517  };
518 
519  std::optional<WalkResult> result =
520  walkSymbolTable(op->getRegions(), verifySymbolUserFn);
521  return success(result && !result->wasInterrupted());
522 }
523 
525  // Verify the name attribute.
526  if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
527  return op->emitOpError() << "requires string attribute '"
529 
530  // Verify the visibility attribute.
532  StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
533  if (!visStrAttr)
534  return op->emitOpError() << "requires visibility attribute '"
536  << "' to be a string attribute, but got " << vis;
537 
538  if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
539  visStrAttr.getValue()))
540  return op->emitOpError()
541  << "visibility expected to be one of [\"public\", \"private\", "
542  "\"nested\"], but got "
543  << visStrAttr;
544  }
545  return success();
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // Symbol Use Lists
550 //===----------------------------------------------------------------------===//
551 
552 /// Walk all of the symbol references within the given operation, invoking the
553 /// provided callback for each found use. The callbacks takes the use of the
554 /// symbol.
555 static WalkResult
558  return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
559  [&](SymbolRefAttr symbolRef) {
560  if (callback({op, symbolRef}).wasInterrupted())
561  return WalkResult::interrupt();
562 
563  // Don't walk nested references.
564  return WalkResult::skip();
565  });
566 }
567 
568 /// Walk all of the uses, for any symbol, that are nested within the given
569 /// regions, invoking the provided callback for each. This does not traverse
570 /// into any nested symbol tables.
571 static std::optional<WalkResult>
574  return walkSymbolTable(regions,
575  [&](Operation *op) -> std::optional<WalkResult> {
576  // Check that this isn't a potentially unknown symbol
577  // table.
579  return std::nullopt;
580 
581  return walkSymbolRefs(op, callback);
582  });
583 }
584 /// Walk all of the uses, for any symbol, that are nested within the given
585 /// operation 'from', invoking the provided callback for each. This does not
586 /// traverse into any nested symbol tables.
587 static std::optional<WalkResult>
590  // If this operation has regions, and it, as well as its dialect, isn't
591  // registered then conservatively fail. The operation may define a
592  // symbol table, so we can't opaquely know if we should traverse to find
593  // nested uses.
595  return std::nullopt;
596 
597  // Walk the uses on this operation.
598  if (walkSymbolRefs(from, callback).wasInterrupted())
599  return WalkResult::interrupt();
600 
601  // Only recurse if this operation is not a symbol table. A symbol table
602  // defines a new scope, so we can't walk the attributes from within the symbol
603  // table op.
604  if (!from->hasTrait<OpTrait::SymbolTable>())
605  return walkSymbolUses(from->getRegions(), callback);
606  return WalkResult::advance();
607 }
608 
609 namespace {
610 /// This class represents a single symbol scope. A symbol scope represents the
611 /// set of operations nested within a symbol table that may reference symbols
612 /// within that table. A symbol scope does not contain the symbol table
613 /// operation itself, just its contained operations. A scope ends at leaf
614 /// operations or another symbol table operation.
615 struct SymbolScope {
616  /// Walk the symbol uses within this scope, invoking the given callback.
617  /// This variant is used when the callback type matches that expected by
618  /// 'walkSymbolUses'.
619  template <typename CallbackT,
620  std::enable_if_t<!std::is_same<
621  typename llvm::function_traits<CallbackT>::result_t,
622  void>::value> * = nullptr>
623  std::optional<WalkResult> walk(CallbackT cback) {
624  if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
625  return walkSymbolUses(*region, cback);
626  return walkSymbolUses(limit.get<Operation *>(), cback);
627  }
628  /// This variant is used when the callback type matches a stripped down type:
629  /// void(SymbolTable::SymbolUse use)
630  template <typename CallbackT,
631  std::enable_if_t<std::is_same<
632  typename llvm::function_traits<CallbackT>::result_t,
633  void>::value> * = nullptr>
634  std::optional<WalkResult> walk(CallbackT cback) {
635  return walk([=](SymbolTable::SymbolUse use) {
636  return cback(use), WalkResult::advance();
637  });
638  }
639 
640  /// Walk all of the operations nested under the current scope without
641  /// traversing into any nested symbol tables.
642  template <typename CallbackT>
643  std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
644  if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
645  return ::walkSymbolTable(*region, cback);
646  return ::walkSymbolTable(limit.get<Operation *>(), cback);
647  }
648 
649  /// The representation of the symbol within this scope.
650  SymbolRefAttr symbol;
651 
652  /// The IR unit representing this scope.
654 };
655 } // namespace
656 
657 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
659  Operation *limit) {
660  StringAttr symName = SymbolTable::getSymbolName(symbol);
661  assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
662 
663  // Compute the ancestors of 'limit'.
666  limitAncestors;
667  Operation *limitAncestor = limit;
668  do {
669  // Check to see if 'symbol' is an ancestor of 'limit'.
670  if (limitAncestor == symbol) {
671  // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
672  // doesn't support parent references.
674  symbol->getParentOp())
675  return {{SymbolRefAttr::get(symName), limit}};
676  return {};
677  }
678 
679  limitAncestors.insert(limitAncestor);
680  } while ((limitAncestor = limitAncestor->getParentOp()));
681 
682  // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
683  Operation *commonAncestor = symbol->getParentOp();
684  do {
685  if (limitAncestors.count(commonAncestor))
686  break;
687  } while ((commonAncestor = commonAncestor->getParentOp()));
688  assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
689 
690  // Compute the set of valid nested references for 'symbol' as far up to the
691  // common ancestor as possible.
693  bool collectedAllReferences = succeeded(
694  collectValidReferencesFor(symbol, symName, commonAncestor, references));
695 
696  // Handle the case where the common ancestor is 'limit'.
697  if (commonAncestor == limit) {
699 
700  // Walk each of the ancestors of 'symbol', calling the compute function for
701  // each one.
702  Operation *limitIt = symbol->getParentOp();
703  for (size_t i = 0, e = references.size(); i != e;
704  ++i, limitIt = limitIt->getParentOp()) {
705  assert(limitIt->hasTrait<OpTrait::SymbolTable>());
706  scopes.push_back({references[i], &limitIt->getRegion(0)});
707  }
708  return scopes;
709  }
710 
711  // Otherwise, we just need the symbol reference for 'symbol' that will be
712  // used within 'limit'. This is the last reference in the list we computed
713  // above if we were able to collect all references.
714  if (!collectedAllReferences)
715  return {};
716  return {{references.back(), limit}};
717 }
719  Region *limit) {
720  auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
721 
722  // If we collected some scopes to walk, make sure to constrain the one for
723  // limit to the specific region requested.
724  if (!scopes.empty())
725  scopes.back().limit = limit;
726  return scopes;
727 }
729  Region *limit) {
730  return {{SymbolRefAttr::get(symbol), limit}};
731 }
732 
734  Operation *limit) {
736  auto symbolRef = SymbolRefAttr::get(symbol);
737  for (auto &region : limit->getRegions())
738  scopes.push_back({symbolRef, &region});
739  return scopes;
740 }
741 
742 /// Returns true if the given reference 'SubRef' is a sub reference of the
743 /// reference 'ref', i.e. 'ref' is a further qualified reference.
744 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
745  if (ref == subRef)
746  return true;
747 
748  // If the references are not pointer equal, check to see if `subRef` is a
749  // prefix of `ref`.
750  if (llvm::isa<FlatSymbolRefAttr>(ref) ||
751  ref.getRootReference() != subRef.getRootReference())
752  return false;
753 
754  auto refLeafs = ref.getNestedReferences();
755  auto subRefLeafs = subRef.getNestedReferences();
756  return subRefLeafs.size() < refLeafs.size() &&
757  subRefLeafs == refLeafs.take_front(subRefLeafs.size());
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // SymbolTable::getSymbolUses
762 
763 /// The implementation of SymbolTable::getSymbolUses below.
764 template <typename FromT>
765 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
766  std::vector<SymbolTable::SymbolUse> uses;
767  auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
768  uses.push_back(symbolUse);
769  return WalkResult::advance();
770  };
771  auto result = walkSymbolUses(from, walkFn);
772  return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
773  : std::nullopt;
774 }
775 
776 /// Get an iterator range for all of the uses, for any symbol, that are nested
777 /// within the given operation 'from'. This does not traverse into any nested
778 /// symbol tables, and will also only return uses on 'from' if it does not
779 /// also define a symbol table. This is because we treat the region as the
780 /// boundary of the symbol table, and not the op itself. This function returns
781 /// std::nullopt if there are any unknown operations that may potentially be
782 /// symbol tables.
783 auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
784  return getSymbolUsesImpl(from);
785 }
786 auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // SymbolTable::getSymbolUses
792 
793 /// The implementation of SymbolTable::getSymbolUses below.
794 template <typename SymbolT, typename IRUnitT>
795 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
796  IRUnitT *limit) {
797  std::vector<SymbolTable::SymbolUse> uses;
798  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
799  if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
800  if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
801  uses.push_back(symbolUse);
802  }))
803  return std::nullopt;
804  }
805  return SymbolTable::UseRange(std::move(uses));
806 }
807 
808 /// Get all of the uses of the given symbol that are nested within the given
809 /// operation 'from', invoking the provided callback for each. This does not
810 /// traverse into any nested symbol tables. This function returns std::nullopt
811 /// if there are any unknown operations that may potentially be symbol tables.
812 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
813  -> std::optional<UseRange> {
814  return getSymbolUsesImpl(symbol, from);
815 }
817  -> std::optional<UseRange> {
818  return getSymbolUsesImpl(symbol, from);
819 }
820 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
821  -> std::optional<UseRange> {
822  return getSymbolUsesImpl(symbol, from);
823 }
825  -> std::optional<UseRange> {
826  return getSymbolUsesImpl(symbol, from);
827 }
828 
829 //===----------------------------------------------------------------------===//
830 // SymbolTable::symbolKnownUseEmpty
831 
832 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
833 template <typename SymbolT, typename IRUnitT>
834 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
835  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
836  // Walk all of the symbol uses looking for a reference to 'symbol'.
837  if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
838  return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
839  ? WalkResult::interrupt()
840  : WalkResult::advance();
841  }) != WalkResult::advance())
842  return false;
843  }
844  return true;
845 }
846 
847 /// Return if the given symbol is known to have no uses that are nested within
848 /// the given operation 'from'. This does not traverse into any nested symbol
849 /// tables. This function will also return false if there are any unknown
850 /// operations that may potentially be symbol tables.
851 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
852  return symbolKnownUseEmptyImpl(symbol, from);
853 }
855  return symbolKnownUseEmptyImpl(symbol, from);
856 }
857 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
858  return symbolKnownUseEmptyImpl(symbol, from);
859 }
861  return symbolKnownUseEmptyImpl(symbol, from);
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // SymbolTable::replaceAllSymbolUses
866 
867 /// Generates a new symbol reference attribute with a new leaf reference.
868 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
869  FlatSymbolRefAttr newLeafAttr) {
870  if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
871  return newLeafAttr;
872  auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
873  nestedRefs.back() = newLeafAttr;
874  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
875 }
876 
877 /// The implementation of SymbolTable::replaceAllSymbolUses below.
878 template <typename SymbolT, typename IRUnitT>
879 static LogicalResult
880 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
881  // Generate a new attribute to replace the given attribute.
882  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
883  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
884  SymbolRefAttr oldAttr = scope.symbol;
885  SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
886  AttrTypeReplacer replacer;
887  replacer.addReplacement(
888  [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
889  // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
890  // want to accidentally replace an inner reference.
891  if (attr == oldAttr)
892  return {newAttr, WalkResult::skip()};
893  // Handle prefix matches.
894  if (isReferencePrefixOf(oldAttr, attr)) {
895  auto oldNestedRefs = oldAttr.getNestedReferences();
896  auto nestedRefs = attr.getNestedReferences();
897  if (oldNestedRefs.empty())
898  return {SymbolRefAttr::get(newSymbol, nestedRefs),
899  WalkResult::skip()};
900 
901  auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
902  newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
903  return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
904  WalkResult::skip()};
905  }
906  return {attr, WalkResult::skip()};
907  });
908 
909  auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
910  replacer.replaceElementsIn(op);
911  return WalkResult::advance();
912  };
913  if (!scope.walkSymbolTable(walkFn))
914  return failure();
915  }
916  return success();
917 }
918 
919 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
920 /// provided symbol 'newSymbol' that are nested within the given operation
921 /// 'from'. This does not traverse into any nested symbol tables. If there are
922 /// any unknown operations that may potentially be symbol tables, no uses are
923 /// replaced and failure is returned.
925  StringAttr newSymbol,
926  Operation *from) {
927  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
928 }
930  StringAttr newSymbol,
931  Operation *from) {
932  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
933 }
935  StringAttr newSymbol,
936  Region *from) {
937  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
938 }
940  StringAttr newSymbol,
941  Region *from) {
942  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // SymbolTableCollection
947 //===----------------------------------------------------------------------===//
948 
950  StringAttr symbol) {
951  return getSymbolTable(symbolTableOp).lookup(symbol);
952 }
954  SymbolRefAttr name) {
956  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
957  return nullptr;
958  return symbols.back();
959 }
960 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
961 /// a given SymbolRefAttr. Returns failure if any of the nested references could
962 /// not be resolved.
965  SymbolRefAttr name,
966  SmallVectorImpl<Operation *> &symbols) {
967  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
968  return lookupSymbolIn(symbolTableOp, symbol);
969  };
970  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
971 }
972 
973 /// Returns the operation registered with the given symbol name within the
974 /// closest parent operation of, or including, 'from' with the
975 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
976 /// found.
978  StringAttr symbol) {
979  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
980  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
981 }
982 Operation *
984  SymbolRefAttr symbol) {
985  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
986  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
987 }
988 
989 /// Lookup, or create, a symbol table for an operation.
991  auto it = symbolTables.try_emplace(op, nullptr);
992  if (it.second)
993  it.first->second = std::make_unique<SymbolTable>(op);
994  return *it.first->second;
995 }
996 
997 //===----------------------------------------------------------------------===//
998 // LockedSymbolTableCollection
999 //===----------------------------------------------------------------------===//
1000 
1002  StringAttr symbol) {
1003  return getSymbolTable(symbolTableOp).lookup(symbol);
1004 }
1005 
1006 Operation *
1008  FlatSymbolRefAttr symbol) {
1009  return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1010 }
1011 
1013  SymbolRefAttr name) {
1014  SmallVector<Operation *> symbols;
1015  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1016  return nullptr;
1017  return symbols.back();
1018 }
1019 
1021  Operation *symbolTableOp, SymbolRefAttr name,
1022  SmallVectorImpl<Operation *> &symbols) {
1023  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1024  return lookupSymbolIn(symbolTableOp, symbol);
1025  };
1026  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1027 }
1028 
1029 SymbolTable &
1030 LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1031  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1032  // Try to find an existing symbol table.
1033  {
1034  llvm::sys::SmartScopedReader<true> lock(mutex);
1035  auto it = collection.symbolTables.find(symbolTableOp);
1036  if (it != collection.symbolTables.end())
1037  return *it->second;
1038  }
1039  // Create a symbol table for the operation. Perform construction outside of
1040  // the critical section.
1041  auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1042  // Insert the constructed symbol table.
1043  llvm::sys::SmartScopedWriter<true> lock(mutex);
1044  return *collection.symbolTables
1045  .insert({symbolTableOp, std::move(symbolTable)})
1046  .first->second;
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // SymbolUserMap
1051 //===----------------------------------------------------------------------===//
1052 
1054  Operation *symbolTableOp)
1055  : symbolTable(symbolTable) {
1056  // Walk each of the symbol tables looking for discardable callgraph nodes.
1057  SmallVector<Operation *> symbols;
1058  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1059  for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1060  auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1061  assert(symbolUses && "expected uses to be valid");
1062 
1063  for (const SymbolTable::SymbolUse &use : *symbolUses) {
1064  symbols.clear();
1065  (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1066  symbols);
1067  for (Operation *symbolOp : symbols)
1068  symbolToUsers[symbolOp].insert(use.getUser());
1069  }
1070  }
1071  };
1072  // We just set `allSymUsesVisible` to false here because it isn't necessary
1073  // for building the user map.
1074  SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1075  walkFn);
1076 }
1077 
1079  StringAttr newSymbolName) {
1080  auto it = symbolToUsers.find(symbol);
1081  if (it == symbolToUsers.end())
1082  return;
1083 
1084  // Replace the uses within the users of `symbol`.
1085  for (Operation *user : it->second)
1086  (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1087 
1088  // Move the current users of `symbol` to the new symbol if it is in the
1089  // symbol table.
1090  Operation *newSymbol =
1091  symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1092  if (newSymbol != symbol) {
1093  // Transfer over the users to the new symbol. The reference to the old one
1094  // is fetched again as the iterator is invalidated during the insertion.
1095  auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
1096  auto oldIt = symbolToUsers.find(symbol);
1097  assert(oldIt != symbolToUsers.end() && "missing old users list");
1098  if (newIt.second)
1099  newIt.first->second = std::move(oldIt->second);
1100  else
1101  newIt.first->second.set_union(oldIt->second);
1102  symbolToUsers.erase(oldIt);
1103  }
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // Visibility parsing implementation.
1108 //===----------------------------------------------------------------------===//
1109 
1111  NamedAttrList &attrs) {
1112  StringRef visibility;
1113  if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1114  return failure();
1115 
1116  StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1117  attrs.push_back(parser.getBuilder().getNamedAttr(
1118  SymbolTable::getVisibilityAttrName(), visibilityAttr));
1119  return success();
1120 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // Symbol Interfaces
1124 //===----------------------------------------------------------------------===//
1125 
1126 /// Include the generated symbol interfaces.
1127 #include "mlir/IR/SymbolInterfaces.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
Definition: SymbolTable.cpp:81
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
Definition: SymbolTable.cpp:40
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
Definition: SymbolTable.cpp:28
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
Definition: SymbolTable.cpp:22
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Definition: Diagnostics.h:334
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:762
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:595
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
Definition: SymbolTable.h:183
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:203
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
Operation * getOp() const
Returns the associated operation.
Definition: SymbolTable.h:79
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition: SymbolTable.h:90
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition: SymbolTable.h:82
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
FailureOr< StringAttr > renameToUnique(StringAttr from, ArrayRef< SymbolTable * > others)
Renames the given op or the op refered to by the given name to the a name that is unique within this ...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26