MLIR 22.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/ADT/StringSwitch.h"
15#include <optional>
16
17using namespace mlir;
18
19/// Return true if the given operation is unknown and may potentially define a
20/// symbol table.
22 return op->getNumRegions() == 1 && !op->getDialect();
23}
24
25/// Returns the string name of the given symbol, or null if this is not a
26/// symbol.
27static StringAttr getNameIfSymbol(Operation *op) {
28 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
29}
30static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
31 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
32}
33
34/// Computes the nested symbol reference attribute for the symbol 'symbolName'
35/// that are usable within the symbol table operations from 'symbol' as far up
36/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
37/// Returns success if all references up to 'within' could be computed.
38static LogicalResult
39collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
40 Operation *within,
42 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
43 MLIRContext *ctx = symbol->getContext();
44
45 auto leafRef = FlatSymbolRefAttr::get(symbolName);
46 results.push_back(leafRef);
47
48 // Early exit for when 'within' is the parent of 'symbol'.
49 Operation *symbolTableOp = symbol->getParentOp();
50 if (within == symbolTableOp)
51 return success();
52
53 // Collect references until 'symbolTableOp' reaches 'within'.
54 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
55 StringAttr symbolNameId =
56 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
57 do {
58 // Each parent of 'symbol' should define a symbol table.
59 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
60 return failure();
61 // Each parent of 'symbol' should also be a symbol.
62 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
63 if (!symbolTableName)
64 return failure();
65 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
66
67 symbolTableOp = symbolTableOp->getParentOp();
68 if (symbolTableOp == within)
69 break;
70 nestedRefs.insert(nestedRefs.begin(),
71 FlatSymbolRefAttr::get(symbolTableName));
72 } while (true);
73 return success();
74}
75
76/// Walk all of the operations within the given set of regions, without
77/// traversing into any nested symbol tables. Stops walking if the result of the
78/// callback is anything other than `WalkResult::advance`.
79static std::optional<WalkResult>
81 function_ref<std::optional<WalkResult>(Operation *)> callback) {
82 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
83 while (!worklist.empty()) {
84 for (Operation &op : worklist.pop_back_val()->getOps()) {
85 std::optional<WalkResult> result = callback(&op);
87 return result;
88
89 // If this op defines a new symbol table scope, we can't traverse. Any
90 // symbol references nested within 'op' are different semantically.
91 if (!op.hasTrait<OpTrait::SymbolTable>()) {
92 for (Region &region : op.getRegions())
93 worklist.push_back(&region);
94 }
95 }
96 }
97 return WalkResult::advance();
98}
99
100/// Walk all of the operations nested under, and including, the given operation,
101/// without traversing into any nested symbol tables. Stops walking if the
102/// result of the callback is anything other than `WalkResult::advance`.
103static std::optional<WalkResult>
105 function_ref<std::optional<WalkResult>(Operation *)> callback) {
106 std::optional<WalkResult> result = callback(op);
108 return result;
109 return walkSymbolTable(op->getRegions(), callback);
110}
111
112//===----------------------------------------------------------------------===//
113// SymbolTable
114//===----------------------------------------------------------------------===//
115
116/// Build a symbol table with the symbols within the given operation.
118 : symbolTableOp(symbolTableOp) {
119 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
120 "expected operation to have SymbolTable trait");
121 assert(symbolTableOp->getNumRegions() == 1 &&
122 "expected operation to have a single region");
123 assert(symbolTableOp->getRegion(0).hasOneBlock() &&
124 "expected operation to have a single block");
125
126 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 for (auto &op : symbolTableOp->getRegion(0).front()) {
129 StringAttr name = getNameIfSymbol(&op, symbolNameId);
130 if (!name)
131 continue;
132
133 auto inserted = symbolTable.insert({name, &op});
134 (void)inserted;
135 assert(inserted.second &&
136 "expected region to contain uniquely named symbol operations");
137 }
138}
139
140/// Look up a symbol with the specified name, returning null if no such name
141/// exists. Names never include the @ on them.
142Operation *SymbolTable::lookup(StringRef name) const {
143 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
144}
145Operation *SymbolTable::lookup(StringAttr name) const {
146 return symbolTable.lookup(name);
147}
148
150 StringAttr name = getNameIfSymbol(op);
151 assert(name && "expected valid 'name' attribute");
152 assert(op->getParentOp() == symbolTableOp &&
153 "expected this operation to be inside of the operation with this "
154 "SymbolTable");
155
156 auto it = symbolTable.find(name);
157 if (it != symbolTable.end() && it->second == op)
158 symbolTable.erase(it);
159}
160
162 remove(symbol);
163 symbol->erase();
164}
165
166// TODO: Consider if this should be renamed to something like insertOrUpdate
167/// Insert a new symbol into the table and associated operation if not already
168/// there and rename it as necessary to avoid collisions. Return the name of
169/// the symbol after insertion as attribute.
170StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
171 // The symbol cannot be the child of another op and must be the child of the
172 // symbolTableOp after this.
173 //
174 // TODO: consider if SymbolTable's constructor should behave the same.
175 if (!symbol->getParentOp()) {
176 auto &body = symbolTableOp->getRegion(0).front();
177 if (insertPt == Block::iterator()) {
178 insertPt = Block::iterator(body.end());
179 } else {
180 assert((insertPt == body.end() ||
181 insertPt->getParentOp() == symbolTableOp) &&
182 "expected insertPt to be in the associated module operation");
183 }
184 // Insert before the terminator, if any.
185 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
186 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
187 insertPt = std::prev(body.end());
188
189 body.getOperations().insert(insertPt, symbol);
190 }
191 assert(symbol->getParentOp() == symbolTableOp &&
192 "symbol is already inserted in another op");
193
194 // Add this symbol to the symbol table, uniquing the name if a conflict is
195 // detected.
196 StringAttr name = getSymbolName(symbol);
197 if (symbolTable.insert({name, symbol}).second)
198 return name;
199 // If the symbol was already in the table, also return.
200 if (symbolTable.lookup(name) == symbol)
201 return name;
202
203 MLIRContext *context = symbol->getContext();
205 name.getValue(),
206 [&](StringRef candidate) {
207 return !symbolTable
208 .insert({StringAttr::get(context, candidate), symbol})
209 .second;
210 },
211 uniquingCounter);
212 setSymbolName(symbol, nameBuffer);
213 return getSymbolName(symbol);
214}
215
216LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
217 Operation *op = lookup(from);
218 return rename(op, to);
219}
220
221LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
222 StringAttr from = getNameIfSymbol(op);
223 (void)from;
224
225 assert(from && "expected valid 'name' attribute");
226 assert(op->getParentOp() == symbolTableOp &&
227 "expected this operation to be inside of the operation with this "
228 "SymbolTable");
229 assert(lookup(from) == op && "current name does not resolve to op");
230 assert(lookup(to) == nullptr && "new name already exists");
231
232 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
233 return failure();
234
235 // Remove op with old name, change name, add with new name. The order is
236 // important here due to how `remove` and `insert` rely on the op name.
237 remove(op);
238 setSymbolName(op, to);
239 insert(op);
240
241 assert(lookup(to) == op && "new name does not resolve to renamed op");
242 assert(lookup(from) == nullptr && "old name still exists");
243
244 return success();
245}
246
247LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
248 auto toAttr = StringAttr::get(getOp()->getContext(), to);
249 return rename(from, toAttr);
250}
251
252LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
253 auto toAttr = StringAttr::get(getOp()->getContext(), to);
254 return rename(op, toAttr);
255}
256
257FailureOr<StringAttr>
260
261 // Determine new name that is unique in all symbol tables.
262 StringAttr newName;
263 {
264 MLIRContext *context = oldName.getContext();
265 SmallString<64> prefix = oldName.getValue();
266 int uniqueId = 0;
267 prefix.push_back('_');
268 while (true) {
269 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
270 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
271 if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
272 break;
273 }
274 }
275 }
276
277 // Apply renaming.
278 if (failed(rename(oldName, newName)))
279 return failure();
280 return newName;
281}
282
283FailureOr<StringAttr>
285 StringAttr from = getNameIfSymbol(op);
286 assert(from && "expected valid 'name' attribute");
287 return renameToUnique(from, others);
288}
289
290/// Returns the name of the given symbol operation.
292 StringAttr name = getNameIfSymbol(symbol);
293 assert(name && "expected valid symbol name");
294 return name;
295}
296
297/// Sets the name of the given symbol operation.
298void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
299 symbol->setAttr(getSymbolAttrName(), name);
300}
301
302/// Returns the visibility of the given symbol operation.
304 // If the attribute doesn't exist, assume public.
305 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
306 if (!vis)
307 return Visibility::Public;
308
309 // Otherwise, switch on the string value.
310 return StringSwitch<Visibility>(vis.getValue())
311 .Case("private", Visibility::Private)
312 .Case("nested", Visibility::Nested)
313 .Case("public", Visibility::Public);
314}
315/// Sets the visibility of the given symbol operation.
317 MLIRContext *ctx = symbol->getContext();
318
319 // If the visibility is public, just drop the attribute as this is the
320 // default.
321 if (vis == Visibility::Public) {
322 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
323 return;
324 }
325
326 // Otherwise, update the attribute.
327 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
328 "unknown symbol visibility kind");
329
330 StringRef visName = vis == Visibility::Private ? "private" : "nested";
331 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
332}
333
334/// Returns the nearest symbol table from a given operation `from`. Returns
335/// nullptr if no valid parent symbol table could be found.
337 assert(from && "expected valid operation");
339 return nullptr;
340
341 while (!from->hasTrait<OpTrait::SymbolTable>()) {
342 from = from->getParentOp();
343
344 // Check that this is a valid op and isn't an unknown symbol table.
345 if (!from || isPotentiallyUnknownSymbolTable(from))
346 return nullptr;
347 }
348 return from;
349}
350
351/// Walks all symbol table operations nested within, and including, `op`. For
352/// each symbol table operation, the provided callback is invoked with the op
353/// and a boolean signifying if the symbols within that symbol table can be
354/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
355/// all of the symbol uses of symbols within `op` are visible.
357 Operation *op, bool allSymUsesVisible,
358 function_ref<void(Operation *, bool)> callback) {
359 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
360 if (isSymbolTable) {
361 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
362 allSymUsesVisible |= !symbol || symbol.isPrivate();
363 } else {
364 // Otherwise if 'op' is not a symbol table, any nested symbols are
365 // guaranteed to be hidden.
366 allSymUsesVisible = true;
367 }
368
369 for (Region &region : op->getRegions())
370 for (Block &block : region)
371 for (Operation &nestedOp : block)
372 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
373
374 // If 'op' had the symbol table trait, visit it after any nested symbol
375 // tables.
376 if (isSymbolTable)
377 callback(op, allSymUsesVisible);
378}
379
380/// Returns the operation registered with the given symbol name with the
381/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
382/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
383/// was found.
385 StringAttr symbol) {
386 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
387 Region &region = symbolTableOp->getRegion(0);
388 if (region.empty())
389 return nullptr;
390
391 // Look for a symbol with the given name.
392 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394 for (auto &op : region.front())
395 if (getNameIfSymbol(&op, symbolNameId) == symbol)
396 return &op;
397 return nullptr;
398}
400 SymbolRefAttr symbol) {
401 SmallVector<Operation *, 4> resolvedSymbols;
402 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
403 return nullptr;
404 return resolvedSymbols.back();
405}
406
407/// Internal implementation of `lookupSymbolIn` that allows for specialized
408/// implementations of the lookup function.
409static LogicalResult lookupSymbolInImpl(
410 Operation *symbolTableOp, SymbolRefAttr symbol,
412 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
413 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
414
415 // Lookup the root reference for this symbol.
416 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
417 if (!symbolTableOp)
418 return failure();
419 symbols.push_back(symbolTableOp);
420
421 // If there are no nested references, just return the root symbol directly.
422 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
423 if (nestedRefs.empty())
424 return success();
425
426 // Verify that the root is also a symbol table.
427 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
428 return failure();
429
430 // Otherwise, lookup each of the nested non-leaf references and ensure that
431 // each corresponds to a valid symbol table.
432 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
433 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
434 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
435 return failure();
436 symbols.push_back(symbolTableOp);
437 }
438 symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
439 return success(symbols.back());
440}
441
442LogicalResult
443SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
445 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
446 return lookupSymbolIn(symbolTableOp, symbol);
447 };
448 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
449}
450
451/// Returns the operation registered with the given symbol name within the
452/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
453/// nullptr if no valid symbol was found.
455 StringAttr symbol) {
456 Operation *symbolTableOp = getNearestSymbolTable(from);
457 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
458}
460 SymbolRefAttr symbol) {
461 Operation *symbolTableOp = getNearestSymbolTable(from);
462 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
463}
464
466 SymbolTable::Visibility visibility) {
467 switch (visibility) {
469 return os << "public";
471 return os << "private";
473 return os << "nested";
474 }
475 llvm_unreachable("Unexpected visibility");
476}
477
478//===----------------------------------------------------------------------===//
479// SymbolTable Trait Types
480//===----------------------------------------------------------------------===//
481
483 if (op->getNumRegions() != 1)
484 return op->emitOpError()
485 << "Operations with a 'SymbolTable' must have exactly one region";
486 if (!op->getRegion(0).hasOneBlock())
487 return op->emitOpError()
488 << "Operations with a 'SymbolTable' must have exactly one block";
489
490 // Check that all symbols are uniquely named within child regions.
491 DenseMap<Attribute, Location> nameToOrigLoc;
492 for (auto &block : op->getRegion(0)) {
493 for (auto &op : block) {
494 // Check for a symbol name attribute.
495 auto nameAttr =
497 if (!nameAttr)
498 continue;
499
500 // Try to insert this symbol into the table.
501 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
502 if (!it.second)
503 return op.emitError()
504 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
505 .attachNote(it.first->second)
506 .append("see existing symbol definition here");
507 }
508 }
509
510 // Verify any nested symbol user operations.
511 SymbolTableCollection symbolTable;
512 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
513 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
514 return WalkResult(user.verifySymbolUses(symbolTable));
515 return WalkResult::advance();
516 };
517
518 std::optional<WalkResult> result =
519 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
520 return success(result && !result->wasInterrupted());
521}
522
523LogicalResult detail::verifySymbol(Operation *op) {
524 // Verify the name attribute.
526 return op->emitOpError() << "requires string attribute '"
528
529 // Verify the visibility attribute.
531 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
532 if (!visStrAttr)
533 return op->emitOpError() << "requires visibility attribute '"
535 << "' to be a string attribute, but got " << vis;
536
537 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
538 visStrAttr.getValue()))
539 return op->emitOpError()
540 << "visibility expected to be one of [\"public\", \"private\", "
541 "\"nested\"], but got "
542 << visStrAttr;
543 }
544 return success();
545}
546
547//===----------------------------------------------------------------------===//
548// Symbol Use Lists
549//===----------------------------------------------------------------------===//
550
551/// Walk all of the symbol references within the given operation, invoking the
552/// provided callback for each found use. The callbacks takes the use of the
553/// symbol.
554static WalkResult
557 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
558 [&](SymbolRefAttr symbolRef) {
559 if (callback({op, symbolRef}).wasInterrupted())
560 return WalkResult::interrupt();
561
562 // Don't walk nested references.
563 return WalkResult::skip();
564 });
565}
566
567/// Walk all of the uses, for any symbol, that are nested within the given
568/// regions, invoking the provided callback for each. This does not traverse
569/// into any nested symbol tables.
570static std::optional<WalkResult>
573 return walkSymbolTable(regions,
574 [&](Operation *op) -> std::optional<WalkResult> {
575 // Check that this isn't a potentially unknown symbol
576 // table.
578 return std::nullopt;
579
580 return walkSymbolRefs(op, callback);
581 });
582}
583/// Walk all of the uses, for any symbol, that are nested within the given
584/// operation 'from', invoking the provided callback for each. This does not
585/// traverse into any nested symbol tables.
586static std::optional<WalkResult>
589 // If this operation has regions, and it, as well as its dialect, isn't
590 // registered then conservatively fail. The operation may define a
591 // symbol table, so we can't opaquely know if we should traverse to find
592 // nested uses.
594 return std::nullopt;
595
596 // Walk the uses on this operation.
597 if (walkSymbolRefs(from, callback).wasInterrupted())
598 return WalkResult::interrupt();
599
600 // Only recurse if this operation is not a symbol table. A symbol table
601 // defines a new scope, so we can't walk the attributes from within the symbol
602 // table op.
603 if (!from->hasTrait<OpTrait::SymbolTable>())
604 return walkSymbolUses(from->getRegions(), callback);
605 return WalkResult::advance();
606}
607
608namespace {
609/// This class represents a single symbol scope. A symbol scope represents the
610/// set of operations nested within a symbol table that may reference symbols
611/// within that table. A symbol scope does not contain the symbol table
612/// operation itself, just its contained operations. A scope ends at leaf
613/// operations or another symbol table operation.
614struct SymbolScope {
615 /// Walk the symbol uses within this scope, invoking the given callback.
616 /// This variant is used when the callback type matches that expected by
617 /// 'walkSymbolUses'.
618 template <typename CallbackT,
619 std::enable_if_t<!std::is_same<
620 typename llvm::function_traits<CallbackT>::result_t,
621 void>::value> * = nullptr>
622 std::optional<WalkResult> walk(CallbackT cback) {
623 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
624 return walkSymbolUses(*region, cback);
625 return walkSymbolUses(cast<Operation *>(limit), cback);
626 }
627 /// This variant is used when the callback type matches a stripped down type:
628 /// void(SymbolTable::SymbolUse use)
629 template <typename CallbackT,
630 std::enable_if_t<std::is_same<
631 typename llvm::function_traits<CallbackT>::result_t,
632 void>::value> * = nullptr>
633 std::optional<WalkResult> walk(CallbackT cback) {
634 return walk([=](SymbolTable::SymbolUse use) {
635 return cback(use), WalkResult::advance();
636 });
637 }
638
639 /// Walk all of the operations nested under the current scope without
640 /// traversing into any nested symbol tables.
641 template <typename CallbackT>
642 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
643 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
644 return ::walkSymbolTable(*region, cback);
645 return ::walkSymbolTable(cast<Operation *>(limit), cback);
646 }
647
648 /// The representation of the symbol within this scope.
649 SymbolRefAttr symbol;
650
651 /// The IR unit representing this scope.
652 llvm::PointerUnion<Operation *, Region *> limit;
653};
654} // namespace
655
656/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
658 Operation *limit) {
659 StringAttr symName = SymbolTable::getSymbolName(symbol);
660 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
661
662 // Compute the ancestors of 'limit'.
665 limitAncestors;
666 Operation *limitAncestor = limit;
667 do {
668 // Check to see if 'symbol' is an ancestor of 'limit'.
669 if (limitAncestor == symbol) {
670 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
671 // doesn't support parent references.
673 symbol->getParentOp())
674 return {{SymbolRefAttr::get(symName), limit}};
675 return {};
676 }
677
678 limitAncestors.insert(limitAncestor);
679 } while ((limitAncestor = limitAncestor->getParentOp()));
680
681 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
682 Operation *commonAncestor = symbol->getParentOp();
683 do {
684 if (limitAncestors.count(commonAncestor))
685 break;
686 } while ((commonAncestor = commonAncestor->getParentOp()));
687 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
688
689 // Compute the set of valid nested references for 'symbol' as far up to the
690 // common ancestor as possible.
692 bool collectedAllReferences = succeeded(
693 collectValidReferencesFor(symbol, symName, commonAncestor, references));
694
695 // Handle the case where the common ancestor is 'limit'.
696 if (commonAncestor == limit) {
698
699 // Walk each of the ancestors of 'symbol', calling the compute function for
700 // each one.
701 Operation *limitIt = symbol->getParentOp();
702 for (size_t i = 0, e = references.size(); i != e;
703 ++i, limitIt = limitIt->getParentOp()) {
704 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
705 scopes.push_back({references[i], &limitIt->getRegion(0)});
706 }
707 return scopes;
708 }
709
710 // Otherwise, we just need the symbol reference for 'symbol' that will be
711 // used within 'limit'. This is the last reference in the list we computed
712 // above if we were able to collect all references.
713 if (!collectedAllReferences)
714 return {};
715 return {{references.back(), limit}};
716}
718 Region *limit) {
719 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
720
721 // If we collected some scopes to walk, make sure to constrain the one for
722 // limit to the specific region requested.
723 if (!scopes.empty())
724 scopes.back().limit = limit;
725 return scopes;
726}
728 Region *limit) {
729 return {{SymbolRefAttr::get(symbol), limit}};
730}
731
733 Operation *limit) {
735 auto symbolRef = SymbolRefAttr::get(symbol);
736 for (auto &region : limit->getRegions())
737 scopes.push_back({symbolRef, &region});
738 return scopes;
739}
740
741/// Returns true if the given reference 'SubRef' is a sub reference of the
742/// reference 'ref', i.e. 'ref' is a further qualified reference.
743static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
744 if (ref == subRef)
745 return true;
746
747 // If the references are not pointer equal, check to see if `subRef` is a
748 // prefix of `ref`.
749 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
750 ref.getRootReference() != subRef.getRootReference())
751 return false;
752
753 auto refLeafs = ref.getNestedReferences();
754 auto subRefLeafs = subRef.getNestedReferences();
755 return subRefLeafs.size() < refLeafs.size() &&
756 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
757}
758
759//===----------------------------------------------------------------------===//
760// SymbolTable::getSymbolUses
761//===----------------------------------------------------------------------===//
762
763/// The implementation of SymbolTable::getSymbolUses below.
764template <typename FromT>
765static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
766 std::vector<SymbolTable::SymbolUse> uses;
767 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
768 uses.push_back(symbolUse);
769 return WalkResult::advance();
770 };
771 auto result = walkSymbolUses(from, walkFn);
772 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
773 : std::nullopt;
774}
775
776/// Get an iterator range for all of the uses, for any symbol, that are nested
777/// within the given operation 'from'. This does not traverse into any nested
778/// symbol tables, and will also only return uses on 'from' if it does not
779/// also define a symbol table. This is because we treat the region as the
780/// boundary of the symbol table, and not the op itself. This function returns
781/// std::nullopt if there are any unknown operations that may potentially be
782/// symbol tables.
783auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
784 return getSymbolUsesImpl(from);
785}
786auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
788}
789
790//===----------------------------------------------------------------------===//
791// SymbolTable::getSymbolUses
792//===----------------------------------------------------------------------===//
793
794/// The implementation of SymbolTable::getSymbolUses below.
795template <typename SymbolT, typename IRUnitT>
796static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
797 IRUnitT *limit) {
798 std::vector<SymbolTable::SymbolUse> uses;
799 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
800 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
801 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
802 uses.push_back(symbolUse);
803 }))
804 return std::nullopt;
805 }
806 return SymbolTable::UseRange(std::move(uses));
807}
808
809/// Get all of the uses of the given symbol that are nested within the given
810/// operation 'from'. This does not traverse into any nested symbol tables.
811/// This function returns std::nullopt if there are any unknown operations that
812/// may potentially be symbol tables.
813auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
814 -> std::optional<UseRange> {
815 return getSymbolUsesImpl(symbol, from);
816}
818 -> std::optional<UseRange> {
819 return getSymbolUsesImpl(symbol, from);
820}
821auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
822 -> std::optional<UseRange> {
823 return getSymbolUsesImpl(symbol, from);
824}
826 -> std::optional<UseRange> {
827 return getSymbolUsesImpl(symbol, from);
828}
829
830//===----------------------------------------------------------------------===//
831// SymbolTable::symbolKnownUseEmpty
832//===----------------------------------------------------------------------===//
833
834/// The implementation of SymbolTable::symbolKnownUseEmpty below.
835template <typename SymbolT, typename IRUnitT>
836static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
837 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
838 // Walk all of the symbol uses looking for a reference to 'symbol'.
839 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
840 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
841 ? WalkResult::interrupt()
842 : WalkResult::advance();
843 }) != WalkResult::advance())
844 return false;
845 }
846 return true;
847}
848
849/// Return if the given symbol is known to have no uses that are nested within
850/// the given operation 'from'. This does not traverse into any nested symbol
851/// tables. This function will also return false if there are any unknown
852/// operations that may potentially be symbol tables.
853bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
854 return symbolKnownUseEmptyImpl(symbol, from);
855}
857 return symbolKnownUseEmptyImpl(symbol, from);
858}
859bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
860 return symbolKnownUseEmptyImpl(symbol, from);
861}
863 return symbolKnownUseEmptyImpl(symbol, from);
864}
865
866//===----------------------------------------------------------------------===//
867// SymbolTable::replaceAllSymbolUses
868//===----------------------------------------------------------------------===//
869
870/// Generates a new symbol reference attribute with a new leaf reference.
871static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
872 FlatSymbolRefAttr newLeafAttr) {
873 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
874 return newLeafAttr;
875 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
876 nestedRefs.back() = newLeafAttr;
877 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
878}
879
880/// The implementation of SymbolTable::replaceAllSymbolUses below.
881template <typename SymbolT, typename IRUnitT>
882static LogicalResult
883replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
884 // Generate a new attribute to replace the given attribute.
885 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
886 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
887 SymbolRefAttr oldAttr = scope.symbol;
888 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
889 AttrTypeReplacer replacer;
890 replacer.addReplacement(
891 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
892 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
893 // want to accidentally replace an inner reference.
894 if (attr == oldAttr)
895 return {newAttr, WalkResult::skip()};
896 // Handle prefix matches.
897 if (isReferencePrefixOf(oldAttr, attr)) {
898 auto oldNestedRefs = oldAttr.getNestedReferences();
899 auto nestedRefs = attr.getNestedReferences();
900 if (oldNestedRefs.empty())
901 return {SymbolRefAttr::get(newSymbol, nestedRefs),
903
904 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
905 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
906 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
908 }
909 return {attr, WalkResult::skip()};
910 });
911
912 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
913 replacer.replaceElementsIn(op);
914 return WalkResult::advance();
915 };
916 if (!scope.walkSymbolTable(walkFn))
917 return failure();
918 }
919 return success();
920}
921
922/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
923/// provided symbol 'newSymbol' that are nested within the given operation
924/// 'from'. This does not traverse into any nested symbol tables. If there are
925/// any unknown operations that may potentially be symbol tables, no uses are
926/// replaced and failure is returned.
927LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
928 StringAttr newSymbol,
929 Operation *from) {
930 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
931}
933 StringAttr newSymbol,
934 Operation *from) {
935 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
936}
937LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
938 StringAttr newSymbol,
939 Region *from) {
940 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
941}
943 StringAttr newSymbol,
944 Region *from) {
945 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
946}
947
948//===----------------------------------------------------------------------===//
949// SymbolTableCollection
950//===----------------------------------------------------------------------===//
951
953 StringAttr symbol) {
954 return getSymbolTable(symbolTableOp).lookup(symbol);
955}
957 SymbolRefAttr name) {
959 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
960 return nullptr;
961 return symbols.back();
962}
963/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
964/// a given SymbolRefAttr. Returns failure if any of the nested references could
965/// not be resolved.
966LogicalResult
968 SymbolRefAttr name,
970 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
971 return lookupSymbolIn(symbolTableOp, symbol);
972 };
973 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
974}
975
976/// Returns the operation registered with the given symbol name within the
977/// closest parent operation of, or including, 'from' with the
978/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
979/// found.
981 StringAttr symbol) {
982 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
983 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
984}
985Operation *
987 SymbolRefAttr symbol) {
988 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
989 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
990}
991
992/// Lookup, or create, a symbol table for an operation.
994 auto it = symbolTables.try_emplace(op, nullptr);
995 if (it.second)
996 it.first->second = std::make_unique<SymbolTable>(op);
997 return *it.first->second;
998}
999
1001 symbolTables.erase(op);
1002}
1003
1004//===----------------------------------------------------------------------===//
1005// LockedSymbolTableCollection
1006//===----------------------------------------------------------------------===//
1007
1009 StringAttr symbol) {
1010 return getSymbolTable(symbolTableOp).lookup(symbol);
1011}
1012
1013Operation *
1015 FlatSymbolRefAttr symbol) {
1016 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1017}
1018
1020 SymbolRefAttr name) {
1022 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1023 return nullptr;
1024 return symbols.back();
1025}
1026
1028 Operation *symbolTableOp, SymbolRefAttr name,
1030 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1031 return lookupSymbolIn(symbolTableOp, symbol);
1032 };
1033 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1034}
1035
1037LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1038 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1039 // Try to find an existing symbol table.
1040 {
1041 llvm::sys::SmartScopedReader<true> lock(mutex);
1042 auto it = collection.symbolTables.find(symbolTableOp);
1043 if (it != collection.symbolTables.end())
1044 return *it->second;
1045 }
1046 // Create a symbol table for the operation. Perform construction outside of
1047 // the critical section.
1048 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1049 // Insert the constructed symbol table.
1050 llvm::sys::SmartScopedWriter<true> lock(mutex);
1051 return *collection.symbolTables
1052 .insert({symbolTableOp, std::move(symbolTable)})
1053 .first->second;
1054}
1055
1056//===----------------------------------------------------------------------===//
1057// SymbolUserMap
1058//===----------------------------------------------------------------------===//
1059
1061 Operation *symbolTableOp)
1062 : symbolTable(symbolTable) {
1063 // Walk each of the symbol tables looking for discardable callgraph nodes.
1065 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1066 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1067 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1068 assert(symbolUses && "expected uses to be valid");
1069
1070 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1071 symbols.clear();
1072 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1073 symbols);
1074 for (Operation *symbolOp : symbols)
1075 symbolToUsers[symbolOp].insert(use.getUser());
1076 }
1077 }
1078 };
1079 // We just set `allSymUsesVisible` to false here because it isn't necessary
1080 // for building the user map.
1081 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1082 walkFn);
1083}
1084
1086 StringAttr newSymbolName) {
1087 auto it = symbolToUsers.find(symbol);
1088 if (it == symbolToUsers.end())
1089 return;
1090
1091 // Replace the uses within the users of `symbol`.
1092 for (Operation *user : it->second)
1093 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1094
1095 // Move the current users of `symbol` to the new symbol if it is in the
1096 // symbol table.
1097 Operation *newSymbol =
1098 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1099 if (newSymbol != symbol) {
1100 // Transfer over the users to the new symbol. The reference to the old one
1101 // is fetched again as the iterator is invalidated during the insertion.
1102 auto newIt = symbolToUsers.try_emplace(newSymbol);
1103 auto oldIt = symbolToUsers.find(symbol);
1104 assert(oldIt != symbolToUsers.end() && "missing old users list");
1105 if (newIt.second)
1106 newIt.first->second = std::move(oldIt->second);
1107 else
1108 newIt.first->second.set_union(oldIt->second);
1109 symbolToUsers.erase(oldIt);
1110 }
1111}
1112
1113//===----------------------------------------------------------------------===//
1114// Visibility parsing implementation.
1115//===----------------------------------------------------------------------===//
1116
1118 NamedAttrList &attrs) {
1119 StringRef visibility;
1120 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1121 return failure();
1122
1123 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1124 attrs.push_back(parser.getBuilder().getNamedAttr(
1125 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1126 return success();
1127}
1128
1129//===----------------------------------------------------------------------===//
1130// Symbol Interfaces
1131//===----------------------------------------------------------------------===//
1132
1133/// Include the generated symbol interfaces.
1134#include "mlir/IR/SymbolInterfaces.cpp.inc"
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol) override
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
virtual void invalidateSymbolTable(Operation *op)
Invalidate the cached symbol table for an operation.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static SmallString< N > generateSymbolName(StringRef name, UniqueChecker uniqueChecker, unsigned &uniquingCounter)
Generate a unique symbol name.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
Definition SymbolTable.h:93
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition SymbolTable.h:82
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
Operation * getOp() const
Returns the associated operation.
Definition SymbolTable.h:79
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
FailureOr< StringAttr > renameToUnique(StringAttr from, ArrayRef< SymbolTable * > others)
Renames the given op or the op refered to by the given name to the a name that is unique within this ...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::StringSwitch< T, R > StringSwitch
Definition LLVM.h:141
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152