MLIR  22.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include <optional>
16 
17 using namespace mlir;
18 
19 /// Return true if the given operation is unknown and may potentially define a
20 /// symbol table.
22  return op->getNumRegions() == 1 && !op->getDialect();
23 }
24 
25 /// Returns the string name of the given symbol, or null if this is not a
26 /// symbol.
27 static StringAttr getNameIfSymbol(Operation *op) {
28  return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
29 }
30 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
31  return op->getAttrOfType<StringAttr>(symbolAttrNameId);
32 }
33 
34 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
35 /// that are usable within the symbol table operations from 'symbol' as far up
36 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
37 /// Returns success if all references up to 'within' could be computed.
38 static LogicalResult
39 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
40  Operation *within,
42  assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
43  MLIRContext *ctx = symbol->getContext();
44 
45  auto leafRef = FlatSymbolRefAttr::get(symbolName);
46  results.push_back(leafRef);
47 
48  // Early exit for when 'within' is the parent of 'symbol'.
49  Operation *symbolTableOp = symbol->getParentOp();
50  if (within == symbolTableOp)
51  return success();
52 
53  // Collect references until 'symbolTableOp' reaches 'within'.
54  SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
55  StringAttr symbolNameId =
57  do {
58  // Each parent of 'symbol' should define a symbol table.
59  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
60  return failure();
61  // Each parent of 'symbol' should also be a symbol.
62  StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
63  if (!symbolTableName)
64  return failure();
65  results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
66 
67  symbolTableOp = symbolTableOp->getParentOp();
68  if (symbolTableOp == within)
69  break;
70  nestedRefs.insert(nestedRefs.begin(),
71  FlatSymbolRefAttr::get(symbolTableName));
72  } while (true);
73  return success();
74 }
75 
76 /// Walk all of the operations within the given set of regions, without
77 /// traversing into any nested symbol tables. Stops walking if the result of the
78 /// callback is anything other than `WalkResult::advance`.
79 static std::optional<WalkResult>
81  function_ref<std::optional<WalkResult>(Operation *)> callback) {
82  SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
83  while (!worklist.empty()) {
84  for (Operation &op : worklist.pop_back_val()->getOps()) {
85  std::optional<WalkResult> result = callback(&op);
86  if (result != WalkResult::advance())
87  return result;
88 
89  // If this op defines a new symbol table scope, we can't traverse. Any
90  // symbol references nested within 'op' are different semantically.
91  if (!op.hasTrait<OpTrait::SymbolTable>()) {
92  for (Region &region : op.getRegions())
93  worklist.push_back(&region);
94  }
95  }
96  }
97  return WalkResult::advance();
98 }
99 
100 /// Walk all of the operations nested under, and including, the given operation,
101 /// without traversing into any nested symbol tables. Stops walking if the
102 /// result of the callback is anything other than `WalkResult::advance`.
103 static std::optional<WalkResult>
105  function_ref<std::optional<WalkResult>(Operation *)> callback) {
106  std::optional<WalkResult> result = callback(op);
107  if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
108  return result;
109  return walkSymbolTable(op->getRegions(), callback);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // SymbolTable
114 //===----------------------------------------------------------------------===//
115 
116 /// Build a symbol table with the symbols within the given operation.
118  : symbolTableOp(symbolTableOp) {
119  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
120  "expected operation to have SymbolTable trait");
121  assert(symbolTableOp->getNumRegions() == 1 &&
122  "expected operation to have a single region");
123  assert(symbolTableOp->getRegion(0).hasOneBlock() &&
124  "expected operation to have a single block");
125 
126  StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128  for (auto &op : symbolTableOp->getRegion(0).front()) {
129  StringAttr name = getNameIfSymbol(&op, symbolNameId);
130  if (!name)
131  continue;
132 
133  auto inserted = symbolTable.insert({name, &op});
134  (void)inserted;
135  assert(inserted.second &&
136  "expected region to contain uniquely named symbol operations");
137  }
138 }
139 
140 /// Look up a symbol with the specified name, returning null if no such name
141 /// exists. Names never include the @ on them.
142 Operation *SymbolTable::lookup(StringRef name) const {
143  return lookup(StringAttr::get(symbolTableOp->getContext(), name));
144 }
145 Operation *SymbolTable::lookup(StringAttr name) const {
146  return symbolTable.lookup(name);
147 }
148 
150  StringAttr name = getNameIfSymbol(op);
151  assert(name && "expected valid 'name' attribute");
152  assert(op->getParentOp() == symbolTableOp &&
153  "expected this operation to be inside of the operation with this "
154  "SymbolTable");
155 
156  auto it = symbolTable.find(name);
157  if (it != symbolTable.end() && it->second == op)
158  symbolTable.erase(it);
159 }
160 
162  remove(symbol);
163  symbol->erase();
164 }
165 
166 // TODO: Consider if this should be renamed to something like insertOrUpdate
167 /// Insert a new symbol into the table and associated operation if not already
168 /// there and rename it as necessary to avoid collisions. Return the name of
169 /// the symbol after insertion as attribute.
170 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
171  // The symbol cannot be the child of another op and must be the child of the
172  // symbolTableOp after this.
173  //
174  // TODO: consider if SymbolTable's constructor should behave the same.
175  if (!symbol->getParentOp()) {
176  auto &body = symbolTableOp->getRegion(0).front();
177  if (insertPt == Block::iterator()) {
178  insertPt = Block::iterator(body.end());
179  } else {
180  assert((insertPt == body.end() ||
181  insertPt->getParentOp() == symbolTableOp) &&
182  "expected insertPt to be in the associated module operation");
183  }
184  // Insert before the terminator, if any.
185  if (insertPt == Block::iterator(body.end()) && !body.empty() &&
186  std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
187  insertPt = std::prev(body.end());
188 
189  body.getOperations().insert(insertPt, symbol);
190  }
191  assert(symbol->getParentOp() == symbolTableOp &&
192  "symbol is already inserted in another op");
193 
194  // Add this symbol to the symbol table, uniquing the name if a conflict is
195  // detected.
196  StringAttr name = getSymbolName(symbol);
197  if (symbolTable.insert({name, symbol}).second)
198  return name;
199  // If the symbol was already in the table, also return.
200  if (symbolTable.lookup(name) == symbol)
201  return name;
202 
203  MLIRContext *context = symbol->getContext();
204  SmallString<128> nameBuffer = generateSymbolName<128>(
205  name.getValue(),
206  [&](StringRef candidate) {
207  return !symbolTable
208  .insert({StringAttr::get(context, candidate), symbol})
209  .second;
210  },
211  uniquingCounter);
212  setSymbolName(symbol, nameBuffer);
213  return getSymbolName(symbol);
214 }
215 
216 LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
217  Operation *op = lookup(from);
218  return rename(op, to);
219 }
220 
221 LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
222  StringAttr from = getNameIfSymbol(op);
223  (void)from;
224 
225  assert(from && "expected valid 'name' attribute");
226  assert(op->getParentOp() == symbolTableOp &&
227  "expected this operation to be inside of the operation with this "
228  "SymbolTable");
229  assert(lookup(from) == op && "current name does not resolve to op");
230  assert(lookup(to) == nullptr && "new name already exists");
231 
233  return failure();
234 
235  // Remove op with old name, change name, add with new name. The order is
236  // important here due to how `remove` and `insert` rely on the op name.
237  remove(op);
238  setSymbolName(op, to);
239  insert(op);
240 
241  assert(lookup(to) == op && "new name does not resolve to renamed op");
242  assert(lookup(from) == nullptr && "old name still exists");
243 
244  return success();
245 }
246 
247 LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
248  auto toAttr = StringAttr::get(getOp()->getContext(), to);
249  return rename(from, toAttr);
250 }
251 
252 LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
253  auto toAttr = StringAttr::get(getOp()->getContext(), to);
254  return rename(op, toAttr);
255 }
256 
257 FailureOr<StringAttr>
258 SymbolTable::renameToUnique(StringAttr oldName,
259  ArrayRef<SymbolTable *> others) {
260 
261  // Determine new name that is unique in all symbol tables.
262  StringAttr newName;
263  {
264  MLIRContext *context = oldName.getContext();
265  SmallString<64> prefix = oldName.getValue();
266  int uniqueId = 0;
267  prefix.push_back('_');
268  while (true) {
269  newName = StringAttr::get(context, prefix + Twine(uniqueId++));
270  auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
271  if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
272  break;
273  }
274  }
275  }
276 
277  // Apply renaming.
278  if (failed(rename(oldName, newName)))
279  return failure();
280  return newName;
281 }
282 
283 FailureOr<StringAttr>
285  StringAttr from = getNameIfSymbol(op);
286  assert(from && "expected valid 'name' attribute");
287  return renameToUnique(from, others);
288 }
289 
290 /// Returns the name of the given symbol operation.
292  StringAttr name = getNameIfSymbol(symbol);
293  assert(name && "expected valid symbol name");
294  return name;
295 }
296 
297 /// Sets the name of the given symbol operation.
298 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
299  symbol->setAttr(getSymbolAttrName(), name);
300 }
301 
302 /// Returns the visibility of the given symbol operation.
304  // If the attribute doesn't exist, assume public.
305  StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
306  if (!vis)
307  return Visibility::Public;
308 
309  // Otherwise, switch on the string value.
310  return StringSwitch<Visibility>(vis.getValue())
311  .Case("private", Visibility::Private)
312  .Case("nested", Visibility::Nested)
313  .Case("public", Visibility::Public);
314 }
315 /// Sets the visibility of the given symbol operation.
317  MLIRContext *ctx = symbol->getContext();
318 
319  // If the visibility is public, just drop the attribute as this is the
320  // default.
321  if (vis == Visibility::Public) {
323  return;
324  }
325 
326  // Otherwise, update the attribute.
327  assert((vis == Visibility::Private || vis == Visibility::Nested) &&
328  "unknown symbol visibility kind");
329 
330  StringRef visName = vis == Visibility::Private ? "private" : "nested";
331  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
332 }
333 
334 /// Returns the nearest symbol table from a given operation `from`. Returns
335 /// nullptr if no valid parent symbol table could be found.
337  assert(from && "expected valid operation");
339  return nullptr;
340 
341  while (!from->hasTrait<OpTrait::SymbolTable>()) {
342  from = from->getParentOp();
343 
344  // Check that this is a valid op and isn't an unknown symbol table.
345  if (!from || isPotentiallyUnknownSymbolTable(from))
346  return nullptr;
347  }
348  return from;
349 }
350 
351 /// Walks all symbol table operations nested within, and including, `op`. For
352 /// each symbol table operation, the provided callback is invoked with the op
353 /// and a boolean signifying if the symbols within that symbol table can be
354 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
355 /// all of the symbol uses of symbols within `op` are visible.
357  Operation *op, bool allSymUsesVisible,
358  function_ref<void(Operation *, bool)> callback) {
359  bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
360  if (isSymbolTable) {
361  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
362  allSymUsesVisible |= !symbol || symbol.isPrivate();
363  } else {
364  // Otherwise if 'op' is not a symbol table, any nested symbols are
365  // guaranteed to be hidden.
366  allSymUsesVisible = true;
367  }
368 
369  for (Region &region : op->getRegions())
370  for (Block &block : region)
371  for (Operation &nestedOp : block)
372  walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
373 
374  // If 'op' had the symbol table trait, visit it after any nested symbol
375  // tables.
376  if (isSymbolTable)
377  callback(op, allSymUsesVisible);
378 }
379 
380 /// Returns the operation registered with the given symbol name with the
381 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
382 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
383 /// was found.
385  StringAttr symbol) {
386  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
387  Region &region = symbolTableOp->getRegion(0);
388  if (region.empty())
389  return nullptr;
390 
391  // Look for a symbol with the given name.
392  StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394  for (auto &op : region.front())
395  if (getNameIfSymbol(&op, symbolNameId) == symbol)
396  return &op;
397  return nullptr;
398 }
400  SymbolRefAttr symbol) {
401  SmallVector<Operation *, 4> resolvedSymbols;
402  if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
403  return nullptr;
404  return resolvedSymbols.back();
405 }
406 
407 /// Internal implementation of `lookupSymbolIn` that allows for specialized
408 /// implementations of the lookup function.
409 static LogicalResult lookupSymbolInImpl(
410  Operation *symbolTableOp, SymbolRefAttr symbol,
412  function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
413  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
414 
415  // Lookup the root reference for this symbol.
416  symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
417  if (!symbolTableOp)
418  return failure();
419  symbols.push_back(symbolTableOp);
420 
421  // If there are no nested references, just return the root symbol directly.
422  ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
423  if (nestedRefs.empty())
424  return success();
425 
426  // Verify that the root is also a symbol table.
427  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
428  return failure();
429 
430  // Otherwise, lookup each of the nested non-leaf references and ensure that
431  // each corresponds to a valid symbol table.
432  for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
433  symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
434  if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
435  return failure();
436  symbols.push_back(symbolTableOp);
437  }
438  symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
439  return success(symbols.back());
440 }
441 
442 LogicalResult
443 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
444  SmallVectorImpl<Operation *> &symbols) {
445  auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
446  return lookupSymbolIn(symbolTableOp, symbol);
447  };
448  return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
449 }
450 
451 /// Returns the operation registered with the given symbol name within the
452 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
453 /// nullptr if no valid symbol was found.
455  StringAttr symbol) {
456  Operation *symbolTableOp = getNearestSymbolTable(from);
457  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
458 }
460  SymbolRefAttr symbol) {
461  Operation *symbolTableOp = getNearestSymbolTable(from);
462  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
463 }
464 
465 raw_ostream &mlir::operator<<(raw_ostream &os,
466  SymbolTable::Visibility visibility) {
467  switch (visibility) {
469  return os << "public";
471  return os << "private";
473  return os << "nested";
474  }
475  llvm_unreachable("Unexpected visibility");
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // SymbolTable Trait Types
480 //===----------------------------------------------------------------------===//
481 
483  if (op->getNumRegions() != 1)
484  return op->emitOpError()
485  << "Operations with a 'SymbolTable' must have exactly one region";
486  if (!op->getRegion(0).hasOneBlock())
487  return op->emitOpError()
488  << "Operations with a 'SymbolTable' must have exactly one block";
489 
490  // Check that all symbols are uniquely named within child regions.
491  DenseMap<Attribute, Location> nameToOrigLoc;
492  for (auto &block : op->getRegion(0)) {
493  for (auto &op : block) {
494  // Check for a symbol name attribute.
495  auto nameAttr =
497  if (!nameAttr)
498  continue;
499 
500  // Try to insert this symbol into the table.
501  auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
502  if (!it.second)
503  return op.emitError()
504  .append("redefinition of symbol named '", nameAttr.getValue(), "'")
505  .attachNote(it.first->second)
506  .append("see existing symbol definition here");
507  }
508  }
509 
510  // Verify any nested symbol user operations.
511  SymbolTableCollection symbolTable;
512  auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
513  if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
514  return WalkResult(user.verifySymbolUses(symbolTable));
515  return WalkResult::advance();
516  };
517 
518  std::optional<WalkResult> result =
519  walkSymbolTable(op->getRegions(), verifySymbolUserFn);
520  return success(result && !result->wasInterrupted());
521 }
522 
523 LogicalResult detail::verifySymbol(Operation *op) {
524  // Verify the name attribute.
525  if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
526  return op->emitOpError() << "requires string attribute '"
528 
529  // Verify the visibility attribute.
531  StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
532  if (!visStrAttr)
533  return op->emitOpError() << "requires visibility attribute '"
535  << "' to be a string attribute, but got " << vis;
536 
537  if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
538  visStrAttr.getValue()))
539  return op->emitOpError()
540  << "visibility expected to be one of [\"public\", \"private\", "
541  "\"nested\"], but got "
542  << visStrAttr;
543  }
544  return success();
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Symbol Use Lists
549 //===----------------------------------------------------------------------===//
550 
551 /// Walk all of the symbol references within the given operation, invoking the
552 /// provided callback for each found use. The callbacks takes the use of the
553 /// symbol.
554 static WalkResult
557  return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
558  [&](SymbolRefAttr symbolRef) {
559  if (callback({op, symbolRef}).wasInterrupted())
560  return WalkResult::interrupt();
561 
562  // Don't walk nested references.
563  return WalkResult::skip();
564  });
565 }
566 
567 /// Walk all of the uses, for any symbol, that are nested within the given
568 /// regions, invoking the provided callback for each. This does not traverse
569 /// into any nested symbol tables.
570 static std::optional<WalkResult>
573  return walkSymbolTable(regions,
574  [&](Operation *op) -> std::optional<WalkResult> {
575  // Check that this isn't a potentially unknown symbol
576  // table.
578  return std::nullopt;
579 
580  return walkSymbolRefs(op, callback);
581  });
582 }
583 /// Walk all of the uses, for any symbol, that are nested within the given
584 /// operation 'from', invoking the provided callback for each. This does not
585 /// traverse into any nested symbol tables.
586 static std::optional<WalkResult>
589  // If this operation has regions, and it, as well as its dialect, isn't
590  // registered then conservatively fail. The operation may define a
591  // symbol table, so we can't opaquely know if we should traverse to find
592  // nested uses.
594  return std::nullopt;
595 
596  // Walk the uses on this operation.
597  if (walkSymbolRefs(from, callback).wasInterrupted())
598  return WalkResult::interrupt();
599 
600  // Only recurse if this operation is not a symbol table. A symbol table
601  // defines a new scope, so we can't walk the attributes from within the symbol
602  // table op.
603  if (!from->hasTrait<OpTrait::SymbolTable>())
604  return walkSymbolUses(from->getRegions(), callback);
605  return WalkResult::advance();
606 }
607 
608 namespace {
609 /// This class represents a single symbol scope. A symbol scope represents the
610 /// set of operations nested within a symbol table that may reference symbols
611 /// within that table. A symbol scope does not contain the symbol table
612 /// operation itself, just its contained operations. A scope ends at leaf
613 /// operations or another symbol table operation.
614 struct SymbolScope {
615  /// Walk the symbol uses within this scope, invoking the given callback.
616  /// This variant is used when the callback type matches that expected by
617  /// 'walkSymbolUses'.
618  template <typename CallbackT,
619  std::enable_if_t<!std::is_same<
620  typename llvm::function_traits<CallbackT>::result_t,
621  void>::value> * = nullptr>
622  std::optional<WalkResult> walk(CallbackT cback) {
623  if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
624  return walkSymbolUses(*region, cback);
625  return walkSymbolUses(cast<Operation *>(limit), cback);
626  }
627  /// This variant is used when the callback type matches a stripped down type:
628  /// void(SymbolTable::SymbolUse use)
629  template <typename CallbackT,
630  std::enable_if_t<std::is_same<
631  typename llvm::function_traits<CallbackT>::result_t,
632  void>::value> * = nullptr>
633  std::optional<WalkResult> walk(CallbackT cback) {
634  return walk([=](SymbolTable::SymbolUse use) {
635  return cback(use), WalkResult::advance();
636  });
637  }
638 
639  /// Walk all of the operations nested under the current scope without
640  /// traversing into any nested symbol tables.
641  template <typename CallbackT>
642  std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
643  if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
644  return ::walkSymbolTable(*region, cback);
645  return ::walkSymbolTable(cast<Operation *>(limit), cback);
646  }
647 
648  /// The representation of the symbol within this scope.
649  SymbolRefAttr symbol;
650 
651  /// The IR unit representing this scope.
653 };
654 } // namespace
655 
656 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
658  Operation *limit) {
659  StringAttr symName = SymbolTable::getSymbolName(symbol);
660  assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
661 
662  // Compute the ancestors of 'limit'.
665  limitAncestors;
666  Operation *limitAncestor = limit;
667  do {
668  // Check to see if 'symbol' is an ancestor of 'limit'.
669  if (limitAncestor == symbol) {
670  // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
671  // doesn't support parent references.
673  symbol->getParentOp())
674  return {{SymbolRefAttr::get(symName), limit}};
675  return {};
676  }
677 
678  limitAncestors.insert(limitAncestor);
679  } while ((limitAncestor = limitAncestor->getParentOp()));
680 
681  // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
682  Operation *commonAncestor = symbol->getParentOp();
683  do {
684  if (limitAncestors.count(commonAncestor))
685  break;
686  } while ((commonAncestor = commonAncestor->getParentOp()));
687  assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
688 
689  // Compute the set of valid nested references for 'symbol' as far up to the
690  // common ancestor as possible.
692  bool collectedAllReferences = succeeded(
693  collectValidReferencesFor(symbol, symName, commonAncestor, references));
694 
695  // Handle the case where the common ancestor is 'limit'.
696  if (commonAncestor == limit) {
698 
699  // Walk each of the ancestors of 'symbol', calling the compute function for
700  // each one.
701  Operation *limitIt = symbol->getParentOp();
702  for (size_t i = 0, e = references.size(); i != e;
703  ++i, limitIt = limitIt->getParentOp()) {
704  assert(limitIt->hasTrait<OpTrait::SymbolTable>());
705  scopes.push_back({references[i], &limitIt->getRegion(0)});
706  }
707  return scopes;
708  }
709 
710  // Otherwise, we just need the symbol reference for 'symbol' that will be
711  // used within 'limit'. This is the last reference in the list we computed
712  // above if we were able to collect all references.
713  if (!collectedAllReferences)
714  return {};
715  return {{references.back(), limit}};
716 }
718  Region *limit) {
719  auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
720 
721  // If we collected some scopes to walk, make sure to constrain the one for
722  // limit to the specific region requested.
723  if (!scopes.empty())
724  scopes.back().limit = limit;
725  return scopes;
726 }
728  Region *limit) {
729  return {{SymbolRefAttr::get(symbol), limit}};
730 }
731 
733  Operation *limit) {
735  auto symbolRef = SymbolRefAttr::get(symbol);
736  for (auto &region : limit->getRegions())
737  scopes.push_back({symbolRef, &region});
738  return scopes;
739 }
740 
741 /// Returns true if the given reference 'SubRef' is a sub reference of the
742 /// reference 'ref', i.e. 'ref' is a further qualified reference.
743 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
744  if (ref == subRef)
745  return true;
746 
747  // If the references are not pointer equal, check to see if `subRef` is a
748  // prefix of `ref`.
749  if (llvm::isa<FlatSymbolRefAttr>(ref) ||
750  ref.getRootReference() != subRef.getRootReference())
751  return false;
752 
753  auto refLeafs = ref.getNestedReferences();
754  auto subRefLeafs = subRef.getNestedReferences();
755  return subRefLeafs.size() < refLeafs.size() &&
756  subRefLeafs == refLeafs.take_front(subRefLeafs.size());
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // SymbolTable::getSymbolUses
761 //===----------------------------------------------------------------------===//
762 
763 /// The implementation of SymbolTable::getSymbolUses below.
764 template <typename FromT>
765 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
766  std::vector<SymbolTable::SymbolUse> uses;
767  auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
768  uses.push_back(symbolUse);
769  return WalkResult::advance();
770  };
771  auto result = walkSymbolUses(from, walkFn);
772  return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
773  : std::nullopt;
774 }
775 
776 /// Get an iterator range for all of the uses, for any symbol, that are nested
777 /// within the given operation 'from'. This does not traverse into any nested
778 /// symbol tables, and will also only return uses on 'from' if it does not
779 /// also define a symbol table. This is because we treat the region as the
780 /// boundary of the symbol table, and not the op itself. This function returns
781 /// std::nullopt if there are any unknown operations that may potentially be
782 /// symbol tables.
783 auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
784  return getSymbolUsesImpl(from);
785 }
786 auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // SymbolTable::getSymbolUses
792 //===----------------------------------------------------------------------===//
793 
794 /// The implementation of SymbolTable::getSymbolUses below.
795 template <typename SymbolT, typename IRUnitT>
796 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
797  IRUnitT *limit) {
798  std::vector<SymbolTable::SymbolUse> uses;
799  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
800  if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
801  if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
802  uses.push_back(symbolUse);
803  }))
804  return std::nullopt;
805  }
806  return SymbolTable::UseRange(std::move(uses));
807 }
808 
809 /// Get all of the uses of the given symbol that are nested within the given
810 /// operation 'from'. This does not traverse into any nested symbol tables.
811 /// This function returns std::nullopt if there are any unknown operations that
812 /// may potentially be symbol tables.
813 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
814  -> std::optional<UseRange> {
815  return getSymbolUsesImpl(symbol, from);
816 }
818  -> std::optional<UseRange> {
819  return getSymbolUsesImpl(symbol, from);
820 }
821 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
822  -> std::optional<UseRange> {
823  return getSymbolUsesImpl(symbol, from);
824 }
826  -> std::optional<UseRange> {
827  return getSymbolUsesImpl(symbol, from);
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // SymbolTable::symbolKnownUseEmpty
832 //===----------------------------------------------------------------------===//
833 
834 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
835 template <typename SymbolT, typename IRUnitT>
836 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
837  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
838  // Walk all of the symbol uses looking for a reference to 'symbol'.
839  if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
840  return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
841  ? WalkResult::interrupt()
842  : WalkResult::advance();
843  }) != WalkResult::advance())
844  return false;
845  }
846  return true;
847 }
848 
849 /// Return if the given symbol is known to have no uses that are nested within
850 /// the given operation 'from'. This does not traverse into any nested symbol
851 /// tables. This function will also return false if there are any unknown
852 /// operations that may potentially be symbol tables.
853 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
854  return symbolKnownUseEmptyImpl(symbol, from);
855 }
857  return symbolKnownUseEmptyImpl(symbol, from);
858 }
859 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
860  return symbolKnownUseEmptyImpl(symbol, from);
861 }
863  return symbolKnownUseEmptyImpl(symbol, from);
864 }
865 
866 //===----------------------------------------------------------------------===//
867 // SymbolTable::replaceAllSymbolUses
868 //===----------------------------------------------------------------------===//
869 
870 /// Generates a new symbol reference attribute with a new leaf reference.
871 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
872  FlatSymbolRefAttr newLeafAttr) {
873  if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
874  return newLeafAttr;
875  auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
876  nestedRefs.back() = newLeafAttr;
877  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
878 }
879 
880 /// The implementation of SymbolTable::replaceAllSymbolUses below.
881 template <typename SymbolT, typename IRUnitT>
882 static LogicalResult
883 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
884  // Generate a new attribute to replace the given attribute.
885  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
886  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
887  SymbolRefAttr oldAttr = scope.symbol;
888  SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
889  AttrTypeReplacer replacer;
890  replacer.addReplacement(
891  [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
892  // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
893  // want to accidentally replace an inner reference.
894  if (attr == oldAttr)
895  return {newAttr, WalkResult::skip()};
896  // Handle prefix matches.
897  if (isReferencePrefixOf(oldAttr, attr)) {
898  auto oldNestedRefs = oldAttr.getNestedReferences();
899  auto nestedRefs = attr.getNestedReferences();
900  if (oldNestedRefs.empty())
901  return {SymbolRefAttr::get(newSymbol, nestedRefs),
902  WalkResult::skip()};
903 
904  auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
905  newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
906  return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
907  WalkResult::skip()};
908  }
909  return {attr, WalkResult::skip()};
910  });
911 
912  auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
913  replacer.replaceElementsIn(op);
914  return WalkResult::advance();
915  };
916  if (!scope.walkSymbolTable(walkFn))
917  return failure();
918  }
919  return success();
920 }
921 
922 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
923 /// provided symbol 'newSymbol' that are nested within the given operation
924 /// 'from'. This does not traverse into any nested symbol tables. If there are
925 /// any unknown operations that may potentially be symbol tables, no uses are
926 /// replaced and failure is returned.
927 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
928  StringAttr newSymbol,
929  Operation *from) {
930  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
931 }
933  StringAttr newSymbol,
934  Operation *from) {
935  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
936 }
937 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
938  StringAttr newSymbol,
939  Region *from) {
940  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
941 }
943  StringAttr newSymbol,
944  Region *from) {
945  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // SymbolTableCollection
950 //===----------------------------------------------------------------------===//
951 
953  StringAttr symbol) {
954  return getSymbolTable(symbolTableOp).lookup(symbol);
955 }
957  SymbolRefAttr name) {
959  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
960  return nullptr;
961  return symbols.back();
962 }
963 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
964 /// a given SymbolRefAttr. Returns failure if any of the nested references could
965 /// not be resolved.
966 LogicalResult
968  SymbolRefAttr name,
969  SmallVectorImpl<Operation *> &symbols) {
970  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
971  return lookupSymbolIn(symbolTableOp, symbol);
972  };
973  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
974 }
975 
976 /// Returns the operation registered with the given symbol name within the
977 /// closest parent operation of, or including, 'from' with the
978 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
979 /// found.
981  StringAttr symbol) {
982  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
983  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
984 }
985 Operation *
987  SymbolRefAttr symbol) {
988  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
989  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
990 }
991 
992 /// Lookup, or create, a symbol table for an operation.
994  auto it = symbolTables.try_emplace(op, nullptr);
995  if (it.second)
996  it.first->second = std::make_unique<SymbolTable>(op);
997  return *it.first->second;
998 }
999 
1001  symbolTables.erase(op);
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // LockedSymbolTableCollection
1006 //===----------------------------------------------------------------------===//
1007 
1009  StringAttr symbol) {
1010  return getSymbolTable(symbolTableOp).lookup(symbol);
1011 }
1012 
1013 Operation *
1015  FlatSymbolRefAttr symbol) {
1016  return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1017 }
1018 
1020  SymbolRefAttr name) {
1021  SmallVector<Operation *> symbols;
1022  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1023  return nullptr;
1024  return symbols.back();
1025 }
1026 
1028  Operation *symbolTableOp, SymbolRefAttr name,
1029  SmallVectorImpl<Operation *> &symbols) {
1030  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1031  return lookupSymbolIn(symbolTableOp, symbol);
1032  };
1033  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1034 }
1035 
1036 SymbolTable &
1037 LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1038  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1039  // Try to find an existing symbol table.
1040  {
1041  llvm::sys::SmartScopedReader<true> lock(mutex);
1042  auto it = collection.symbolTables.find(symbolTableOp);
1043  if (it != collection.symbolTables.end())
1044  return *it->second;
1045  }
1046  // Create a symbol table for the operation. Perform construction outside of
1047  // the critical section.
1048  auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1049  // Insert the constructed symbol table.
1050  llvm::sys::SmartScopedWriter<true> lock(mutex);
1051  return *collection.symbolTables
1052  .insert({symbolTableOp, std::move(symbolTable)})
1053  .first->second;
1054 }
1055 
1056 //===----------------------------------------------------------------------===//
1057 // SymbolUserMap
1058 //===----------------------------------------------------------------------===//
1059 
1061  Operation *symbolTableOp)
1062  : symbolTable(symbolTable) {
1063  // Walk each of the symbol tables looking for discardable callgraph nodes.
1064  SmallVector<Operation *> symbols;
1065  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1066  for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1067  auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1068  assert(symbolUses && "expected uses to be valid");
1069 
1070  for (const SymbolTable::SymbolUse &use : *symbolUses) {
1071  symbols.clear();
1072  (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1073  symbols);
1074  for (Operation *symbolOp : symbols)
1075  symbolToUsers[symbolOp].insert(use.getUser());
1076  }
1077  }
1078  };
1079  // We just set `allSymUsesVisible` to false here because it isn't necessary
1080  // for building the user map.
1081  SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1082  walkFn);
1083 }
1084 
1086  StringAttr newSymbolName) {
1087  auto it = symbolToUsers.find(symbol);
1088  if (it == symbolToUsers.end())
1089  return;
1090 
1091  // Replace the uses within the users of `symbol`.
1092  for (Operation *user : it->second)
1093  (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1094 
1095  // Move the current users of `symbol` to the new symbol if it is in the
1096  // symbol table.
1097  Operation *newSymbol =
1098  symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1099  if (newSymbol != symbol) {
1100  // Transfer over the users to the new symbol. The reference to the old one
1101  // is fetched again as the iterator is invalidated during the insertion.
1102  auto newIt = symbolToUsers.try_emplace(newSymbol);
1103  auto oldIt = symbolToUsers.find(symbol);
1104  assert(oldIt != symbolToUsers.end() && "missing old users list");
1105  if (newIt.second)
1106  newIt.first->second = std::move(oldIt->second);
1107  else
1108  newIt.first->second.set_union(oldIt->second);
1109  symbolToUsers.erase(oldIt);
1110  }
1111 }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // Visibility parsing implementation.
1115 //===----------------------------------------------------------------------===//
1116 
1118  NamedAttrList &attrs) {
1119  StringRef visibility;
1120  if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1121  return failure();
1122 
1123  StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1124  attrs.push_back(parser.getBuilder().getNamedAttr(
1125  SymbolTable::getVisibilityAttrName(), visibilityAttr));
1126  return success();
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // Symbol Interfaces
1131 //===----------------------------------------------------------------------===//
1132 
1133 /// Include the generated symbol interfaces.
1134 #include "mlir/IR/SymbolInterfaces.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
Definition: SymbolTable.cpp:80
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
Definition: SymbolTable.cpp:39
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
Definition: SymbolTable.cpp:27
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
Definition: SymbolTable.cpp:21
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:89
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Definition: Diagnostics.h:340
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol) override
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:295
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:600
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
virtual void invalidateSymbolTable(Operation *op)
Invalidate the cached symbol table for an operation.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
Definition: SymbolTable.h:183
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:203
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
Operation * getOp() const
Returns the associated operation.
Definition: SymbolTable.h:79
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition: SymbolTable.h:90
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition: SymbolTable.h:82
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
FailureOr< StringAttr > renameToUnique(StringAttr from, ArrayRef< SymbolTable * > others)
Renames the given op or the op refered to by the given name to the a name that is unique within this ...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78