13 #ifndef MLIR_IR_DIALECTREGISTRY_H 14 #define MLIR_IR_DIALECTREGISTRY_H 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringRef.h" 53 virtual std::unique_ptr<DialectExtensionBase>
clone()
const = 0;
59 : dialectNames(dialectNames.begin(), dialectNames.end()) {
60 assert(!dialectNames.empty() &&
"expected at least one affected dialect");
71 template <
typename DerivedT,
typename... DialectsT>
78 std::unique_ptr<DialectExtensionBase>
clone() const final {
79 return std::make_unique<DerivedT>(
static_cast<const DerivedT &
>(*this));
85 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
90 unsigned dialectIdx = 0;
91 auto derivedDialects = std::tuple<DialectsT *...>{
92 static_cast<DialectsT *
>(dialects[dialectIdx++])...};
93 std::apply([&](DialectsT *...dialect) {
apply(context, dialect...); },
109 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
114 template <
typename ConcreteDialect>
116 insert(TypeID::get<ConcreteDialect>(),
117 ConcreteDialect::getDialectNamespace(),
118 static_cast<DialectAllocatorFunction>(([](
MLIRContext *ctx) {
125 template <
typename ConcreteDialect,
typename OtherDialect,
126 typename... MoreDialects>
128 insert<ConcreteDialect>();
129 insert<OtherDialect, MoreDialects...>();
135 void insert(
TypeID typeID, StringRef name,
145 for (
const auto &nameAndRegistrationIt : registry)
146 destination.
insert(nameAndRegistrationIt.second.first,
147 nameAndRegistrationIt.first,
148 nameAndRegistrationIt.second.second);
150 for (
const auto &extension : extensions)
151 destination.extensions.push_back(extension->clone());
156 return llvm::map_range(
158 [](
const MapTy::value_type &item) -> StringRef {
return item.first; });
163 void applyExtensions(
Dialect *dialect)
const;
171 extensions.push_back(std::move(extension));
175 template <
typename... ExtensionsT>
177 (addExtension(std::make_unique<ExtensionsT>()), ...);
189 template <
typename... DialectsT>
191 addExtension<DialectsT...>(
194 template <
typename... DialectsT>
197 using ExtensionFnT = std::function<
void(
MLIRContext *, DialectsT * ...)>;
200 Extension(
const Extension &) =
default;
201 Extension(ExtensionFnT extensionFn)
202 : extensionFn(std::move(extensionFn)) {}
203 ~Extension()
override =
default;
206 extensionFn(context, dialects...);
208 ExtensionFnT extensionFn;
210 addExtension(std::make_unique<Extension>(std::move(extensionFn)));
219 std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
224 #endif // MLIR_IR_DIALECTREGISTRY_H Include the generated interface declarations.
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect *> dialects) const =0
Apply this extension to the given context and the required dialects.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
This class provides an efficient unique identifier for a specific C++ type.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
void addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
virtual ~DialectExtensionBase()
void addExtensions()
Add the given extensions to the registry.
void apply(MLIRContext *context, MutableArrayRef< Dialect *> dialects) const final
Override the base apply method to allow providing the exact dialect types.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This class represents a dialect extension anchored on the given set of dialects.
auto getDialectNames() const
Return the names of dialects known to this registry.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
void addExtension(std::function< void(MLIRContext *, DialectsT *...)> extensionFn)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendTo(DialectRegistry &destination) const