MLIR 22.0.0git
SymbolTable.cpp
Go to the documentation of this file.
1//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/ADT/StringSwitch.h"
15#include <optional>
16
17using namespace mlir;
18
19/// Return true if the given operation is unknown and may potentially define a
20/// symbol table.
22 return op->getNumRegions() == 1 && !op->getDialect();
23}
24
25/// Returns the string name of the given symbol, or null if this is not a
26/// symbol.
27static StringAttr getNameIfSymbol(Operation *op) {
28 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
29}
30static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
31 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
32}
33
34/// Computes the nested symbol reference attribute for the symbol 'symbolName'
35/// that are usable within the symbol table operations from 'symbol' as far up
36/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
37/// Returns success if all references up to 'within' could be computed.
38static LogicalResult
39collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
40 Operation *within,
42 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
43 MLIRContext *ctx = symbol->getContext();
44
45 auto leafRef = FlatSymbolRefAttr::get(symbolName);
46 results.push_back(leafRef);
47
48 // Early exit for when 'within' is the parent of 'symbol'.
49 Operation *symbolTableOp = symbol->getParentOp();
50 if (within == symbolTableOp)
51 return success();
52
53 // Collect references until 'symbolTableOp' reaches 'within'.
54 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
55 StringAttr symbolNameId =
56 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
57 do {
58 // Each parent of 'symbol' should define a symbol table.
59 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
60 return failure();
61 // Each parent of 'symbol' should also be a symbol.
62 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
63 if (!symbolTableName)
64 return failure();
65 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
66
67 symbolTableOp = symbolTableOp->getParentOp();
68 if (symbolTableOp == within)
69 break;
70 nestedRefs.insert(nestedRefs.begin(),
71 FlatSymbolRefAttr::get(symbolTableName));
72 } while (true);
73 return success();
74}
75
76/// Walk all of the operations within the given set of regions, without
77/// traversing into any nested symbol tables. Stops walking if the result of the
78/// callback is anything other than `WalkResult::advance`.
79static std::optional<WalkResult>
81 function_ref<std::optional<WalkResult>(Operation *)> callback) {
82 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
83 while (!worklist.empty()) {
84 for (Operation &op : worklist.pop_back_val()->getOps()) {
85 std::optional<WalkResult> result = callback(&op);
87 return result;
88
89 // If this op defines a new symbol table scope, we can't traverse. Any
90 // symbol references nested within 'op' are different semantically.
91 if (!op.hasTrait<OpTrait::SymbolTable>()) {
92 for (Region &region : op.getRegions())
93 worklist.push_back(&region);
94 }
95 }
96 }
97 return WalkResult::advance();
98}
99
100/// Walk all of the operations nested under, and including, the given operation,
101/// without traversing into any nested symbol tables. Stops walking if the
102/// result of the callback is anything other than `WalkResult::advance`.
103static std::optional<WalkResult>
105 function_ref<std::optional<WalkResult>(Operation *)> callback) {
106 std::optional<WalkResult> result = callback(op);
108 return result;
109 return walkSymbolTable(op->getRegions(), callback);
110}
111
112//===----------------------------------------------------------------------===//
113// SymbolTable
114//===----------------------------------------------------------------------===//
115
116/// Build a symbol table with the symbols within the given operation.
118 : symbolTableOp(symbolTableOp) {
119 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
120 "expected operation to have SymbolTable trait");
121 assert(symbolTableOp->getNumRegions() == 1 &&
122 "expected operation to have a single region");
123 assert(symbolTableOp->getRegion(0).hasOneBlock() &&
124 "expected operation to have a single block");
125
126 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 for (auto &op : symbolTableOp->getRegion(0).front()) {
129 StringAttr name = getNameIfSymbol(&op, symbolNameId);
130 if (!name)
131 continue;
132
133 auto inserted = symbolTable.insert({name, &op});
134 (void)inserted;
135 assert(inserted.second &&
136 "expected region to contain uniquely named symbol operations");
137 }
138}
139
140/// Look up a symbol with the specified name, returning null if no such name
141/// exists. Names never include the @ on them.
142Operation *SymbolTable::lookup(StringRef name) const {
143 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
144}
145Operation *SymbolTable::lookup(StringAttr name) const {
146 return symbolTable.lookup(name);
147}
148
150 StringAttr name = getNameIfSymbol(op);
151 assert(name && "expected valid 'name' attribute");
152 assert(op->getParentOp() == symbolTableOp &&
153 "expected this operation to be inside of the operation with this "
154 "SymbolTable");
155
156 auto it = symbolTable.find(name);
157 if (it != symbolTable.end() && it->second == op)
158 symbolTable.erase(it);
159}
160
162 remove(symbol);
163 symbol->erase();
164}
165
166// TODO: Consider if this should be renamed to something like insertOrUpdate
167/// Insert a new symbol into the table and associated operation if not already
168/// there and rename it as necessary to avoid collisions. Return the name of
169/// the symbol after insertion as attribute.
170StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
171 // The symbol cannot be the child of another op and must be the child of the
172 // symbolTableOp after this.
173 //
174 // TODO: consider if SymbolTable's constructor should behave the same.
175 if (!symbol->getParentOp()) {
176 auto &body = symbolTableOp->getRegion(0).front();
177 if (insertPt == Block::iterator()) {
178 insertPt = Block::iterator(body.end());
179 } else {
180 assert((insertPt == body.end() ||
181 insertPt->getParentOp() == symbolTableOp) &&
182 "expected insertPt to be in the associated module operation");
183 }
184 // Insert before the terminator, if any.
185 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
186 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
187 insertPt = std::prev(body.end());
188
189 body.getOperations().insert(insertPt, symbol);
190 }
191 assert(symbol->getParentOp() == symbolTableOp &&
192 "symbol is already inserted in another op");
193
194 // Add this symbol to the symbol table, uniquing the name if a conflict is
195 // detected.
196 StringAttr name = getSymbolName(symbol);
197 if (symbolTable.insert({name, symbol}).second)
198 return name;
199 // If the symbol was already in the table, also return.
200 if (symbolTable.lookup(name) == symbol)
201 return name;
202
203 MLIRContext *context = symbol->getContext();
205 name.getValue(),
206 [&](StringRef candidate) {
207 return !symbolTable
208 .insert({StringAttr::get(context, candidate), symbol})
209 .second;
210 },
211 uniquingCounter);
212 setSymbolName(symbol, nameBuffer);
213 return getSymbolName(symbol);
214}
215
216LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
217 Operation *op = lookup(from);
218 return rename(op, to);
219}
220
221LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
222 StringAttr from = getNameIfSymbol(op);
223 (void)from;
224
225 assert(from && "expected valid 'name' attribute");
226 assert(op->getParentOp() == symbolTableOp &&
227 "expected this operation to be inside of the operation with this "
228 "SymbolTable");
229 assert(lookup(from) == op && "current name does not resolve to op");
230 assert(lookup(to) == nullptr && "new name already exists");
231
232 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
233 return failure();
234
235 // Remove op with old name, change name, add with new name. The order is
236 // important here due to how `remove` and `insert` rely on the op name.
237 remove(op);
238 setSymbolName(op, to);
239 insert(op);
240
241 assert(lookup(to) == op && "new name does not resolve to renamed op");
242 assert(lookup(from) == nullptr && "old name still exists");
243
244 return success();
245}
246
247LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
248 auto toAttr = StringAttr::get(getOp()->getContext(), to);
249 return rename(from, toAttr);
250}
251
252LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
253 auto toAttr = StringAttr::get(getOp()->getContext(), to);
254 return rename(op, toAttr);
255}
256
257FailureOr<StringAttr>
260
261 // Determine new name that is unique in all symbol tables.
262 StringAttr newName;
263 {
264 MLIRContext *context = oldName.getContext();
265 SmallString<64> prefix = oldName.getValue();
266 int uniqueId = 0;
267 prefix.push_back('_');
268 while (true) {
269 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
270 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
271 if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
272 break;
273 }
274 }
275 }
276
277 // Apply renaming.
278 if (failed(rename(oldName, newName)))
279 return failure();
280 return newName;
281}
282
283FailureOr<StringAttr>
285 StringAttr from = getNameIfSymbol(op);
286 assert(from && "expected valid 'name' attribute");
287 return renameToUnique(from, others);
288}
289
290/// Returns the name of the given symbol operation.
292 StringAttr name = getNameIfSymbol(symbol);
293 assert(name && "expected valid symbol name");
294 return name;
295}
296
297/// Sets the name of the given symbol operation.
298void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
299 symbol->setAttr(getSymbolAttrName(), name);
300}
301
302/// Returns the visibility of the given symbol operation.
304 // If the attribute doesn't exist, assume public.
305 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
306 if (!vis)
307 return Visibility::Public;
308
309 // Otherwise, switch on the string value.
310 return StringSwitch<Visibility>(vis.getValue())
311 .Case("private", Visibility::Private)
312 .Case("nested", Visibility::Nested)
313 .Case("public", Visibility::Public);
314}
315/// Sets the visibility of the given symbol operation.
317 MLIRContext *ctx = symbol->getContext();
318
319 // If the visibility is public, just drop the attribute as this is the
320 // default.
321 if (vis == Visibility::Public) {
322 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
323 return;
324 }
325
326 // Otherwise, update the attribute.
327 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
328 "unknown symbol visibility kind");
329
330 StringRef visName = vis == Visibility::Private ? "private" : "nested";
331 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
332}
333
334/// Returns the nearest symbol table from a given operation `from`. Returns
335/// nullptr if no valid parent symbol table could be found.
337 assert(from && "expected valid operation");
339 return nullptr;
340
341 while (!from->hasTrait<OpTrait::SymbolTable>()) {
342 from = from->getParentOp();
343
344 // Check that this is a valid op and isn't an unknown symbol table.
345 if (!from || isPotentiallyUnknownSymbolTable(from))
346 return nullptr;
347 }
348 return from;
349}
350
351/// Walks all symbol table operations nested within, and including, `op`. For
352/// each symbol table operation, the provided callback is invoked with the op
353/// and a boolean signifying if the symbols within that symbol table can be
354/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
355/// all of the symbol uses of symbols within `op` are visible.
357 Operation *op, bool allSymUsesVisible,
358 function_ref<void(Operation *, bool)> callback) {
359 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
360 if (isSymbolTable) {
361 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
362 allSymUsesVisible |= !symbol || symbol.isPrivate();
363 } else {
364 // Otherwise if 'op' is not a symbol table, any nested symbols are
365 // guaranteed to be hidden.
366 allSymUsesVisible = true;
367 }
368
369 for (Region &region : op->getRegions())
370 for (Block &block : region)
371 for (Operation &nestedOp : block)
372 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
373
374 // If 'op' had the symbol table trait, visit it after any nested symbol
375 // tables.
376 if (isSymbolTable)
377 callback(op, allSymUsesVisible);
378}
379
380/// Returns the operation registered with the given symbol name with the
381/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
382/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
383/// was found.
385 StringAttr symbol) {
386 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
387 Region &region = symbolTableOp->getRegion(0);
388 if (region.empty())
389 return nullptr;
390
391 // Look for a symbol with the given name.
392 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394 for (auto &op : region.front())
395 if (getNameIfSymbol(&op, symbolNameId) == symbol)
396 return &op;
397 return nullptr;
398}
400 SymbolRefAttr symbol) {
401 SmallVector<Operation *, 4> resolvedSymbols;
402 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
403 return nullptr;
404 return resolvedSymbols.back();
405}
406
407/// Internal implementation of `lookupSymbolIn` that allows for specialized
408/// implementations of the lookup function.
409static LogicalResult lookupSymbolInImpl(
410 Operation *symbolTableOp, SymbolRefAttr symbol,
412 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
413 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
414
415 // Lookup the root reference for this symbol.
416 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
417 if (!symbolTableOp)
418 return failure();
419 symbols.push_back(symbolTableOp);
420
421 // If there are no nested references, just return the root symbol directly.
422 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
423 if (nestedRefs.empty())
424 return success();
425
426 // Verify that the root is also a symbol table.
427 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
428 return failure();
429
430 // Otherwise, lookup each of the nested non-leaf references and ensure that
431 // each corresponds to a valid symbol table.
432 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
433 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
434 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
435 return failure();
436 symbols.push_back(symbolTableOp);
437 }
438 symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
439 return success(symbols.back());
440}
441
442LogicalResult
443SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
445 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
446 return lookupSymbolIn(symbolTableOp, symbol);
447 };
448 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
449}
450
451/// Returns the operation registered with the given symbol name within the
452/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
453/// nullptr if no valid symbol was found.
455 StringAttr symbol) {
456 Operation *symbolTableOp = getNearestSymbolTable(from);
457 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
458}
460 SymbolRefAttr symbol) {
461 Operation *symbolTableOp = getNearestSymbolTable(from);
462 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
463}
464
466 SymbolTable::Visibility visibility) {
467 switch (visibility) {
469 return os << "public";
471 return os << "private";
473 return os << "nested";
474 }
475 llvm_unreachable("Unexpected visibility");
476}
477
478//===----------------------------------------------------------------------===//
479// SymbolTable Trait Types
480//===----------------------------------------------------------------------===//
481
483 if (op->getNumRegions() != 1)
484 return op->emitOpError()
485 << "Operations with a 'SymbolTable' must have exactly one region";
486 if (!op->getRegion(0).hasOneBlock())
487 return op->emitOpError()
488 << "Operations with a 'SymbolTable' must have exactly one block";
489
490 // Check that all symbols are uniquely named within child regions.
491 DenseMap<Attribute, Location> nameToOrigLoc;
492 for (auto &block : op->getRegion(0)) {
493 for (auto &op : block) {
494 // Check for a symbol name attribute.
495 auto nameAttr =
497 if (!nameAttr)
498 continue;
499
500 // Try to insert this symbol into the table.
501 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
502 if (!it.second)
503 return op.emitError()
504 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
505 .attachNote(it.first->second)
506 .append("see existing symbol definition here");
507 }
508 }
509
510 // Verify any nested symbol user operations.
511 SymbolTableCollection symbolTable;
512 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
513 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
514 if (failed(user.verifySymbolUses(symbolTable)))
515 return WalkResult::interrupt();
516 for (auto &attr : op->getDiscardableAttrs()) {
517 if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
518 if (failed(user.verifySymbolUses(op, symbolTable)))
519 return WalkResult::interrupt();
520 }
521 }
522 return WalkResult::advance();
523 };
524
525 std::optional<WalkResult> result =
526 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
527 return success(result && !result->wasInterrupted());
528}
529
530LogicalResult detail::verifySymbol(Operation *op) {
531 // Verify the name attribute.
533 return op->emitOpError() << "requires string attribute '"
535
536 // Verify the visibility attribute.
538 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
539 if (!visStrAttr)
540 return op->emitOpError() << "requires visibility attribute '"
542 << "' to be a string attribute, but got " << vis;
543
544 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
545 visStrAttr.getValue()))
546 return op->emitOpError()
547 << "visibility expected to be one of [\"public\", \"private\", "
548 "\"nested\"], but got "
549 << visStrAttr;
550 }
551 return success();
552}
553
554//===----------------------------------------------------------------------===//
555// Symbol Use Lists
556//===----------------------------------------------------------------------===//
557
558/// Walk all of the symbol references within the given operation, invoking the
559/// provided callback for each found use. The callbacks takes the use of the
560/// symbol.
561static WalkResult
564 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
565 [&](SymbolRefAttr symbolRef) {
566 if (callback({op, symbolRef}).wasInterrupted())
567 return WalkResult::interrupt();
568
569 // Don't walk nested references.
570 return WalkResult::skip();
571 });
572}
573
574/// Walk all of the uses, for any symbol, that are nested within the given
575/// regions, invoking the provided callback for each. This does not traverse
576/// into any nested symbol tables.
577static std::optional<WalkResult>
580 return walkSymbolTable(regions,
581 [&](Operation *op) -> std::optional<WalkResult> {
582 // Check that this isn't a potentially unknown symbol
583 // table.
585 return std::nullopt;
586
587 return walkSymbolRefs(op, callback);
588 });
589}
590/// Walk all of the uses, for any symbol, that are nested within the given
591/// operation 'from', invoking the provided callback for each. This does not
592/// traverse into any nested symbol tables.
593static std::optional<WalkResult>
596 // If this operation has regions, and it, as well as its dialect, isn't
597 // registered then conservatively fail. The operation may define a
598 // symbol table, so we can't opaquely know if we should traverse to find
599 // nested uses.
601 return std::nullopt;
602
603 // Walk the uses on this operation.
604 if (walkSymbolRefs(from, callback).wasInterrupted())
605 return WalkResult::interrupt();
606
607 // Only recurse if this operation is not a symbol table. A symbol table
608 // defines a new scope, so we can't walk the attributes from within the symbol
609 // table op.
610 if (!from->hasTrait<OpTrait::SymbolTable>())
611 return walkSymbolUses(from->getRegions(), callback);
612 return WalkResult::advance();
613}
614
615namespace {
616/// This class represents a single symbol scope. A symbol scope represents the
617/// set of operations nested within a symbol table that may reference symbols
618/// within that table. A symbol scope does not contain the symbol table
619/// operation itself, just its contained operations. A scope ends at leaf
620/// operations or another symbol table operation.
621struct SymbolScope {
622 /// Walk the symbol uses within this scope, invoking the given callback.
623 /// This variant is used when the callback type matches that expected by
624 /// 'walkSymbolUses'.
625 template <typename CallbackT,
626 std::enable_if_t<!std::is_same<
627 typename llvm::function_traits<CallbackT>::result_t,
628 void>::value> * = nullptr>
629 std::optional<WalkResult> walk(CallbackT cback) {
630 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
631 return walkSymbolUses(*region, cback);
632 return walkSymbolUses(cast<Operation *>(limit), cback);
633 }
634 /// This variant is used when the callback type matches a stripped down type:
635 /// void(SymbolTable::SymbolUse use)
636 template <typename CallbackT,
637 std::enable_if_t<std::is_same<
638 typename llvm::function_traits<CallbackT>::result_t,
639 void>::value> * = nullptr>
640 std::optional<WalkResult> walk(CallbackT cback) {
641 return walk([=](SymbolTable::SymbolUse use) {
642 return cback(use), WalkResult::advance();
643 });
644 }
645
646 /// Walk all of the operations nested under the current scope without
647 /// traversing into any nested symbol tables.
648 template <typename CallbackT>
649 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
650 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
651 return ::walkSymbolTable(*region, cback);
652 return ::walkSymbolTable(cast<Operation *>(limit), cback);
653 }
654
655 /// The representation of the symbol within this scope.
656 SymbolRefAttr symbol;
657
658 /// The IR unit representing this scope.
659 llvm::PointerUnion<Operation *, Region *> limit;
660};
661} // namespace
662
663/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
665 Operation *limit) {
666 StringAttr symName = SymbolTable::getSymbolName(symbol);
667 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
668
669 // Compute the ancestors of 'limit'.
672 limitAncestors;
673 Operation *limitAncestor = limit;
674 do {
675 // Check to see if 'symbol' is an ancestor of 'limit'.
676 if (limitAncestor == symbol) {
677 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
678 // doesn't support parent references.
680 symbol->getParentOp())
681 return {{SymbolRefAttr::get(symName), limit}};
682 return {};
683 }
684
685 limitAncestors.insert(limitAncestor);
686 } while ((limitAncestor = limitAncestor->getParentOp()));
687
688 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
689 Operation *commonAncestor = symbol->getParentOp();
690 do {
691 if (limitAncestors.count(commonAncestor))
692 break;
693 } while ((commonAncestor = commonAncestor->getParentOp()));
694 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
695
696 // Compute the set of valid nested references for 'symbol' as far up to the
697 // common ancestor as possible.
699 bool collectedAllReferences = succeeded(
700 collectValidReferencesFor(symbol, symName, commonAncestor, references));
701
702 // Handle the case where the common ancestor is 'limit'.
703 if (commonAncestor == limit) {
705
706 // Walk each of the ancestors of 'symbol', calling the compute function for
707 // each one.
708 Operation *limitIt = symbol->getParentOp();
709 for (size_t i = 0, e = references.size(); i != e;
710 ++i, limitIt = limitIt->getParentOp()) {
711 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
712 scopes.push_back({references[i], &limitIt->getRegion(0)});
713 }
714 return scopes;
715 }
716
717 // Otherwise, we just need the symbol reference for 'symbol' that will be
718 // used within 'limit'. This is the last reference in the list we computed
719 // above if we were able to collect all references.
720 if (!collectedAllReferences)
721 return {};
722 return {{references.back(), limit}};
723}
725 Region *limit) {
726 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
727
728 // If we collected some scopes to walk, make sure to constrain the one for
729 // limit to the specific region requested.
730 if (!scopes.empty())
731 scopes.back().limit = limit;
732 return scopes;
733}
735 Region *limit) {
736 return {{SymbolRefAttr::get(symbol), limit}};
737}
738
740 Operation *limit) {
742 auto symbolRef = SymbolRefAttr::get(symbol);
743 for (auto &region : limit->getRegions())
744 scopes.push_back({symbolRef, &region});
745 return scopes;
746}
747
748/// Returns true if the given reference 'SubRef' is a sub reference of the
749/// reference 'ref', i.e. 'ref' is a further qualified reference.
750static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
751 if (ref == subRef)
752 return true;
753
754 // If the references are not pointer equal, check to see if `subRef` is a
755 // prefix of `ref`.
756 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
757 ref.getRootReference() != subRef.getRootReference())
758 return false;
759
760 auto refLeafs = ref.getNestedReferences();
761 auto subRefLeafs = subRef.getNestedReferences();
762 return subRefLeafs.size() < refLeafs.size() &&
763 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
764}
765
766//===----------------------------------------------------------------------===//
767// SymbolTable::getSymbolUses
768//===----------------------------------------------------------------------===//
769
770/// The implementation of SymbolTable::getSymbolUses below.
771template <typename FromT>
772static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
773 std::vector<SymbolTable::SymbolUse> uses;
774 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
775 uses.push_back(symbolUse);
776 return WalkResult::advance();
777 };
778 auto result = walkSymbolUses(from, walkFn);
779 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
780 : std::nullopt;
781}
782
783/// Get an iterator range for all of the uses, for any symbol, that are nested
784/// within the given operation 'from'. This does not traverse into any nested
785/// symbol tables, and will also only return uses on 'from' if it does not
786/// also define a symbol table. This is because we treat the region as the
787/// boundary of the symbol table, and not the op itself. This function returns
788/// std::nullopt if there are any unknown operations that may potentially be
789/// symbol tables.
790auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
791 return getSymbolUsesImpl(from);
792}
793auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
795}
796
797//===----------------------------------------------------------------------===//
798// SymbolTable::getSymbolUses
799//===----------------------------------------------------------------------===//
800
801/// The implementation of SymbolTable::getSymbolUses below.
802template <typename SymbolT, typename IRUnitT>
803static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
804 IRUnitT *limit) {
805 std::vector<SymbolTable::SymbolUse> uses;
806 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
807 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
808 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
809 uses.push_back(symbolUse);
810 }))
811 return std::nullopt;
812 }
813 return SymbolTable::UseRange(std::move(uses));
814}
815
816/// Get all of the uses of the given symbol that are nested within the given
817/// operation 'from'. This does not traverse into any nested symbol tables.
818/// This function returns std::nullopt if there are any unknown operations that
819/// may potentially be symbol tables.
820auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
821 -> std::optional<UseRange> {
822 return getSymbolUsesImpl(symbol, from);
823}
825 -> std::optional<UseRange> {
826 return getSymbolUsesImpl(symbol, from);
827}
828auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
829 -> std::optional<UseRange> {
830 return getSymbolUsesImpl(symbol, from);
831}
833 -> std::optional<UseRange> {
834 return getSymbolUsesImpl(symbol, from);
835}
836
837//===----------------------------------------------------------------------===//
838// SymbolTable::symbolKnownUseEmpty
839//===----------------------------------------------------------------------===//
840
841/// The implementation of SymbolTable::symbolKnownUseEmpty below.
842template <typename SymbolT, typename IRUnitT>
843static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
844 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
845 // Walk all of the symbol uses looking for a reference to 'symbol'.
846 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
847 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
848 ? WalkResult::interrupt()
849 : WalkResult::advance();
850 }) != WalkResult::advance())
851 return false;
852 }
853 return true;
854}
855
856/// Return if the given symbol is known to have no uses that are nested within
857/// the given operation 'from'. This does not traverse into any nested symbol
858/// tables. This function will also return false if there are any unknown
859/// operations that may potentially be symbol tables.
860bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
861 return symbolKnownUseEmptyImpl(symbol, from);
862}
864 return symbolKnownUseEmptyImpl(symbol, from);
865}
866bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
867 return symbolKnownUseEmptyImpl(symbol, from);
868}
870 return symbolKnownUseEmptyImpl(symbol, from);
871}
872
873//===----------------------------------------------------------------------===//
874// SymbolTable::replaceAllSymbolUses
875//===----------------------------------------------------------------------===//
876
877/// Generates a new symbol reference attribute with a new leaf reference.
878static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
879 FlatSymbolRefAttr newLeafAttr) {
880 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
881 return newLeafAttr;
882 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
883 nestedRefs.back() = newLeafAttr;
884 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
885}
886
887/// The implementation of SymbolTable::replaceAllSymbolUses below.
888template <typename SymbolT, typename IRUnitT>
889static LogicalResult
890replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
891 // Generate a new attribute to replace the given attribute.
892 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
893 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
894 SymbolRefAttr oldAttr = scope.symbol;
895 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
896 AttrTypeReplacer replacer;
897 replacer.addReplacement(
898 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
899 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
900 // want to accidentally replace an inner reference.
901 if (attr == oldAttr)
902 return {newAttr, WalkResult::skip()};
903 // Handle prefix matches.
904 if (isReferencePrefixOf(oldAttr, attr)) {
905 auto oldNestedRefs = oldAttr.getNestedReferences();
906 auto nestedRefs = attr.getNestedReferences();
907 if (oldNestedRefs.empty())
908 return {SymbolRefAttr::get(newSymbol, nestedRefs),
910
911 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
912 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
913 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
915 }
916 return {attr, WalkResult::skip()};
917 });
918
919 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
920 replacer.replaceElementsIn(op);
921 return WalkResult::advance();
922 };
923 if (!scope.walkSymbolTable(walkFn))
924 return failure();
925 }
926 return success();
927}
928
929/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
930/// provided symbol 'newSymbol' that are nested within the given operation
931/// 'from'. This does not traverse into any nested symbol tables. If there are
932/// any unknown operations that may potentially be symbol tables, no uses are
933/// replaced and failure is returned.
934LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
935 StringAttr newSymbol,
936 Operation *from) {
937 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
938}
940 StringAttr newSymbol,
941 Operation *from) {
942 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
943}
944LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
945 StringAttr newSymbol,
946 Region *from) {
947 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
948}
950 StringAttr newSymbol,
951 Region *from) {
952 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
953}
954
955//===----------------------------------------------------------------------===//
956// SymbolTableCollection
957//===----------------------------------------------------------------------===//
958
960 StringAttr symbol) {
961 return getSymbolTable(symbolTableOp).lookup(symbol);
962}
964 SymbolRefAttr name) {
966 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
967 return nullptr;
968 return symbols.back();
969}
970/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
971/// a given SymbolRefAttr. Returns failure if any of the nested references could
972/// not be resolved.
973LogicalResult
975 SymbolRefAttr name,
977 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
978 return lookupSymbolIn(symbolTableOp, symbol);
979 };
980 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
981}
982
983/// Returns the operation registered with the given symbol name within the
984/// closest parent operation of, or including, 'from' with the
985/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
986/// found.
988 StringAttr symbol) {
989 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
990 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
991}
992Operation *
994 SymbolRefAttr symbol) {
995 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
996 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
997}
998
999/// Lookup, or create, a symbol table for an operation.
1001 auto it = symbolTables.try_emplace(op, nullptr);
1002 if (it.second)
1003 it.first->second = std::make_unique<SymbolTable>(op);
1004 return *it.first->second;
1005}
1006
1008 symbolTables.erase(op);
1009}
1010
1011//===----------------------------------------------------------------------===//
1012// LockedSymbolTableCollection
1013//===----------------------------------------------------------------------===//
1014
1016 StringAttr symbol) {
1017 return getSymbolTable(symbolTableOp).lookup(symbol);
1018}
1019
1020Operation *
1022 FlatSymbolRefAttr symbol) {
1023 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1024}
1025
1027 SymbolRefAttr name) {
1029 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1030 return nullptr;
1031 return symbols.back();
1032}
1033
1035 Operation *symbolTableOp, SymbolRefAttr name,
1037 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1038 return lookupSymbolIn(symbolTableOp, symbol);
1039 };
1040 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1041}
1042
1044LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1045 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1046 // Try to find an existing symbol table.
1047 {
1048 llvm::sys::SmartScopedReader<true> lock(mutex);
1049 auto it = collection.symbolTables.find(symbolTableOp);
1050 if (it != collection.symbolTables.end())
1051 return *it->second;
1052 }
1053 // Create a symbol table for the operation. Perform construction outside of
1054 // the critical section.
1055 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1056 // Insert the constructed symbol table.
1057 llvm::sys::SmartScopedWriter<true> lock(mutex);
1058 return *collection.symbolTables
1059 .insert({symbolTableOp, std::move(symbolTable)})
1060 .first->second;
1061}
1062
1063//===----------------------------------------------------------------------===//
1064// SymbolUserMap
1065//===----------------------------------------------------------------------===//
1066
1068 Operation *symbolTableOp)
1069 : symbolTable(symbolTable) {
1070 // Walk each of the symbol tables looking for discardable callgraph nodes.
1072 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1073 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1074 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1075 assert(symbolUses && "expected uses to be valid");
1076
1077 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1078 symbols.clear();
1079 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1080 symbols);
1081 for (Operation *symbolOp : symbols)
1082 symbolToUsers[symbolOp].insert(use.getUser());
1083 }
1084 }
1085 };
1086 // We just set `allSymUsesVisible` to false here because it isn't necessary
1087 // for building the user map.
1088 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1089 walkFn);
1090}
1091
1093 StringAttr newSymbolName) {
1094 auto it = symbolToUsers.find(symbol);
1095 if (it == symbolToUsers.end())
1096 return;
1097
1098 // Replace the uses within the users of `symbol`.
1099 for (Operation *user : it->second)
1100 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1101
1102 // Move the current users of `symbol` to the new symbol if it is in the
1103 // symbol table.
1104 Operation *newSymbol =
1105 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1106 if (newSymbol != symbol) {
1107 // Transfer over the users to the new symbol. The reference to the old one
1108 // is fetched again as the iterator is invalidated during the insertion.
1109 auto newIt = symbolToUsers.try_emplace(newSymbol);
1110 auto oldIt = symbolToUsers.find(symbol);
1111 assert(oldIt != symbolToUsers.end() && "missing old users list");
1112 if (newIt.second)
1113 newIt.first->second = std::move(oldIt->second);
1114 else
1115 newIt.first->second.set_union(oldIt->second);
1116 symbolToUsers.erase(oldIt);
1117 }
1118}
1119
1120//===----------------------------------------------------------------------===//
1121// Visibility parsing implementation.
1122//===----------------------------------------------------------------------===//
1123
1125 NamedAttrList &attrs) {
1126 StringRef visibility;
1127 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1128 return failure();
1129
1130 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1131 attrs.push_back(parser.getBuilder().getNamedAttr(
1132 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1133 return success();
1134}
1135
1136//===----------------------------------------------------------------------===//
1137// Symbol Interfaces
1138//===----------------------------------------------------------------------===//
1139
1140/// Include the generated symbol interfaces.
1141#include "mlir/IR/SymbolInterfaces.cpp.inc"
1142#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol) override
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:486
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
virtual void invalidateSymbolTable(Operation *op)
Invalidate the cached symbol table for an operation.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static SmallString< N > generateSymbolName(StringRef name, UniqueChecker uniqueChecker, unsigned &uniquingCounter)
Generate a unique symbol name.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
Definition SymbolTable.h:93
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition SymbolTable.h:82
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
Operation * getOp() const
Returns the associated operation.
Definition SymbolTable.h:79
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
FailureOr< StringAttr > renameToUnique(StringAttr from, ArrayRef< SymbolTable * > others)
Renames the given op or the op refered to by the given name to the a name that is unique within this ...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::StringSwitch< T, R > StringSwitch
Definition LLVM.h:141
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152