MLIR 23.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/ADT/StringSwitch.h"
15#include <optional>
16
17using namespace mlir;
18
19/// Return true if the given operation is unknown and may potentially define a
20/// symbol table.
22 return op->getNumRegions() == 1 && !op->getDialect();
23}
24
25/// Returns the string name of the given symbol, or null if this is not a
26/// symbol.
27static StringAttr getNameIfSymbol(Operation *op) {
28 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
29}
30static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
31 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
32}
33
34/// Computes the nested symbol reference attribute for the symbol 'symbolName'
35/// that are usable within the symbol table operations from 'symbol' as far up
36/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
37/// Returns success if all references up to 'within' could be computed.
38static LogicalResult
39collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
40 Operation *within,
42 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
43 MLIRContext *ctx = symbol->getContext();
44
45 auto leafRef = FlatSymbolRefAttr::get(symbolName);
46 results.push_back(leafRef);
47
48 // Early exit for when 'within' is the parent of 'symbol'.
49 Operation *symbolTableOp = symbol->getParentOp();
50 if (within == symbolTableOp)
51 return success();
52
53 // Collect references until 'symbolTableOp' reaches 'within'.
54 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
55 StringAttr symbolNameId =
56 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
57 do {
58 // Each parent of 'symbol' should define a symbol table.
59 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
60 return failure();
61 // Each parent of 'symbol' should also be a symbol.
62 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
63 if (!symbolTableName)
64 return failure();
65 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
66
67 symbolTableOp = symbolTableOp->getParentOp();
68 if (symbolTableOp == within)
69 break;
70 nestedRefs.insert(nestedRefs.begin(),
71 FlatSymbolRefAttr::get(symbolTableName));
72 } while (true);
73 return success();
74}
75
76/// Walk all of the operations within the given set of regions, without
77/// traversing into any nested symbol tables. Stops walking if the result of the
78/// callback is anything other than `WalkResult::advance`.
79static std::optional<WalkResult>
81 function_ref<std::optional<WalkResult>(Operation *)> callback) {
82 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
83 while (!worklist.empty()) {
84 for (Operation &op : worklist.pop_back_val()->getOps()) {
85 std::optional<WalkResult> result = callback(&op);
87 return result;
88
89 // If this op defines a new symbol table scope, we can't traverse. Any
90 // symbol references nested within 'op' are different semantically.
91 if (!op.hasTrait<OpTrait::SymbolTable>()) {
92 for (Region &region : op.getRegions())
93 worklist.push_back(&region);
94 }
95 }
96 }
97 return WalkResult::advance();
98}
99
100/// Walk all of the operations nested under, and including, the given operation,
101/// without traversing into any nested symbol tables. Stops walking if the
102/// result of the callback is anything other than `WalkResult::advance`.
103static std::optional<WalkResult>
105 function_ref<std::optional<WalkResult>(Operation *)> callback) {
106 std::optional<WalkResult> result = callback(op);
108 return result;
109 return walkSymbolTable(op->getRegions(), callback);
110}
111
112//===----------------------------------------------------------------------===//
113// SymbolTable
114//===----------------------------------------------------------------------===//
115
116/// Build a symbol table with the symbols within the given operation.
118 : symbolTableOp(symbolTableOp) {
119 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
120 "expected operation to have SymbolTable trait");
121 assert(symbolTableOp->getNumRegions() == 1 &&
122 "expected operation to have a single region");
123 assert(symbolTableOp->getRegion(0).hasOneBlock() &&
124 "expected operation to have a single block");
125
126 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 for (auto &op : symbolTableOp->getRegion(0).front()) {
129 StringAttr name = getNameIfSymbol(&op, symbolNameId);
130 if (!name)
131 continue;
132
133 // Silently skip duplicate symbol names. Duplicate symbols are an
134 // invalid IR condition diagnosed by the SymbolTable trait's
135 // verifyRegionTrait. The constructor may be called before verification
136 // completes (e.g., when IsolatedFromAbove ops look up symbols in an
137 // ancestor symbol table during verification), so an assert here would
138 // crash instead of producing a proper diagnostic.
139 symbolTable.try_emplace(name, &op);
140 }
141}
142
143/// Look up a symbol with the specified name, returning null if no such name
144/// exists. Names never include the @ on them.
145Operation *SymbolTable::lookup(StringRef name) const {
146 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
147}
148Operation *SymbolTable::lookup(StringAttr name) const {
149 return symbolTable.lookup(name);
150}
151
153 StringAttr name = getNameIfSymbol(op);
154 assert(name && "expected valid 'name' attribute");
155 assert(op->getParentOp() == symbolTableOp &&
156 "expected this operation to be inside of the operation with this "
157 "SymbolTable");
158
159 auto it = symbolTable.find(name);
160 if (it != symbolTable.end() && it->second == op)
161 symbolTable.erase(it);
162}
163
165 remove(symbol);
166 symbol->erase();
167}
168
169// TODO: Consider if this should be renamed to something like insertOrUpdate
170/// Insert a new symbol into the table and associated operation if not already
171/// there and rename it as necessary to avoid collisions. Return the name of
172/// the symbol after insertion as attribute.
173StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
174 // The symbol cannot be the child of another op and must be the child of the
175 // symbolTableOp after this.
176 //
177 // TODO: consider if SymbolTable's constructor should behave the same.
178 if (!symbol->getParentOp()) {
179 auto &body = symbolTableOp->getRegion(0).front();
180 if (insertPt == Block::iterator()) {
181 insertPt = Block::iterator(body.end());
182 } else {
183 assert((insertPt == body.end() ||
184 insertPt->getParentOp() == symbolTableOp) &&
185 "expected insertPt to be in the associated module operation");
186 }
187 // Insert before the terminator, if any.
188 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
189 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
190 insertPt = std::prev(body.end());
191
192 body.getOperations().insert(insertPt, symbol);
193 }
194 assert(symbol->getParentOp() == symbolTableOp &&
195 "symbol is already inserted in another op");
196
197 // Add this symbol to the symbol table, uniquing the name if a conflict is
198 // detected.
199 StringAttr name = getSymbolName(symbol);
200 if (symbolTable.insert({name, symbol}).second)
201 return name;
202 // If the symbol was already in the table, also return.
203 if (symbolTable.lookup(name) == symbol)
204 return name;
205
206 MLIRContext *context = symbol->getContext();
208 name.getValue(),
209 [&](StringRef candidate) {
210 return !symbolTable
211 .insert({StringAttr::get(context, candidate), symbol})
212 .second;
213 },
214 uniquingCounter);
215 setSymbolName(symbol, nameBuffer);
216 return getSymbolName(symbol);
217}
218
219LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
220 Operation *op = lookup(from);
221 return rename(op, to);
222}
223
224LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
225 StringAttr from = getNameIfSymbol(op);
226 (void)from;
227
228 assert(from && "expected valid 'name' attribute");
229 assert(op->getParentOp() == symbolTableOp &&
230 "expected this operation to be inside of the operation with this "
231 "SymbolTable");
232 assert(lookup(from) == op && "current name does not resolve to op");
233 assert(lookup(to) == nullptr && "new name already exists");
234
235 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
236 return failure();
237
238 // Remove op with old name, change name, add with new name. The order is
239 // important here due to how `remove` and `insert` rely on the op name.
240 remove(op);
241 setSymbolName(op, to);
242 insert(op);
243
244 assert(lookup(to) == op && "new name does not resolve to renamed op");
245 assert(lookup(from) == nullptr && "old name still exists");
246
247 return success();
248}
249
250LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
251 auto toAttr = StringAttr::get(getOp()->getContext(), to);
252 return rename(from, toAttr);
253}
254
255LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
256 auto toAttr = StringAttr::get(getOp()->getContext(), to);
257 return rename(op, toAttr);
258}
259
260FailureOr<StringAttr>
263
264 // Determine new name that is unique in all symbol tables.
265 StringAttr newName;
266 {
267 MLIRContext *context = oldName.getContext();
268 SmallString<64> prefix = oldName.getValue();
269 int uniqueId = 0;
270 prefix.push_back('_');
271 while (true) {
272 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
273 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
274 if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
275 break;
276 }
277 }
278 }
279
280 // Apply renaming.
281 if (failed(rename(oldName, newName)))
282 return failure();
283 return newName;
284}
285
286FailureOr<StringAttr>
288 StringAttr from = getNameIfSymbol(op);
289 assert(from && "expected valid 'name' attribute");
290 return renameToUnique(from, others);
291}
292
293/// Returns the name of the given symbol operation.
295 StringAttr name = getNameIfSymbol(symbol);
296 assert(name && "expected valid symbol name");
297 return name;
298}
299
300/// Sets the name of the given symbol operation.
301void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
302 symbol->setAttr(getSymbolAttrName(), name);
303}
304
305/// Returns the visibility of the given symbol operation.
307 // If the attribute doesn't exist, assume public.
308 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
309 if (!vis)
310 return Visibility::Public;
311
312 // Otherwise, switch on the string value.
313 return StringSwitch<Visibility>(vis.getValue())
314 .Case("private", Visibility::Private)
315 .Case("nested", Visibility::Nested)
316 .Case("public", Visibility::Public);
317}
318/// Sets the visibility of the given symbol operation.
320 MLIRContext *ctx = symbol->getContext();
321
322 // If the visibility is public, just drop the attribute as this is the
323 // default.
324 if (vis == Visibility::Public) {
325 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
326 return;
327 }
328
329 // Otherwise, update the attribute.
330 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
331 "unknown symbol visibility kind");
332
333 StringRef visName = vis == Visibility::Private ? "private" : "nested";
334 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
335}
336
337/// Returns the nearest symbol table from a given operation `from`. Returns
338/// nullptr if no valid parent symbol table could be found.
340 assert(from && "expected valid operation");
342 return nullptr;
343
344 while (!from->hasTrait<OpTrait::SymbolTable>()) {
345 from = from->getParentOp();
346
347 // Check that this is a valid op and isn't an unknown symbol table.
348 if (!from || isPotentiallyUnknownSymbolTable(from))
349 return nullptr;
350 }
351 return from;
352}
353
354/// Walks all symbol table operations nested within, and including, `op`. For
355/// each symbol table operation, the provided callback is invoked with the op
356/// and a boolean signifying if the symbols within that symbol table can be
357/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
358/// all of the symbol uses of symbols within `op` are visible.
360 Operation *op, bool allSymUsesVisible,
361 function_ref<void(Operation *, bool)> callback) {
362 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
363 if (isSymbolTable) {
364 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
365 allSymUsesVisible |= !symbol || symbol.isPrivate();
366 } else {
367 // Otherwise if 'op' is not a symbol table, any nested symbols are
368 // guaranteed to be hidden.
369 allSymUsesVisible = true;
370 }
371
372 for (Region &region : op->getRegions())
373 for (Block &block : region)
374 for (Operation &nestedOp : block)
375 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
376
377 // If 'op' had the symbol table trait, visit it after any nested symbol
378 // tables.
379 if (isSymbolTable)
380 callback(op, allSymUsesVisible);
381}
382
383/// Returns the operation registered with the given symbol name with the
384/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
385/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
386/// was found.
388 StringAttr symbol) {
389 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
390 Region &region = symbolTableOp->getRegion(0);
391 if (region.empty())
392 return nullptr;
393
394 // Look for a symbol with the given name.
395 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
397 for (auto &op : region.front())
398 if (getNameIfSymbol(&op, symbolNameId) == symbol)
399 return &op;
400 return nullptr;
401}
403 SymbolRefAttr symbol) {
404 SmallVector<Operation *, 4> resolvedSymbols;
405 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
406 return nullptr;
407 return resolvedSymbols.back();
408}
409
410/// Internal implementation of `lookupSymbolIn` that allows for specialized
411/// implementations of the lookup function.
412static LogicalResult lookupSymbolInImpl(
413 Operation *symbolTableOp, SymbolRefAttr symbol,
415 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
416 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
417
418 // Lookup the root reference for this symbol.
419 auto *symbolOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
420 if (!symbolOp)
421 return failure();
422 symbols.push_back(symbolOp);
423
424 // Lookup each of the nested references.
425 for (FlatSymbolRefAttr ref : symbol.getNestedReferences()) {
426 // Check that we have a valid symbol table to lookup ref.
427 if (!symbolOp->hasTrait<OpTrait::SymbolTable>())
428 return failure();
429 symbolOp = lookupSymbolFn(symbolOp, ref.getAttr());
430 // If the nested symbol is private, lookup failed.
431 if (!symbolOp || SymbolTable::getSymbolVisibility(symbolOp) ==
433 return failure();
434 symbols.push_back(symbolOp);
435 }
436 return success();
437}
438
439LogicalResult
440SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
442 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
443 return lookupSymbolIn(symbolTableOp, symbol);
444 };
445 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
446}
447
448/// Returns the operation registered with the given symbol name within the
449/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
450/// nullptr if no valid symbol was found.
452 StringAttr symbol) {
453 Operation *symbolTableOp = getNearestSymbolTable(from);
454 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
455}
457 SymbolRefAttr symbol) {
458 Operation *symbolTableOp = getNearestSymbolTable(from);
459 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
460}
461
463 SymbolTable::Visibility visibility) {
464 switch (visibility) {
466 return os << "public";
468 return os << "private";
470 return os << "nested";
471 }
472 llvm_unreachable("Unexpected visibility");
473}
474
475//===----------------------------------------------------------------------===//
476// SymbolTable Trait Types
477//===----------------------------------------------------------------------===//
478
480 if (op->getNumRegions() != 1)
481 return op->emitOpError()
482 << "Operations with a 'SymbolTable' must have exactly one region";
483 if (!op->getRegion(0).hasOneBlock())
484 return op->emitOpError()
485 << "Operations with a 'SymbolTable' must have exactly one block";
486
487 // Check that all symbols are uniquely named within child regions.
488 DenseMap<Attribute, Location> nameToOrigLoc;
489 for (auto &block : op->getRegion(0)) {
490 for (auto &op : block) {
491 // Check for a symbol name attribute.
492 auto nameAttr =
494 if (!nameAttr)
495 continue;
496
497 // Try to insert this symbol into the table.
498 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
499 if (!it.second)
500 return op.emitError()
501 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
502 .attachNote(it.first->second)
503 .append("see existing symbol definition here");
504 }
505 }
506
507 // Verify any nested symbol user operations.
508 SymbolTableCollection symbolTable;
509 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
510 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
511 if (failed(user.verifySymbolUses(symbolTable)))
512 return WalkResult::interrupt();
513 for (auto &attr : op->getDiscardableAttrs()) {
514 if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
515 if (failed(user.verifySymbolUses(op, symbolTable)))
516 return WalkResult::interrupt();
517 }
518 }
519 return WalkResult::advance();
520 };
521
522 std::optional<WalkResult> result =
523 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
524 return success(result && !result->wasInterrupted());
525}
526
527LogicalResult detail::verifySymbol(Operation *op) {
528 // Verify the name attribute.
530 return op->emitOpError() << "requires string attribute '"
532
533 // Verify the visibility attribute.
535 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
536 if (!visStrAttr)
537 return op->emitOpError() << "requires visibility attribute '"
539 << "' to be a string attribute, but got " << vis;
540
541 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
542 visStrAttr.getValue()))
543 return op->emitOpError()
544 << "visibility expected to be one of [\"public\", \"private\", "
545 "\"nested\"], but got "
546 << visStrAttr;
547 }
548 return success();
549}
550
551//===----------------------------------------------------------------------===//
552// Symbol Use Lists
553//===----------------------------------------------------------------------===//
554
555/// Walk all of the symbol references within the given operation, invoking the
556/// provided callback for each found use. The callbacks takes the use of the
557/// symbol.
558static WalkResult
561 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
562 [&](SymbolRefAttr symbolRef) {
563 if (callback({op, symbolRef}).wasInterrupted())
564 return WalkResult::interrupt();
565
566 // Don't walk nested references.
567 return WalkResult::skip();
568 });
569}
570
571/// Walk all of the uses, for any symbol, that are nested within the given
572/// regions, invoking the provided callback for each. This does not traverse
573/// into any nested symbol tables.
574static std::optional<WalkResult>
577 return walkSymbolTable(regions,
578 [&](Operation *op) -> std::optional<WalkResult> {
579 // Check that this isn't a potentially unknown symbol
580 // table.
582 return std::nullopt;
583
584 return walkSymbolRefs(op, callback);
585 });
586}
587/// Walk all of the uses, for any symbol, that are nested within the given
588/// operation 'from', invoking the provided callback for each. This does not
589/// traverse into any nested symbol tables.
590static std::optional<WalkResult>
593 // If this operation has regions, and it, as well as its dialect, isn't
594 // registered then conservatively fail. The operation may define a
595 // symbol table, so we can't opaquely know if we should traverse to find
596 // nested uses.
598 return std::nullopt;
599
600 // Walk the uses on this operation.
601 if (walkSymbolRefs(from, callback).wasInterrupted())
602 return WalkResult::interrupt();
603
604 // Only recurse if this operation is not a symbol table. A symbol table
605 // defines a new scope, so we can't walk the attributes from within the symbol
606 // table op.
607 if (!from->hasTrait<OpTrait::SymbolTable>())
608 return walkSymbolUses(from->getRegions(), callback);
609 return WalkResult::advance();
610}
611
612namespace {
613/// This class represents a single symbol scope. A symbol scope represents the
614/// set of operations nested within a symbol table that may reference symbols
615/// within that table. A symbol scope does not contain the symbol table
616/// operation itself, just its contained operations. A scope ends at leaf
617/// operations or another symbol table operation.
618struct SymbolScope {
619 /// Walk the symbol uses within this scope, invoking the given callback.
620 /// This variant is used when the callback type matches that expected by
621 /// 'walkSymbolUses'.
622 template <typename CallbackT,
623 std::enable_if_t<!std::is_same<
624 typename llvm::function_traits<CallbackT>::result_t,
625 void>::value> * = nullptr>
626 std::optional<WalkResult> walk(CallbackT cback) {
627 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
628 return walkSymbolUses(*region, cback);
629 return walkSymbolUses(cast<Operation *>(limit), cback);
630 }
631 /// This variant is used when the callback type matches a stripped down type:
632 /// void(SymbolTable::SymbolUse use)
633 template <typename CallbackT,
634 std::enable_if_t<std::is_same<
635 typename llvm::function_traits<CallbackT>::result_t,
636 void>::value> * = nullptr>
637 std::optional<WalkResult> walk(CallbackT cback) {
638 return walk([=](SymbolTable::SymbolUse use) {
639 return cback(use), WalkResult::advance();
640 });
641 }
642
643 /// Walk all of the operations nested under the current scope without
644 /// traversing into any nested symbol tables.
645 template <typename CallbackT>
646 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
647 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
648 return ::walkSymbolTable(*region, cback);
649 return ::walkSymbolTable(cast<Operation *>(limit), cback);
650 }
651
652 /// The representation of the symbol within this scope.
653 SymbolRefAttr symbol;
654
655 /// The IR unit representing this scope.
656 llvm::PointerUnion<Operation *, Region *> limit;
657};
658} // namespace
659
660/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
662 Operation *limit) {
663 StringAttr symName = SymbolTable::getSymbolName(symbol);
664 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
665
666 // Compute the ancestors of 'limit'.
669 limitAncestors;
670 Operation *limitAncestor = limit;
671 do {
672 // Check to see if 'symbol' is an ancestor of 'limit'.
673 if (limitAncestor == symbol) {
674 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
675 // doesn't support parent references.
677 symbol->getParentOp())
678 return {{SymbolRefAttr::get(symName), limit}};
679 return {};
680 }
681
682 limitAncestors.insert(limitAncestor);
683 } while ((limitAncestor = limitAncestor->getParentOp()));
684
685 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
686 Operation *commonAncestor = symbol->getParentOp();
687 do {
688 if (limitAncestors.count(commonAncestor))
689 break;
690 } while ((commonAncestor = commonAncestor->getParentOp()));
691 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
692
693 // Compute the set of valid nested references for 'symbol' as far up to the
694 // common ancestor as possible.
696 bool collectedAllReferences = succeeded(
697 collectValidReferencesFor(symbol, symName, commonAncestor, references));
698
699 // Handle the case where the common ancestor is 'limit'.
700 if (commonAncestor == limit) {
702
703 // Walk each of the ancestors of 'symbol', calling the compute function for
704 // each one.
705 Operation *limitIt = symbol->getParentOp();
706 for (size_t i = 0, e = references.size(); i != e;
707 ++i, limitIt = limitIt->getParentOp()) {
708 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
709 scopes.push_back({references[i], &limitIt->getRegion(0)});
710 }
711 return scopes;
712 }
713
714 // Otherwise, we just need the symbol reference for 'symbol' that will be
715 // used within 'limit'. This is the last reference in the list we computed
716 // above if we were able to collect all references.
717 if (!collectedAllReferences)
718 return {};
719 return {{references.back(), limit}};
720}
722 Region *limit) {
723 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
724
725 // If we collected some scopes to walk, make sure to constrain the one for
726 // limit to the specific region requested.
727 if (!scopes.empty())
728 scopes.back().limit = limit;
729 return scopes;
730}
732 Region *limit) {
733 return {{SymbolRefAttr::get(symbol), limit}};
734}
735
737 Operation *limit) {
739 auto symbolRef = SymbolRefAttr::get(symbol);
740 for (auto &region : limit->getRegions())
741 scopes.push_back({symbolRef, &region});
742 return scopes;
743}
744
745/// Returns true if the given reference 'SubRef' is a sub reference of the
746/// reference 'ref', i.e. 'ref' is a further qualified reference.
747static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
748 if (ref == subRef)
749 return true;
750
751 // If the references are not pointer equal, check to see if `subRef` is a
752 // prefix of `ref`.
753 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
754 ref.getRootReference() != subRef.getRootReference())
755 return false;
756
757 auto refLeafs = ref.getNestedReferences();
758 auto subRefLeafs = subRef.getNestedReferences();
759 return subRefLeafs.size() < refLeafs.size() &&
760 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
761}
762
763//===----------------------------------------------------------------------===//
764// SymbolTable::getSymbolUses
765//===----------------------------------------------------------------------===//
766
767/// The implementation of SymbolTable::getSymbolUses below.
768template <typename FromT>
769static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
770 std::vector<SymbolTable::SymbolUse> uses;
771 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
772 uses.push_back(symbolUse);
773 return WalkResult::advance();
774 };
775 auto result = walkSymbolUses(from, walkFn);
776 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
777 : std::nullopt;
778}
779
780/// Get an iterator range for all of the uses, for any symbol, that are nested
781/// within the given operation 'from'. This does not traverse into any nested
782/// symbol tables, and will also only return uses on 'from' if it does not
783/// also define a symbol table. This is because we treat the region as the
784/// boundary of the symbol table, and not the op itself. This function returns
785/// std::nullopt if there are any unknown operations that may potentially be
786/// symbol tables.
787auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
788 return getSymbolUsesImpl(from);
789}
790auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
792}
793
794//===----------------------------------------------------------------------===//
795// SymbolTable::getSymbolUses
796//===----------------------------------------------------------------------===//
797
798/// The implementation of SymbolTable::getSymbolUses below.
799template <typename SymbolT, typename IRUnitT>
800static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
801 IRUnitT *limit) {
802 std::vector<SymbolTable::SymbolUse> uses;
803 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
804 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
805 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
806 uses.push_back(symbolUse);
807 }))
808 return std::nullopt;
809 }
810 return SymbolTable::UseRange(std::move(uses));
811}
812
813/// Get all of the uses of the given symbol that are nested within the given
814/// operation 'from'. This does not traverse into any nested symbol tables.
815/// This function returns std::nullopt if there are any unknown operations that
816/// may potentially be symbol tables.
817auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
818 -> std::optional<UseRange> {
819 return getSymbolUsesImpl(symbol, from);
820}
822 -> std::optional<UseRange> {
823 return getSymbolUsesImpl(symbol, from);
824}
825auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
826 -> std::optional<UseRange> {
827 return getSymbolUsesImpl(symbol, from);
828}
830 -> std::optional<UseRange> {
831 return getSymbolUsesImpl(symbol, from);
832}
833
834//===----------------------------------------------------------------------===//
835// SymbolTable::symbolKnownUseEmpty
836//===----------------------------------------------------------------------===//
837
838/// The implementation of SymbolTable::symbolKnownUseEmpty below.
839template <typename SymbolT, typename IRUnitT>
840static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
841 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
842 // Walk all of the symbol uses looking for a reference to 'symbol'.
843 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
844 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
845 ? WalkResult::interrupt()
846 : WalkResult::advance();
847 }) != WalkResult::advance())
848 return false;
849 }
850 return true;
851}
852
853/// Return if the given symbol is known to have no uses that are nested within
854/// the given operation 'from'. This does not traverse into any nested symbol
855/// tables. This function will also return false if there are any unknown
856/// operations that may potentially be symbol tables.
857bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
858 return symbolKnownUseEmptyImpl(symbol, from);
859}
861 return symbolKnownUseEmptyImpl(symbol, from);
862}
863bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
864 return symbolKnownUseEmptyImpl(symbol, from);
865}
867 return symbolKnownUseEmptyImpl(symbol, from);
868}
869
870//===----------------------------------------------------------------------===//
871// SymbolTable::replaceAllSymbolUses
872//===----------------------------------------------------------------------===//
873
874/// Generates a new symbol reference attribute with a new leaf reference.
875static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
876 FlatSymbolRefAttr newLeafAttr) {
877 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
878 return newLeafAttr;
879 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
880 nestedRefs.back() = newLeafAttr;
881 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
882}
883
884/// The implementation of SymbolTable::replaceAllSymbolUses below.
885template <typename SymbolT, typename IRUnitT>
886static LogicalResult
887replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
888 // Generate a new attribute to replace the given attribute.
889 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
890 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
891 SymbolRefAttr oldAttr = scope.symbol;
892 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
893 AttrTypeReplacer replacer;
894 replacer.addReplacement(
895 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
896 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
897 // want to accidentally replace an inner reference.
898 if (attr == oldAttr)
899 return {newAttr, WalkResult::skip()};
900 // Handle prefix matches.
901 if (isReferencePrefixOf(oldAttr, attr)) {
902 auto oldNestedRefs = oldAttr.getNestedReferences();
903 auto nestedRefs = attr.getNestedReferences();
904 if (oldNestedRefs.empty())
905 return {SymbolRefAttr::get(newSymbol, nestedRefs),
907
908 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
909 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
910 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
912 }
913 return {attr, WalkResult::skip()};
914 });
915
916 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
917 replacer.replaceElementsIn(op);
918 return WalkResult::advance();
919 };
920 if (!scope.walkSymbolTable(walkFn))
921 return failure();
922 }
923 return success();
924}
925
926/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
927/// provided symbol 'newSymbol' that are nested within the given operation
928/// 'from'. This does not traverse into any nested symbol tables. If there are
929/// any unknown operations that may potentially be symbol tables, no uses are
930/// replaced and failure is returned.
931LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
932 StringAttr newSymbol,
933 Operation *from) {
934 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
935}
937 StringAttr newSymbol,
938 Operation *from) {
939 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
940}
941LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
942 StringAttr newSymbol,
943 Region *from) {
944 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
945}
947 StringAttr newSymbol,
948 Region *from) {
949 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
950}
951
952//===----------------------------------------------------------------------===//
953// SymbolTableCollection
954//===----------------------------------------------------------------------===//
955
957 StringAttr symbol) {
958 return getSymbolTable(symbolTableOp).lookup(symbol);
959}
961 SymbolRefAttr name) {
963 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
964 return nullptr;
965 return symbols.back();
966}
967/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
968/// a given SymbolRefAttr. Returns failure if any of the nested references could
969/// not be resolved.
970LogicalResult
972 SymbolRefAttr name,
974 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
975 return lookupSymbolIn(symbolTableOp, symbol);
976 };
977 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
978}
979
980/// Returns the operation registered with the given symbol name within the
981/// closest parent operation of, or including, 'from' with the
982/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
983/// found.
985 StringAttr symbol) {
986 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
987 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
988}
989Operation *
991 SymbolRefAttr symbol) {
992 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
993 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
994}
995
996/// Lookup, or create, a symbol table for an operation.
998 auto it = symbolTables.try_emplace(op, nullptr);
999 if (it.second)
1000 it.first->second = std::make_unique<SymbolTable>(op);
1001 return *it.first->second;
1002}
1003
1005 symbolTables.erase(op);
1006}
1007
1008//===----------------------------------------------------------------------===//
1009// LockedSymbolTableCollection
1010//===----------------------------------------------------------------------===//
1011
1013 StringAttr symbol) {
1014 return getSymbolTable(symbolTableOp).lookup(symbol);
1015}
1016
1017Operation *
1019 FlatSymbolRefAttr symbol) {
1020 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1021}
1022
1024 SymbolRefAttr name) {
1026 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1027 return nullptr;
1028 return symbols.back();
1029}
1030
1032 Operation *symbolTableOp, SymbolRefAttr name,
1034 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1035 return lookupSymbolIn(symbolTableOp, symbol);
1036 };
1037 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1038}
1039
1041LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1042 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1043 // Try to find an existing symbol table.
1044 {
1045 llvm::sys::SmartScopedReader<true> lock(mutex);
1046 auto it = collection.symbolTables.find(symbolTableOp);
1047 if (it != collection.symbolTables.end())
1048 return *it->second;
1049 }
1050 // Create a symbol table for the operation. Perform construction outside of
1051 // the critical section.
1052 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1053 // Insert the constructed symbol table.
1054 llvm::sys::SmartScopedWriter<true> lock(mutex);
1055 return *collection.symbolTables
1056 .insert({symbolTableOp, std::move(symbolTable)})
1057 .first->second;
1058}
1059
1060//===----------------------------------------------------------------------===//
1061// SymbolUserMap
1062//===----------------------------------------------------------------------===//
1063
1065 Operation *symbolTableOp)
1066 : symbolTable(symbolTable) {
1067 // Walk each of the symbol tables looking for discardable callgraph nodes.
1069 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1070 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1071 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1072 assert(symbolUses && "expected uses to be valid");
1073
1074 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1075 symbols.clear();
1076 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1077 symbols);
1078 for (Operation *symbolOp : symbols)
1079 symbolToUsers[symbolOp].insert(use.getUser());
1080 }
1081 }
1082 };
1083 // We just set `allSymUsesVisible` to false here because it isn't necessary
1084 // for building the user map.
1085 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1086 walkFn);
1087}
1088
1090 StringAttr newSymbolName) {
1091 auto it = symbolToUsers.find(symbol);
1092 if (it == symbolToUsers.end())
1093 return;
1094
1095 // Replace the uses within the users of `symbol`.
1096 for (Operation *user : it->second)
1097 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1098
1099 // Move the current users of `symbol` to the new symbol if it is in the
1100 // symbol table.
1101 Operation *newSymbol =
1102 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1103 if (newSymbol != symbol) {
1104 // Transfer over the users to the new symbol. The reference to the old one
1105 // is fetched again as the iterator is invalidated during the insertion.
1106 auto newIt = symbolToUsers.try_emplace(newSymbol);
1107 auto oldIt = symbolToUsers.find(symbol);
1108 assert(oldIt != symbolToUsers.end() && "missing old users list");
1109 if (newIt.second)
1110 newIt.first->second = std::move(oldIt->second);
1111 else
1112 newIt.first->second.set_union(oldIt->second);
1113 symbolToUsers.erase(oldIt);
1114 }
1115}
1116
1117//===----------------------------------------------------------------------===//
1118// Visibility parsing implementation.
1119//===----------------------------------------------------------------------===//
1120
1122 NamedAttrList &attrs) {
1123 StringRef visibility;
1124 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1125 return failure();
1126
1127 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1128 attrs.push_back(parser.getBuilder().getNamedAttr(
1129 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1130 return success();
1131}
1132
1133//===----------------------------------------------------------------------===//
1134// Symbol Interfaces
1135//===----------------------------------------------------------------------===//
1136
1137/// Include the generated symbol interfaces.
1138#include "mlir/IR/SymbolInterfaces.cpp.inc"
1139#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"
return success()
b getContext())
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol) override
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:241
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:579
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:563
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:611
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:515
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:292
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:629
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
virtual void invalidateSymbolTable(Operation *op)
Invalidate the cached symbol table for an operation.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static SmallString< N > generateSymbolName(StringRef name, UniqueChecker uniqueChecker, unsigned &uniquingCounter)
Generate a unique symbol name.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
Definition SymbolTable.h:93
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition SymbolTable.h:82
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
Operation * getOp() const
Returns the associated operation.
Definition SymbolTable.h:79
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
FailureOr< StringAttr > renameToUnique(StringAttr from, ArrayRef< SymbolTable * > others)
Renames the given op or the op refered to by the given name to the a name that is unique within this ...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::StringSwitch< T, R > StringSwitch
Definition LLVM.h:133
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144