12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
43 assert(within->
isAncestor(symbol) &&
"expected 'within' to be an ancestor");
47 results.push_back(leafRef);
51 if (within == symbolTableOp)
56 StringAttr symbolNameId =
63 StringAttr symbolTableName =
getNameIfSymbol(symbolTableOp, symbolNameId);
69 if (symbolTableOp == within)
71 nestedRefs.insert(nestedRefs.begin(),
80 static std::optional<WalkResult>
84 while (!worklist.empty()) {
85 for (
Operation &op : worklist.pop_back_val()->getOps()) {
86 std::optional<WalkResult> result = callback(&op);
94 worklist.push_back(®ion);
104 static std::optional<WalkResult>
107 std::optional<WalkResult> result = callback(op);
119 : symbolTableOp(symbolTableOp) {
121 "expected operation to have SymbolTable trait");
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp->
getRegion(0)) &&
125 "expected operation to have a single block");
134 auto inserted = symbolTable.insert({name, &op});
136 assert(inserted.second &&
137 "expected region to contain uniquely named symbol operations");
147 return symbolTable.lookup(name);
152 assert(name &&
"expected valid 'name' attribute");
154 "expected this operation to be inside of the operation with this "
157 auto it = symbolTable.find(name);
158 if (it != symbolTable.end() && it->second == op)
159 symbolTable.erase(it);
181 assert((insertPt == body.end() ||
182 insertPt->getParentOp() == symbolTableOp) &&
183 "expected insertPt to be in the associated module operation");
188 insertPt = std::prev(body.end());
190 body.getOperations().insert(insertPt, symbol);
193 "symbol is already inserted in another op");
198 if (symbolTable.insert({name, symbol}).second)
201 if (symbolTable.lookup(name) == symbol)
206 unsigned originalLength = nameBuffer.size();
212 nameBuffer.resize(originalLength);
214 nameBuffer += std::to_string(uniquingCounter++);
215 }
while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
224 assert(name &&
"expected valid symbol name");
259 "unknown symbol visibility kind");
268 assert(from &&
"expected valid operation");
292 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
293 allSymUsesVisible |= !symbol || symbol.isPrivate();
297 allSymUsesVisible =
true;
301 for (
Block &block : region)
308 callback(op, allSymUsesVisible);
325 for (
auto &op : region.
front())
331 SymbolRefAttr symbol) {
335 return resolvedSymbols.back();
341 Operation *symbolTableOp, SymbolRefAttr symbol,
347 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
350 symbols.push_back(symbolTableOp);
354 if (nestedRefs.empty())
364 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.
getAttr());
367 symbols.push_back(symbolTableOp);
369 symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
370 return success(symbols.back());
376 auto lookupFn = [](
Operation *symbolTableOp, StringAttr symbol) {
388 return symbolTableOp ?
lookupSymbolIn(symbolTableOp, symbol) :
nullptr;
391 SymbolRefAttr symbol) {
393 return symbolTableOp ?
lookupSymbolIn(symbolTableOp, symbol) :
nullptr;
398 switch (visibility) {
400 return os <<
"public";
402 return os <<
"private";
404 return os <<
"nested";
406 llvm_unreachable(
"Unexpected visibility");
416 <<
"Operations with a 'SymbolTable' must have exactly one region";
417 if (!llvm::hasSingleElement(op->
getRegion(0)))
419 <<
"Operations with a 'SymbolTable' must have exactly one block";
424 for (
auto &op : block) {
432 auto it = nameToOrigLoc.try_emplace(nameAttr, op.
getLoc());
435 .
append(
"redefinition of symbol named '", nameAttr.getValue(),
"'")
436 .attachNote(it.first->second)
437 .
append(
"see existing symbol definition here");
443 auto verifySymbolUserFn = [&](
Operation *op) -> std::optional<WalkResult> {
444 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
445 return WalkResult(user.verifySymbolUses(symbolTable));
449 std::optional<WalkResult> result =
451 return success(result && !result->wasInterrupted());
457 return op->
emitOpError() <<
"requires string attribute '"
462 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
464 return op->
emitOpError() <<
"requires visibility attribute '"
466 <<
"' to be a string attribute, but got " << vis;
469 visStrAttr.getValue()))
471 <<
"visibility expected to be one of [\"public\", \"private\", "
472 "\"nested\"], but got "
489 [&](SymbolRefAttr symbolRef) {
490 if (callback({op, symbolRef}).wasInterrupted())
501 static std::optional<WalkResult>
505 [&](
Operation *op) -> std::optional<WalkResult> {
517 static std::optional<WalkResult>
549 template <
typename CallbackT,
550 std::enable_if_t<!std::is_same<
551 typename llvm::function_traits<CallbackT>::result_t,
552 void>::value> * =
nullptr>
553 std::optional<WalkResult>
walk(CallbackT cback) {
554 if (
Region *region = llvm::dyn_cast_if_present<Region *>(limit))
560 template <
typename CallbackT,
561 std::enable_if_t<std::is_same<
562 typename llvm::function_traits<CallbackT>::result_t,
563 void>::value> * =
nullptr>
564 std::optional<WalkResult>
walk(CallbackT cback) {
572 template <
typename CallbackT>
574 if (
Region *region = llvm::dyn_cast_if_present<Region *>(limit))
580 SymbolRefAttr symbol;
600 if (limitAncestor == symbol) {
609 limitAncestors.insert(limitAncestor);
610 }
while ((limitAncestor = limitAncestor->
getParentOp()));
615 if (limitAncestors.count(commonAncestor))
617 }
while ((commonAncestor = commonAncestor->
getParentOp()));
618 assert(commonAncestor &&
"'limit' and 'symbol' have no common ancestor");
627 if (commonAncestor == limit) {
633 for (
size_t i = 0, e = references.size(); i != e;
636 scopes.push_back({references[i], &limitIt->
getRegion(0)});
644 if (!collectedAllReferences)
646 return {{references.back(), limit}};
655 scopes.back().limit = limit;
658 template <
typename IRUnit>
672 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
673 ref.getRootReference() != subRef.getRootReference())
676 auto refLeafs = ref.getNestedReferences();
677 auto subRefLeafs = subRef.getNestedReferences();
678 return subRefLeafs.size() < refLeafs.size() &&
679 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
686 template <
typename FromT>
688 std::vector<SymbolTable::SymbolUse> uses;
690 uses.push_back(symbolUse);
694 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
716 template <
typename SymbolT,
typename IRUnitT>
719 std::vector<SymbolTable::SymbolUse> uses;
722 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
723 uses.push_back(symbolUse);
735 -> std::optional<UseRange> {
739 -> std::optional<UseRange> {
743 -> std::optional<UseRange> {
747 -> std::optional<UseRange> {
755 template <
typename SymbolT,
typename IRUnitT>
760 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
761 ? WalkResult::interrupt()
762 : WalkResult::advance();
792 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
794 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
795 nestedRefs.back() = newLeafAttr;
800 template <
typename SymbolT,
typename IRUnitT>
806 SymbolRefAttr oldAttr = scope.symbol;
810 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
817 auto oldNestedRefs = oldAttr.getNestedReferences();
818 auto nestedRefs = attr.getNestedReferences();
819 if (oldNestedRefs.empty())
823 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
824 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
831 auto walkFn = [&](
Operation *op) -> std::optional<WalkResult> {
835 if (!scope.walkSymbolTable(walkFn))
847 StringAttr newSymbol,
852 StringAttr newSymbol,
857 StringAttr newSymbol,
862 StringAttr newSymbol,
876 SymbolRefAttr name) {
880 return symbols.back();
889 auto lookupFn = [
this](
Operation *symbolTableOp, StringAttr symbol) {
902 return symbolTableOp ?
lookupSymbolIn(symbolTableOp, symbol) :
nullptr;
906 SymbolRefAttr symbol) {
908 return symbolTableOp ?
lookupSymbolIn(symbolTableOp, symbol) :
nullptr;
913 auto it = symbolTables.try_emplace(op,
nullptr);
915 it.first->second = std::make_unique<SymbolTable>(op);
916 return *it.first->second;
925 return getSymbolTable(symbolTableOp).
lookup(symbol);
935 SymbolRefAttr name) {
939 return symbols.back();
943 Operation *symbolTableOp, SymbolRefAttr name,
945 auto lookupFn = [
this](
Operation *symbolTableOp, StringAttr symbol) {
952 LockedSymbolTableCollection::getSymbolTable(
Operation *symbolTableOp) {
956 llvm::sys::SmartScopedReader<true> lock(mutex);
957 auto it = collection.symbolTables.find(symbolTableOp);
958 if (it != collection.symbolTables.end())
963 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
965 llvm::sys::SmartScopedWriter<true> lock(mutex);
966 return *collection.symbolTables
967 .insert({symbolTableOp, std::move(symbolTable)})
977 : symbolTable(symbolTable) {
980 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
983 assert(symbolUses &&
"expected uses to be valid");
987 (void)symbolTable.
lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
990 symbolToUsers[symbolOp].insert(use.getUser());
1001 StringAttr newSymbolName) {
1002 auto it = symbolToUsers.find(symbol);
1003 if (it == symbolToUsers.end())
1014 if (newSymbol != symbol) {
1018 auto oldIt = symbolToUsers.find(symbol);
1019 assert(oldIt != symbolToUsers.end() &&
"missing old users list");
1021 newIt.first->second = std::move(oldIt->second);
1023 newIt.first->second.set_union(oldIt->second);
1024 symbolToUsers.erase(oldIt);
1034 StringRef visibility;
1049 #include "mlir/IR/SymbolInterfaces.cpp.inc"
static std::optional< WalkResult > walkSymbolTable(MutableArrayRef< Region > regions, function_ref< std::optional< WalkResult >(Operation *)> callback)
Walk all of the operations within the given set of regions, without traversing into any nested symbol...
static LogicalResult collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl< SymbolRefAttr > &results)
Computes the nested symbol reference attribute for the symbol 'symbolName' that are usable within the...
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit)
The implementation of SymbolTable::symbolKnownUseEmpty below.
static SmallVector< SymbolScope, 2 > collectSymbolScopes(Operation *symbol, Operation *limit)
Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static WalkResult walkSymbolRefs(Operation *op, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the symbol references within the given operation, invoking the provided callback for each...
static StringAttr getNameIfSymbol(Operation *op)
Returns the string name of the given symbol, or null if this is not a symbol.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref)
Returns true if the given reference 'SubRef' is a sub reference of the reference 'ref',...
static std::optional< WalkResult > walkSymbolUses(MutableArrayRef< Region > regions, function_ref< WalkResult(SymbolTable::SymbolUse)> callback)
Walk all of the uses, for any symbol, that are nested within the given regions, invoking the provided...
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr)
Generates a new symbol reference attribute with a new leaf reference.
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit)
The implementation of SymbolTable::replaceAllSymbolUses below.
static LogicalResult lookupSymbolInImpl(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl< Operation * > &symbols, function_ref< Operation *(Operation *, StringAttr)> lookupSymbolFn)
Internal implementation of lookupSymbolIn that allows for specialized implementations of the lookup f...
static std::optional< SymbolTable::UseRange > getSymbolUsesImpl(FromT from)
The implementation of SymbolTable::getSymbolUses below.
static bool isPotentiallyUnknownSymbolTable(Operation *op)
Return true if the given operation is unknown and may potentially define a symbol table.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
StringAttr getStringAttr(const Twine &bytes)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
IRUnit is a union of the different types of IR objects that consistute the IR structure (other than T...
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Visibility
An enumeration detailing the different visibility types that a symbol may have.
@ Nested
The symbol is visible to the current IR, which may include operations in symbol tables above the one ...
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
void erase(Operation *symbol)
Erase the given symbol from the table and delete the operation.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
SymbolTable(Operation *symbolTableOp)
Build a symbol table with the symbols within the given operation.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName)
Replace all of the uses of the given symbol with newSymbolName.
SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp)
Build a user map for all of the symbols defined in regions nested under 'symbolTableOp'.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
LogicalResult verifySymbol(Operation *op)
LogicalResult verifySymbolTable(Operation *op)
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
This class represents an efficient way to signal success or failure.