13#ifndef MLIR_IR_DIALECTREGISTRY_H
14#define MLIR_IR_DIALECTREGISTRY_H
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/LogicalResult.h"
59 virtual std::unique_ptr<DialectExtensionBase>
clone()
const = 0;
66 : dialectNames(dialectNames) {}
76template <
typename DerivedT,
typename... DialectsT>
83 std::unique_ptr<DialectExtensionBase>
clone() const final {
84 return std::make_unique<DerivedT>(
static_cast<const DerivedT &
>(*
this));
90 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
95 unsigned dialectIdx = 0;
96 auto derivedDialects = std::tuple<DialectsT *...>{
97 static_cast<DialectsT *
>(dialects[dialectIdx++])...};
98 std::apply([&](DialectsT *...dialect) {
apply(context, dialect...); },
109 TypeID interfaceRequestorID,
111 StringRef interfaceName);
117 TypeID interfaceRequestorID,
125template <
typename ConcreteT,
typename InterfaceT>
128 InterfaceT::getInterfaceID());
152 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
162 template <
typename ConcreteDialect>
165 ConcreteDialect::getDialectNamespace(),
173 template <
typename ConcreteDialect,
typename OtherDialect,
174 typename... MoreDialects>
177 insert<OtherDialect, MoreDialects...>();
207 for (
const auto &nameAndRegistrationIt : registry)
208 destination.
insert(nameAndRegistrationIt.second.first,
209 nameAndRegistrationIt.first,
210 nameAndRegistrationIt.second.second);
211 for (
const std::string &name : dialectsToPreload)
214 for (
const auto &extension : extensions)
215 destination.extensions.try_emplace(extension.first,
216 extension.second->clone());
224 names.reserve(registry.size());
225 for (
const auto &item : registry)
226 names.push_back(item.first);
234 return dialectsToPreload;
256 std::unique_ptr<DialectExtensionBase> extension) {
257 return extensions.try_emplace(extensionID, std::move(extension)).second;
261 template <
typename... ExtensionsT>
277 template <
typename... DialectsT>
282 Extension(
const Extension &) =
default;
283 Extension(ExtensionFnT extensionFn)
285 extensionFn(extensionFn) {}
286 ~Extension()
override =
default;
288 void apply(
MLIRContext *context, DialectsT *...dialects)
const final {
289 extensionFn(context, dialects...);
291 ExtensionFnT extensionFn;
294 reinterpret_cast<const void *
>(extensionFn)),
295 std::make_unique<Extension>(extensionFn));
309 llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
virtual ~DialectExtensionBase()
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
SmallVector< StringRef > getRegisteredDialectNames() const
Return the names of dialects registered in this registry with an allocator function.
ArrayRef< std::string > getDialectsToPreload() const
Return the names of dialects that should be preloaded into the context but whose allocator is expecte...
LogicalResult preloadSelectDialects(MLIRContext *ctx, function_ref< InFlightDiagnostic()> emitError={}) const
Load into ctx every dialect previously added via addDialectToPreload(StringRef).
DialectRegistry(const DialectRegistry &)=delete
void addExtensions()
Add the given extensions to the registry.
DialectRegistry & operator=(const DialectRegistry &other)=delete
DialectRegistry & operator=(DialectRegistry &&other)=default
void addDialectToPreload(StringRef name)
Request that the dialect with the given name be preloaded into the MLIRContext, without providing an ...
DialectRegistry(DialectRegistry &&)=default
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A dialect that can be defined at runtime.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
This class provides an efficient unique identifier for a specific C++ type.
static TypeID get()
Construct a type info object for the given type T.
static TypeID getFromOpaquePointer(const void *pointer)
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
function_ref< Dialect *(MLIRContext *)> DialectAllocatorFunctionRef
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
llvm::function_ref< Fn > function_ref