MLIR 23.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/ADT/StringSwitch.h"
15#include <optional>
16
17using namespace mlir;
18
19/// Return true if the given operation is unknown and may potentially define a
20/// symbol table.
22 return op->getNumRegions() == 1 && !op->getDialect();
23}
24
25/// Returns the string name of the given symbol, or null if this is not a
26/// symbol.
27static StringAttr getNameIfSymbol(Operation *op) {
28 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
29}
30static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
31 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
32}
33
34/// Computes the nested symbol reference attribute for the symbol 'symbolName'
35/// that are usable within the symbol table operations from 'symbol' as far up
36/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
37/// Returns success if all references up to 'within' could be computed.
38static LogicalResult
39collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
40 Operation *within,
42 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
43 MLIRContext *ctx = symbol->getContext();
44
45 auto leafRef = FlatSymbolRefAttr::get(symbolName);
46 results.push_back(leafRef);
47
48 // Early exit for when 'within' is the parent of 'symbol'.
49 Operation *symbolTableOp = symbol->getParentOp();
50 if (within == symbolTableOp)
51 return success();
52
53 // Collect references until 'symbolTableOp' reaches 'within'.
54 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
55 StringAttr symbolNameId =
56 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
57 do {
58 // Each parent of 'symbol' should define a symbol table.
59 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
60 return failure();
61 // Each parent of 'symbol' should also be a symbol.
62 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
63 if (!symbolTableName)
64 return failure();
65 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
66
67 symbolTableOp = symbolTableOp->getParentOp();
68 if (symbolTableOp == within)
69 break;
70 nestedRefs.insert(nestedRefs.begin(),
71 FlatSymbolRefAttr::get(symbolTableName));
72 } while (true);
73 return success();
74}
75
76/// Walk all of the operations within the given set of regions, without
77/// traversing into any nested symbol tables. Stops walking if the result of the
78/// callback is anything other than `WalkResult::advance`.
79static std::optional<WalkResult>
81 function_ref<std::optional<WalkResult>(Operation *)> callback) {
82 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
83 while (!worklist.empty()) {
84 for (Operation &op : worklist.pop_back_val()->getOps()) {
85 std::optional<WalkResult> result = callback(&op);
87 return result;
88
89 // If this op defines a new symbol table scope, we can't traverse. Any
90 // symbol references nested within 'op' are different semantically.
91 if (!op.hasTrait<OpTrait::SymbolTable>()) {
92 for (Region &region : op.getRegions())
93 worklist.push_back(&region);
94 }
95 }
96 }
97 return WalkResult::advance();
98}
99
100/// Walk all of the operations nested under, and including, the given operation,
101/// without traversing into any nested symbol tables. Stops walking if the
102/// result of the callback is anything other than `WalkResult::advance`.
103static std::optional<WalkResult>
105 function_ref<std::optional<WalkResult>(Operation *)> callback) {
106 std::optional<WalkResult> result = callback(op);
108 return result;
109 return walkSymbolTable(op->getRegions(), callback);
110}
111
112//===----------------------------------------------------------------------===//
113// SymbolTable
114//===----------------------------------------------------------------------===//
115
116/// Build a symbol table with the symbols within the given operation.
118 : symbolTableOp(symbolTableOp) {
119 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
120 "expected operation to have SymbolTable trait");
121 assert(symbolTableOp->getNumRegions() == 1 &&
122 "expected operation to have a single region");
123 assert(symbolTableOp->getRegion(0).hasOneBlock() &&
124 "expected operation to have a single block");
125
126 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 for (auto &op : symbolTableOp->getRegion(0).front()) {
129 StringAttr name = getNameIfSymbol(&op, symbolNameId);
130 if (!name)
131 continue;
132
133 auto inserted = symbolTable.insert({name, &op});
134 (void)inserted;
135 assert(inserted.second &&
136 "expected region to contain uniquely named symbol operations");
137 }
138}
139
140/// Look up a symbol with the specified name, returning null if no such name
141/// exists. Names never include the @ on them.
142Operation *SymbolTable::lookup(StringRef name) const {
143 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
144}
145Operation *SymbolTable::lookup(StringAttr name) const {
146 return symbolTable.lookup(name);
147}
148
150 StringAttr name = getNameIfSymbol(op);
151 assert(name && "expected valid 'name' attribute");
152 assert(op->getParentOp() == symbolTableOp &&
153 "expected this operation to be inside of the operation with this "
154 "SymbolTable");
155
156 auto it = symbolTable.find(name);
157 if (it != symbolTable.end() && it->second == op)
158 symbolTable.erase(it);
159}
160
162 remove(symbol);
163 symbol->erase();
164}
165
166// TODO: Consider if this should be renamed to something like insertOrUpdate
167/// Insert a new symbol into the table and associated operation if not already
168/// there and rename it as necessary to avoid collisions. Return the name of
169/// the symbol after insertion as attribute.
170StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
171 // The symbol cannot be the child of another op and must be the child of the
172 // symbolTableOp after this.
173 //
174 // TODO: consider if SymbolTable's constructor should behave the same.
175 if (!symbol->getParentOp()) {
176 auto &body = symbolTableOp->getRegion(0).front();
177 if (insertPt == Block::iterator()) {
178 insertPt = Block::iterator(body.end());
179 } else {
180 assert((insertPt == body.end() ||
181 insertPt->getParentOp() == symbolTableOp) &&
182 "expected insertPt to be in the associated module operation");
183 }
184 // Insert before the terminator, if any.
185 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
186 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
187 insertPt = std::prev(body.end());
188
189 body.getOperations().insert(insertPt, symbol);
190 }
191 assert(symbol->getParentOp() == symbolTableOp &&
192 "symbol is already inserted in another op");
193
194 // Add this symbol to the symbol table, uniquing the name if a conflict is
195 // detected.
196 StringAttr name = getSymbolName(symbol);
197 if (symbolTable.insert({name, symbol}).second)
198 return name;
199 // If the symbol was already in the table, also return.
200 if (symbolTable.lookup(name) == symbol)
201 return name;
202
203 MLIRContext *context = symbol->getContext();
205 name.getValue(),
206 [&](StringRef candidate) {
207 return !symbolTable
208 .insert({StringAttr::get(context, candidate), symbol})
209 .second;
210 },
211 uniquingCounter);
212 setSymbolName(symbol, nameBuffer);
213 return getSymbolName(symbol);
214}
215
216LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
217 Operation *op = lookup(from);
218 return rename(op, to);
219}
220
221LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
222 StringAttr from = getNameIfSymbol(op);
223 (void)from;
224
225 assert(from && "expected valid 'name' attribute");
226 assert(op->getParentOp() == symbolTableOp &&
227 "expected this operation to be inside of the operation with this "
228 "SymbolTable");
229 assert(lookup(from) == op && "current name does not resolve to op");
230 assert(lookup(to) == nullptr && "new name already exists");
231
232 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
233 return failure();
234
235 // Remove op with old name, change name, add with new name. The order is
236 // important here due to how `remove` and `insert` rely on the op name.
237 remove(op);
238 setSymbolName(op, to);
239 insert(op);
240
241 assert(lookup(to) == op && "new name does not resolve to renamed op");
242 assert(lookup(from) == nullptr && "old name still exists");
243
244 return success();
245}
246
247LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
248 auto toAttr = StringAttr::get(getOp()->getContext(), to);
249 return rename(from, toAttr);
250}
251
252LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
253 auto toAttr = StringAttr::get(getOp()->getContext(), to);
254 return rename(op, toAttr);
255}
256
257FailureOr<StringAttr>
260
261 // Determine new name that is unique in all symbol tables.
262 StringAttr newName;
263 {
264 MLIRContext *context = oldName.getContext();
265 SmallString<64> prefix = oldName.getValue();
266 int uniqueId = 0;
267 prefix.push_back('_');
268 while (true) {
269 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
270 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
271 if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
272 break;
273 }
274 }
275 }
276
277 // Apply renaming.
278 if (failed(rename(oldName, newName)))
279 return failure();
280 return newName;
281}
282
283FailureOr<StringAttr>
285 StringAttr from = getNameIfSymbol(op);
286 assert(from && "expected valid 'name' attribute");
287 return renameToUnique(from, others);
288}
289
290/// Returns the name of the given symbol operation.
292 StringAttr name = getNameIfSymbol(symbol);
293 assert(name && "expected valid symbol name");
294 return name;
295}
296
297/// Sets the name of the given symbol operation.
298void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
299 symbol->setAttr(getSymbolAttrName(), name);
300}
301
302/// Returns the visibility of the given symbol operation.
304 // If the attribute doesn't exist, assume public.
305 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
306 if (!vis)
307 return Visibility::Public;
308
309 // Otherwise, switch on the string value.
310 return StringSwitch<Visibility>(vis.getValue())
311 .Case("private", Visibility::Private)
312 .Case("nested", Visibility::Nested)
313 .Case("public", Visibility::Public);
314}
315/// Sets the visibility of the given symbol operation.
317 MLIRContext *ctx = symbol->getContext();
318
319 // If the visibility is public, just drop the attribute as this is the
320 // default.
321 if (vis == Visibility::Public) {
322 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
323 return;
324 }
325
326 // Otherwise, update the attribute.
327 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
328 "unknown symbol visibility kind");
329
330 StringRef visName = vis == Visibility::Private ? "private" : "nested";
331 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
332}
333
334/// Returns the nearest symbol table from a given operation `from`. Returns
335/// nullptr if no valid parent symbol table could be found.
337 assert(from && "expected valid operation");
339 return nullptr;
340
341 while (!from->hasTrait<OpTrait::SymbolTable>()) {
342 from = from->getParentOp();
343
344 // Check that this is a valid op and isn't an unknown symbol table.
345 if (!from || isPotentiallyUnknownSymbolTable(from))
346 return nullptr;
347 }
348 return from;
349}
350
351/// Walks all symbol table operations nested within, and including, `op`. For
352/// each symbol table operation, the provided callback is invoked with the op
353/// and a boolean signifying if the symbols within that symbol table can be
354/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
355/// all of the symbol uses of symbols within `op` are visible.
357 Operation *op, bool allSymUsesVisible,
358 function_ref<void(Operation *, bool)> callback) {
359 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
360 if (isSymbolTable) {
361 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
362 allSymUsesVisible |= !symbol || symbol.isPrivate();
363 } else {
364 // Otherwise if 'op' is not a symbol table, any nested symbols are
365 // guaranteed to be hidden.
366 allSymUsesVisible = true;
367 }
368
369 for (Region &region : op->getRegions())
370 for (Block &block : region)
371 for (Operation &nestedOp : block)
372 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
373
374 // If 'op' had the symbol table trait, visit it after any nested symbol
375 // tables.
376 if (isSymbolTable)
377 callback(op, allSymUsesVisible);
378}
379
380/// Returns the operation registered with the given symbol name with the
381/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
382/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
383/// was found.
385 StringAttr symbol) {
386 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
387 Region &region = symbolTableOp->getRegion(0);
388 if (region.empty())
389 return nullptr;
390
391 // Look for a symbol with the given name.
392 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394 for (auto &op : region.front())
395 if (getNameIfSymbol(&op, symbolNameId) == symbol)
396 return &op;
397 return nullptr;
398}
400 SymbolRefAttr symbol) {
401 SmallVector<Operation *, 4> resolvedSymbols;
402 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
403 return nullptr;
404 return resolvedSymbols.back();
405}
406
407/// Internal implementation of `lookupSymbolIn` that allows for specialized
408/// implementations of the lookup function.
409static LogicalResult lookupSymbolInImpl(
410 Operation *symbolTableOp, SymbolRefAttr symbol,
412 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
413 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
414
415 // Lookup the root reference for this symbol.
416 auto *symbolOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
417 if (!symbolOp)
418 return failure();
419 symbols.push_back(symbolOp);
420
421 // Lookup each of the nested references.
422 for (FlatSymbolRefAttr ref : symbol.getNestedReferences()) {
423 // Check that we have a valid symbol table to lookup ref.
424 if (!symbolOp->hasTrait<OpTrait::SymbolTable>())
425 return failure();
426 symbolOp = lookupSymbolFn(symbolOp, ref.getAttr());
427 // If the nested symbol is private, lookup failed.
428 if (!symbolOp || SymbolTable::getSymbolVisibility(symbolOp) ==
430 return failure();
431 symbols.push_back(symbolOp);
432 }
433 return success();
434}
435
436LogicalResult
437SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
439 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
440 return lookupSymbolIn(symbolTableOp, symbol);
441 };
442 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
443}
444
445/// Returns the operation registered with the given symbol name within the
446/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
447/// nullptr if no valid symbol was found.
449 StringAttr symbol) {
450 Operation *symbolTableOp = getNearestSymbolTable(from);
451 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
452}
454 SymbolRefAttr symbol) {
455 Operation *symbolTableOp = getNearestSymbolTable(from);
456 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
457}
458
460 SymbolTable::Visibility visibility) {
461 switch (visibility) {
463 return os << "public";
465 return os << "private";
467 return os << "nested";
468 }
469 llvm_unreachable("Unexpected visibility");
470}
471
472//===----------------------------------------------------------------------===//
473// SymbolTable Trait Types
474//===----------------------------------------------------------------------===//
475
477 if (op->getNumRegions() != 1)
478 return op->emitOpError()
479 << "Operations with a 'SymbolTable' must have exactly one region";
480 if (!op->getRegion(0).hasOneBlock())
481 return op->emitOpError()
482 << "Operations with a 'SymbolTable' must have exactly one block";
483
484 // Check that all symbols are uniquely named within child regions.
485 DenseMap<Attribute, Location> nameToOrigLoc;
486 for (auto &block : op->getRegion(0)) {
487 for (auto &op : block) {
488 // Check for a symbol name attribute.
489 auto nameAttr =
491 if (!nameAttr)
492 continue;
493
494 // Try to insert this symbol into the table.
495 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
496 if (!it.second)
497 return op.emitError()
498 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
499 .attachNote(it.first->second)
500 .append("see existing symbol definition here");
501 }
502 }
503
504 // Verify any nested symbol user operations.
505 SymbolTableCollection symbolTable;
506 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
507 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
508 if (failed(user.verifySymbolUses(symbolTable)))
509 return WalkResult::interrupt();
510 for (auto &attr : op->getDiscardableAttrs()) {
511 if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
512 if (failed(user.verifySymbolUses(op, symbolTable)))
513 return WalkResult::interrupt();
514 }
515 }
516 return WalkResult::advance();
517 };
518
519 std::optional<WalkResult> result =
520 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
521 return success(result && !result->wasInterrupted());
522}
523
524LogicalResult detail::verifySymbol(Operation *op) {
525 // Verify the name attribute.
527 return op->emitOpError() << "requires string attribute '"
529
530 // Verify the visibility attribute.
532 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
533 if (!visStrAttr)
534 return op->emitOpError() << "requires visibility attribute '"
536 << "' to be a string attribute, but got " << vis;
537
538 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
539 visStrAttr.getValue()))
540 return op->emitOpError()
541 << "visibility expected to be one of [\"public\", \"private\", "
542 "\"nested\"], but got "
543 << visStrAttr;
544 }
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Symbol Use Lists
550//===----------------------------------------------------------------------===//
551
552/// Walk all of the symbol references within the given operation, invoking the
553/// provided callback for each found use. The callbacks takes the use of the
554/// symbol.
555static WalkResult
558 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
559 [&](SymbolRefAttr symbolRef) {
560 if (callback({op, symbolRef}).wasInterrupted())
561 return WalkResult::interrupt();
562
563 // Don't walk nested references.
564 return WalkResult::skip();
565 });
566}
567
568/// Walk all of the uses, for any symbol, that are nested within the given
569/// regions, invoking the provided callback for each. This does not traverse
570/// into any nested symbol tables.
571static std::optional<WalkResult>
574 return walkSymbolTable(regions,
575 [&](Operation *op) -> std::optional<WalkResult> {
576 // Check that this isn't a potentially unknown symbol
577 // table.
579 return std::nullopt;
580
581 return walkSymbolRefs(op, callback);
582 });
583}
584/// Walk all of the uses, for any symbol, that are nested within the given
585/// operation 'from', invoking the provided callback for each. This does not
586/// traverse into any nested symbol tables.
587static std::optional<WalkResult>
590 // If this operation has regions, and it, as well as its dialect, isn't
591 // registered then conservatively fail. The operation may define a
592 // symbol table, so we can't opaquely know if we should traverse to find
593 // nested uses.
595 return std::nullopt;
596
597 // Walk the uses on this operation.
598 if (walkSymbolRefs(from, callback).wasInterrupted())
599 return WalkResult::interrupt();
600
601 // Only recurse if this operation is not a symbol table. A symbol table
602 // defines a new scope, so we can't walk the attributes from within the symbol
603 // table op.
604 if (!from->hasTrait<OpTrait::SymbolTable>())
605 return walkSymbolUses(from->getRegions(), callback);
606 return WalkResult::advance();
607}
608
609namespace {
610/// This class represents a single symbol scope. A symbol scope represents the
611/// set of operations nested within a symbol table that may reference symbols
612/// within that table. A symbol scope does not contain the symbol table
613/// operation itself, just its contained operations. A scope ends at leaf
614/// operations or another symbol table operation.
615struct SymbolScope {
616 /// Walk the symbol uses within this scope, invoking the given callback.
617 /// This variant is used when the callback type matches that expected by
618 /// 'walkSymbolUses'.
619 template <typename CallbackT,
620 std::enable_if_t<!std::is_same<
621 typename llvm::function_traits<CallbackT>::result_t,
622 void>::value> * = nullptr>
623 std::optional<WalkResult> walk(CallbackT cback) {
624 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
625 return walkSymbolUses(*region, cback);
626 return walkSymbolUses(cast<Operation *>(limit), cback);
627 }
628 /// This variant is used when the callback type matches a stripped down type:
629 /// void(SymbolTable::SymbolUse use)
630 template <typename CallbackT,
631 std::enable_if_t<std::is_same<
632 typename llvm::function_traits<CallbackT>::result_t,
633 void>::value> * = nullptr>
634 std::optional<WalkResult> walk(CallbackT cback) {
635 return walk([=](SymbolTable::SymbolUse use) {
636 return cback(use), WalkResult::advance();
637 });
638 }
639
640 /// Walk all of the operations nested under the current scope without
641 /// traversing into any nested symbol tables.
642 template <typename CallbackT>
643 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
644 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
645 return ::walkSymbolTable(*region, cback);
646 return ::walkSymbolTable(cast<Operation *>(limit), cback);
647 }
648
649 /// The representation of the symbol within this scope.
650 SymbolRefAttr symbol;
651
652 /// The IR unit representing this scope.
653 llvm::PointerUnion<Operation *, Region *> limit;
654};
655} // namespace
656
657/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
659 Operation *limit) {
660 StringAttr symName = SymbolTable::getSymbolName(symbol);
661 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
662
663 // Compute the ancestors of 'limit'.
666 limitAncestors;
667 Operation *limitAncestor = limit;
668 do {
669 // Check to see if 'symbol' is an ancestor of 'limit'.
670 if (limitAncestor == symbol) {
671 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
672 // doesn't support parent references.
674 symbol->getParentOp())
675 return {{SymbolRefAttr::get(symName), limit}};
676 return {};
677 }
678
679 limitAncestors.insert(limitAncestor);
680 } while ((limitAncestor = limitAncestor->getParentOp()));
681
682 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
683 Operation *commonAncestor = symbol->getParentOp();
684 do {
685 if (limitAncestors.count(commonAncestor))
686 break;
687 } while ((commonAncestor = commonAncestor->getParentOp()));
688 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
689
690 // Compute the set of valid nested references for 'symbol' as far up to the
691 // common ancestor as possible.
693 bool collectedAllReferences = succeeded(
694 collectValidReferencesFor(symbol, symName, commonAncestor, references));
695
696 // Handle the case where the common ancestor is 'limit'.
697 if (commonAncestor == limit) {
699
700 // Walk each of the ancestors of 'symbol', calling the compute function for
701 // each one.
702 Operation *limitIt = symbol->getParentOp();
703 for (size_t i = 0, e = references.size(); i != e;
704 ++i, limitIt = limitIt->getParentOp()) {
705 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
706 scopes.push_back({references[i], &limitIt->getRegion(0)});
707 }
708 return scopes;
709 }
710
711 // Otherwise, we just need the symbol reference for 'symbol' that will be
712 // used within 'limit'. This is the last reference in the list we computed
713 // above if we were able to collect all references.
714 if (!collectedAllReferences)
715 return {};
716 return {{references.back(), limit}};
717}
719 Region *limit) {
720 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
721
722 // If we collected some scopes to walk, make sure to constrain the one for
723 // limit to the specific region requested.
724 if (!scopes.empty())
725 scopes.back().limit = limit;
726 return scopes;
727}
729 Region *limit) {
730 return {{SymbolRefAttr::get(symbol), limit}};
731}
732
734 Operation *limit) {
736 auto symbolRef = SymbolRefAttr::get(symbol);
737 for (auto &region : limit->getRegions())
738 scopes.push_back({symbolRef, &region});
739 return scopes;
740}
741
742/// Returns true if the given reference 'SubRef' is a sub reference of the
743/// reference 'ref', i.e. 'ref' is a further qualified reference.
744static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
745 if (ref == subRef)
746 return true;
747
748 // If the references are not pointer equal, check to see if `subRef` is a
749 // prefix of `ref`.
750 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
751 ref.getRootReference() != subRef.getRootReference())
752 return false;
753
754 auto refLeafs = ref.getNestedReferences();
755 auto subRefLeafs = subRef.getNestedReferences();
756 return subRefLeafs.size() < refLeafs.size() &&
757 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
758}
759
760//===----------------------------------------------------------------------===//
761// SymbolTable::getSymbolUses
762//===----------------------------------------------------------------------===//
763
764/// The implementation of SymbolTable::getSymbolUses below.
765template <typename FromT>
766static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
767 std::vector<SymbolTable::SymbolUse> uses;
768 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
769 uses.push_back(symbolUse);
770 return WalkResult::advance();
771 };
772 auto result = walkSymbolUses(from, walkFn);
773 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
774 : std::nullopt;
775}
776
777/// Get an iterator range for all of the uses, for any symbol, that are nested
778/// within the given operation 'from'. This does not traverse into any nested
779/// symbol tables, and will also only return uses on 'from' if it does not
780/// also define a symbol table. This is because we treat the region as the
781/// boundary of the symbol table, and not the op itself. This function returns
782/// std::nullopt if there are any unknown operations that may potentially be
783/// symbol tables.
784auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
785 return getSymbolUsesImpl(from);
786}
787auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
789}
790
791//===----------------------------------------------------------------------===//
792// SymbolTable::getSymbolUses
793//===----------------------------------------------------------------------===//
794
795/// The implementation of SymbolTable::getSymbolUses below.
796template <typename SymbolT, typename IRUnitT>
797static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
798 IRUnitT *limit) {
799 std::vector<SymbolTable::SymbolUse> uses;
800 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
801 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
802 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
803 uses.push_back(symbolUse);
804 }))
805 return std::nullopt;
806 }
807 return SymbolTable::UseRange(std::move(uses));
808}
809
810/// Get all of the uses of the given symbol that are nested within the given
811/// operation 'from'. This does not traverse into any nested symbol tables.
812/// This function returns std::nullopt if there are any unknown operations that
813/// may potentially be symbol tables.
814auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
815 -> std::optional<UseRange> {
816 return getSymbolUsesImpl(symbol, from);
817}
819 -> std::optional<UseRange> {
820 return getSymbolUsesImpl(symbol, from);
821}
822auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
823 -> std::optional<UseRange> {
824 return getSymbolUsesImpl(symbol, from);
825}
827 -> std::optional<UseRange> {
828 return getSymbolUsesImpl(symbol, from);
829}
830
831//===----------------------------------------------------------------------===//
832// SymbolTable::symbolKnownUseEmpty
833//===----------------------------------------------------------------------===//
834
835/// The implementation of SymbolTable::symbolKnownUseEmpty below.
836template <typename SymbolT, typename IRUnitT>
837static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
838 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
839 // Walk all of the symbol uses looking for a reference to 'symbol'.
840 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
841 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
842 ? WalkResult::interrupt()
843 : WalkResult::advance();
844 }) != WalkResult::advance())
845 return false;
846 }
847 return true;
848}
849
850/// Return if the given symbol is known to have no uses that are nested within
851/// the given operation 'from'. This does not traverse into any nested symbol
852/// tables. This function will also return false if there are any unknown
853/// operations that may potentially be symbol tables.
854bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
855 return symbolKnownUseEmptyImpl(symbol, from);
856}
858 return symbolKnownUseEmptyImpl(symbol, from);
859}
860bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
861 return symbolKnownUseEmptyImpl(symbol, from);
862}
864 return symbolKnownUseEmptyImpl(symbol, from);
865}
866
867//===----------------------------------------------------------------------===//
868// SymbolTable::replaceAllSymbolUses
869//===----------------------------------------------------------------------===//
870
871/// Generates a new symbol reference attribute with a new leaf reference.
872static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
873 FlatSymbolRefAttr newLeafAttr) {
874 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
875 return newLeafAttr;
876 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
877 nestedRefs.back() = newLeafAttr;
878 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
879}
880
881/// The implementation of SymbolTable::replaceAllSymbolUses below.
882template <typename SymbolT, typename IRUnitT>
883static LogicalResult
884replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
885 // Generate a new attribute to replace the given attribute.
886 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
887 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
888 SymbolRefAttr oldAttr = scope.symbol;
889 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
890 AttrTypeReplacer replacer;
891 replacer.addReplacement(
892 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
893 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
894 // want to accidentally replace an inner reference.
895 if (attr == oldAttr)
896 return {newAttr, WalkResult::skip()};
897 // Handle prefix matches.
898 if (isReferencePrefixOf(oldAttr, attr)) {
899 auto oldNestedRefs = oldAttr.getNestedReferences();
900 auto nestedRefs = attr.getNestedReferences();
901 if (oldNestedRefs.empty())
902 return {SymbolRefAttr::get(newSymbol, nestedRefs),
904
905 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
906 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
907 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
909 }
910 return {attr, WalkResult::skip()};
911 });
912
913 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
914 replacer.replaceElementsIn(op);
915 return WalkResult::advance();
916 };
917 if (!scope.walkSymbolTable(walkFn))
918 return failure();
919 }
920 return success();
921}
922
923/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
924/// provided symbol 'newSymbol' that are nested within the given operation
925/// 'from'. This does not traverse into any nested symbol tables. If there are
926/// any unknown operations that may potentially be symbol tables, no uses are
927/// replaced and failure is returned.
928LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
929 StringAttr newSymbol,
930 Operation *from) {
931 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
932}
934 StringAttr newSymbol,
935 Operation *from) {
936 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
937}
938LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
939 StringAttr newSymbol,
940 Region *from) {
941 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
942}
944 StringAttr newSymbol,
945 Region *from) {
946 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
947}
948
949//===----------------------------------------------------------------------===//
950// SymbolTableCollection
951//===----------------------------------------------------------------------===//
952
954 StringAttr symbol) {
955 return getSymbolTable(symbolTableOp).lookup(symbol);
956}
958 SymbolRefAttr name) {
960 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
961 return nullptr;
962 return symbols.back();
963}
964/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
965/// a given SymbolRefAttr. Returns failure if any of the nested references could
966/// not be resolved.
967LogicalResult
969 SymbolRefAttr name,
971 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
972 return lookupSymbolIn(symbolTableOp, symbol);
973 };
974 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
975}
976
977/// Returns the operation registered with the given symbol name within the
978/// closest parent operation of, or including, 'from' with the
979/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
980/// found.
982 StringAttr symbol) {
983 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
984 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
985}
986Operation *
988 SymbolRefAttr symbol) {
989 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
990 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
991}
992
993/// Lookup, or create, a symbol table for an operation.
995 auto it = symbolTables.try_emplace(op, nullptr);
996 if (it.second)
997 it.first->second = std::make_unique<SymbolTable>(op);
998 return *it.first->second;
999}
1000
1002 symbolTables.erase(op);
1003}
1004
1005//===----------------------------------------------------------------------===//
1006// LockedSymbolTableCollection
1007//===----------------------------------------------------------------------===//
1008
1010 StringAttr symbol) {
1011 return getSymbolTable(symbolTableOp).lookup(symbol);
1012}
1013
1014Operation *
1016 FlatSymbolRefAttr symbol) {
1017 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1018}
1019
1021 SymbolRefAttr name) {
1023 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1024 return nullptr;
1025 return symbols.back();
1026}
1027
1029 Operation *symbolTableOp, SymbolRefAttr name,
1031 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1032 return lookupSymbolIn(symbolTableOp, symbol);
1033 };
1034 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1035}
1036
1038LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1039 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1040 // Try to find an existing symbol table.
1041 {
1042 llvm::sys::SmartScopedReader<true> lock(mutex);
1043 auto it = collection.symbolTables.find(symbolTableOp);
1044 if (it != collection.symbolTables.end())
1045 return *it->second;
1046 }
1047 // Create a symbol table for the operation. Perform construction outside of
1048 // the critical section.
1049 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1050 // Insert the constructed symbol table.
1051 llvm::sys::SmartScopedWriter<true> lock(mutex);
1052 return *collection.symbolTables
1053 .insert({symbolTableOp, std::move(symbolTable)})
1054 .first->second;
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// SymbolUserMap
1059//===----------------------------------------------------------------------===//
1060
1062 Operation *symbolTableOp)
1063 : symbolTable(symbolTable) {
1064 // Walk each of the symbol tables looking for discardable callgraph nodes.
1066 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1067 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1068 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1069 assert(symbolUses && "expected uses to be valid");
1070
1071 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1072 symbols.clear();
1073 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1074 symbols);
1075 for (Operation *symbolOp : symbols)
1076 symbolToUsers[symbolOp].insert(use.getUser());
1077 }
1078 }
1079 };
1080 // We just set `allSymUsesVisible` to false here because it isn't necessary
1081 // for building the user map.
1082 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1083 walkFn);
1084}
1085
1087 StringAttr newSymbolName) {
1088 auto it = symbolToUsers.find(symbol);
1089 if (it == symbolToUsers.end())
1090 return;
1091
1092 // Replace the uses within the users of `symbol`.
1093 for (Operation *user : it->second)
1094 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1095
1096 // Move the current users of `symbol` to the new symbol if it is in the
1097 // symbol table.
1098 Operation *newSymbol =
1099 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1100 if (newSymbol != symbol) {
1101 // Transfer over the users to the new symbol. The reference to the old one
1102 // is fetched again as the iterator is invalidated during the insertion.
1103 auto newIt = symbolToUsers.try_emplace(newSymbol);
1104 auto oldIt = symbolToUsers.find(symbol);
1105 assert(oldIt != symbolToUsers.end() && "missing old users list");
1106 if (newIt.second)
1107 newIt.first->second = std::move(oldIt->second);
1108 else
1109 newIt.first->second.set_union(oldIt->second);
1110 symbolToUsers.erase(oldIt);
1111 }
1112}
1113
1114//===----------------------------------------------------------------------===//
1115// Visibility parsing implementation.
1116//===----------------------------------------------------------------------===//
1117
1119 NamedAttrList &attrs) {
1120 StringRef visibility;
1121 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1122 return failure();
1123
1124 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1125 attrs.push_back(parser.getBuilder().getNamedAttr(
1126 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1127 return success();
1128}
1129
1130//===----------------------------------------------------------------------===//
1131// Symbol Interfaces
1132//===----------------------------------------------------------------------===//
1133
1134/// Include the generated symbol interfaces.
1135#include "mlir/IR/SymbolInterfaces.cpp.inc"
1136#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol) override
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:486
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
virtual void invalidateSymbolTable(Operation *op)
Invalidate the cached symbol table for an operation.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static SmallString< N > generateSymbolName(StringRef name, UniqueChecker uniqueChecker, unsigned &uniquingCounter)
Generate a unique symbol name.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
Definition SymbolTable.h:93
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition SymbolTable.h:82
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
Operation * getOp() const
Returns the associated operation.
Definition SymbolTable.h:79
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
FailureOr< StringAttr > renameToUnique(StringAttr from, ArrayRef< SymbolTable * > others)
Renames the given op or the op refered to by the given name to the a name that is unique within this ...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::StringSwitch< T, R > StringSwitch
Definition LLVM.h:133
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144