19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/SmallVectorExtras.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ManagedStatic.h"
26 #include "llvm/Support/Regex.h"
29 #define DEBUG_TYPE "dialect"
32 using namespace detail;
39 : name(name), dialectID(id), context(context) {
67 <<
"' provides no attribute parsing hook";
80 <<
"dialect '" <<
getNamespace() <<
"' provides no type parsing hook";
84 std::optional<Dialect::ParseOpHook>
92 "Dialect hook invoked on non-dialect owned operation");
99 llvm::Regex dialectNameRegex(
"^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
100 return dialectNameRegex.match(str);
108 auto it = registeredInterfaces.try_emplace(interface->getID(),
109 std::move(interface));
113 llvm::dbgs() <<
"[" DEBUG_TYPE
114 "] repeated interface registration for dialect "
130 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
134 dialect->handleUseOfUndefinedPromisedInterface(
135 dialect->getTypeID(), interfaceKind, interfaceName);
137 if (
auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
138 interfaces.insert(interface);
139 orderedInterfaces.push_back(interface);
144 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() =
default;
149 DialectInterfaceCollectionBase::getInterfaceFor(
Operation *op)
const {
161 StringRef interfaceName) {
163 interfaceID, interfaceName);
173 TypeID interfaceRequestorID,
183 template <
typename Fn>
184 void applyExtensionsFn(
186 const llvm::MapVector<
TypeID, std::unique_ptr<DialectExtensionBase>>
192 const auto extractExtension =
194 return entry.second.get();
197 auto startIt = extensions.begin(), endIt = extensions.end();
199 while (startIt != endIt) {
200 count += endIt - startIt;
204 llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
206 for (
const auto *ext : subset)
207 applyExtension(*ext);
210 startIt = extensions.begin() + count;
211 endIt = extensions.end();
220 auto it = registry.find(name.str());
221 if (it == registry.end())
223 return it->second.second;
228 auto inserted = registry.insert(
229 std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
230 if (!inserted.second && inserted.first->second.first != typeID) {
231 llvm::report_fatal_error(
232 "Trying to register different dialects for the same namespace: " +
242 TypeID typeID = TypeID::get<void>();
245 auto constructor = [nameStr = name.str(), ctor](
MLIRContext *ctx) {
246 auto *dynDialect = ctx->getOrLoadDynamicDialect(
247 nameStr, [ctx, ctor](
DynamicDialect *dialect) { ctor(ctx, dialect); });
248 assert(dynDialect &&
"Dynamic dialect creation unexpectedly failed");
252 insert(typeID, name, constructor);
263 if (dialectNames.empty()) {
264 extension.apply(ctx, dialect);
270 if (dialectNames.size() == 1) {
271 if (dialectNames.front() == dialectName)
272 extension.apply(ctx, dialect);
277 const StringRef *nameIt = llvm::find(dialectNames, dialectName);
278 if (nameIt == dialectNames.end())
284 requiredDialects.reserve(dialectNames.size());
285 for (
auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
289 requiredDialects.push_back(dialect);
296 requiredDialects.push_back(loadedDialect);
298 extension.apply(ctx, requiredDialects);
301 applyExtensionsFn(applyExtension, extensions);
308 if (dialectNames.empty()) {
310 extension.apply(ctx, loadedDialects);
316 requiredDialects.reserve(dialectNames.size());
317 for (StringRef dialectName : dialectNames) {
321 requiredDialects.push_back(loadedDialect);
323 extension.apply(ctx, requiredDialects);
326 applyExtensionsFn(applyExtension, extensions);
331 const auto hasExtension = [&](
const auto &key) {
332 return rhs.extensions.contains(key);
334 if (!llvm::all_of(make_first_range(extensions), hasExtension))
339 registry, [&](
const auto &it) {
return rhs.registry.count(it.first); });
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
Attributes are known-constant values of operations.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
This class represents an opaque dialect extension.
virtual ~DialectExtensionBase()
This class represents an interface overridden for a single dialect.
virtual ~DialectInterface()
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
StringRef getNamespace() const
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
MLIRContext * getContext() const
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...