13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/MapVector.h"
56 virtual std::unique_ptr<DialectExtensionBase>
clone()
const = 0;
63 : dialectNames(dialectNames) {}
73 template <
typename DerivedT,
typename... DialectsT>
80 std::unique_ptr<DialectExtensionBase>
clone() const final {
81 return std::make_unique<DerivedT>(
static_cast<const DerivedT &
>(*
this));
87 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
92 unsigned dialectIdx = 0;
93 auto derivedDialects = std::tuple<DialectsT *...>{
94 static_cast<DialectsT *
>(dialects[dialectIdx++])...};
95 std::apply([&](DialectsT *...dialect) {
apply(context, dialect...); },
100 namespace dialect_extension_detail {
106 TypeID interfaceRequestorID,
108 StringRef interfaceName);
114 TypeID interfaceRequestorID,
122 template <
typename ConcreteT,
typename InterfaceT>
125 InterfaceT::getInterfaceID());
141 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
146 template <
typename ConcreteDialect>
148 insert(TypeID::get<ConcreteDialect>(),
149 ConcreteDialect::getDialectNamespace(),
157 template <
typename ConcreteDialect,
typename OtherDialect,
158 typename... MoreDialects>
160 insert<ConcreteDialect>();
161 insert<OtherDialect, MoreDialects...>();
184 for (
const auto &nameAndRegistrationIt : registry)
185 destination.
insert(nameAndRegistrationIt.second.first,
186 nameAndRegistrationIt.first,
187 nameAndRegistrationIt.second.second);
189 for (
const auto &extension : extensions)
190 destination.extensions.try_emplace(extension.first,
191 extension.second->clone());
196 return llvm::map_range(
198 [](
const MapTy::value_type &item) -> StringRef {
return item.first; });
211 std::unique_ptr<DialectExtensionBase> extension) {
212 return extensions.try_emplace(extensionID, std::move(extension)).second;
216 template <
typename... ExtensionsT>
218 (
addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
232 template <
typename... DialectsT>
234 using ExtensionFnT = void (*)(
MLIRContext *, DialectsT *...);
237 Extension(
const Extension &) =
default;
238 Extension(ExtensionFnT extensionFn)
240 extensionFn(extensionFn) {}
241 ~Extension()
override =
default;
243 void apply(
MLIRContext *context, DialectsT *...dialects)
const final {
244 extensionFn(context, dialects...);
246 ExtensionFnT extensionFn;
249 reinterpret_cast<const void *
>(extensionFn)),
250 std::make_unique<Extension>(extensionFn));
259 llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
virtual ~DialectExtensionBase()
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtensions()
Add the given extensions to the registry.
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
This class provides an efficient unique identifier for a specific C++ type.
static TypeID getFromOpaquePointer(const void *pointer)
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction