MLIR 22.0.0git
DialectRegistry.h
Go to the documentation of this file.
1//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines functionality for registring and extending dialects.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_DIALECTREGISTRY_H
14#define MLIR_IR_DIALECTREGISTRY_H
15
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/Support/TypeID.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/MapVector.h"
20
21#include <map>
22#include <tuple>
23
24namespace mlir {
25class Dialect;
26
27using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30 std::function<void(MLIRContext *, DynamicDialect *)>;
31
32//===----------------------------------------------------------------------===//
33// DialectExtension
34//===----------------------------------------------------------------------===//
35
36/// This class represents an opaque dialect extension. It contains a set of
37/// required dialects and an application function. The required dialects control
38/// when the extension is applied, i.e. the extension is applied when all
39/// required dialects are loaded. The application function can be used to attach
40/// additional functionality to attributes, dialects, operations, types, etc.,
41/// and may also load additional necessary dialects.
43public:
45
46 /// Return the dialects that our required by this extension to be loaded
47 /// before applying. If empty then the extension is invoked for every loaded
48 /// dialect indepently.
49 ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50
51 /// Apply this extension to the given context and the required dialects.
52 virtual void apply(MLIRContext *context,
53 MutableArrayRef<Dialect *> dialects) const = 0;
54
55 /// Return a copy of this extension.
56 virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57
58protected:
59 /// Initialize the extension with a set of required dialects.
60 /// If the list is empty, the extension is invoked for every loaded dialect
61 /// independently.
63 : dialectNames(dialectNames) {}
64
65private:
66 /// The names of the dialects affected by this extension.
67 SmallVector<StringRef> dialectNames;
68};
69
70/// This class represents a dialect extension anchored on the given set of
71/// dialects. When all of the specified dialects have been loaded, the
72/// application function of this extension will be executed.
73template <typename DerivedT, typename... DialectsT>
75public:
76 /// Applies this extension to the given context and set of required dialects.
77 virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78
79 /// Return a copy of this extension.
80 std::unique_ptr<DialectExtensionBase> clone() const final {
81 return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82 }
83
84protected:
87 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88
89 /// Override the base apply method to allow providing the exact dialect types.
90 void apply(MLIRContext *context,
91 MutableArrayRef<Dialect *> dialects) const final {
92 unsigned dialectIdx = 0;
93 auto derivedDialects = std::tuple<DialectsT *...>{
94 static_cast<DialectsT *>(dialects[dialectIdx++])...};
95 std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96 derivedDialects);
97 }
98};
99
101
102/// Checks if the given interface, which is attempting to be used, is a
103/// promised interface of this dialect that has yet to be implemented. If so,
104/// emits a fatal error.
106 TypeID interfaceRequestorID,
107 TypeID interfaceID,
108 StringRef interfaceName);
109
110/// Checks if the given interface, which is attempting to be attached, is a
111/// promised interface of this dialect that has yet to be implemented. If so,
112/// the promised interface is marked as resolved.
114 TypeID interfaceRequestorID,
115 TypeID interfaceID);
116
117/// Checks if a promise has been made for the interface/requestor pair.
118bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119 TypeID interfaceID);
120
121/// Checks if a promise has been made for the interface/requestor pair.
122template <typename ConcreteT, typename InterfaceT>
125 InterfaceT::getInterfaceID());
126}
127
128} // namespace dialect_extension_detail
129
130//===----------------------------------------------------------------------===//
131// DialectRegistry
132//===----------------------------------------------------------------------===//
133
134/// The DialectRegistry maps a dialect namespace to a constructor for the
135/// matching dialect. This allows for decoupling the list of dialects
136/// "available" from the dialects loaded in the Context. The parser in
137/// particular will lazily load dialects in the Context as operations are
138/// encountered.
140 using MapTy =
141 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
142 std::less<>>;
143
144public:
145 explicit DialectRegistry();
150
151 template <typename ConcreteDialect>
152 void insert() {
154 ConcreteDialect::getDialectNamespace(),
155 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
156 // Just allocate the dialect, the context
157 // takes ownership of it.
158 return ctx->getOrLoadDialect<ConcreteDialect>();
159 })));
160 }
161
162 template <typename ConcreteDialect, typename OtherDialect,
163 typename... MoreDialects>
164 void insert() {
166 insert<OtherDialect, MoreDialects...>();
167 }
168
169 /// Add a new dialect constructor to the registry. The constructor must be
170 /// calling MLIRContext::getOrLoadDialect in order for the context to take
171 /// ownership of the dialect and for delayed interface registration to happen.
172 void insert(TypeID typeID, StringRef name,
173 const DialectAllocatorFunction &ctor);
174
175 /// Add a new dynamic dialect constructor in the registry. The constructor
176 /// provides as argument the created dynamic dialect, and is expected to
177 /// register the dialect types, attributes, and ops, using the
178 /// methods defined in ExtensibleDialect such as registerDynamicOperation.
179 void insertDynamic(StringRef name,
181
182 /// Return an allocation function for constructing the dialect identified
183 /// by its namespace, or nullptr if the namespace is not in this registry.
185
186 // Register all dialects available in the current registry with the registry
187 // in the provided context.
188 void appendTo(DialectRegistry &destination) const {
189 for (const auto &nameAndRegistrationIt : registry)
190 destination.insert(nameAndRegistrationIt.second.first,
191 nameAndRegistrationIt.first,
192 nameAndRegistrationIt.second.second);
193 // Merge the extensions.
194 for (const auto &extension : extensions)
195 destination.extensions.try_emplace(extension.first,
196 extension.second->clone());
197 }
198
199 /// Return the names of dialects known to this registry.
200 auto getDialectNames() const {
201 return llvm::map_range(
202 registry,
203 [](const MapTy::value_type &item) -> StringRef { return item.first; });
204 }
205
206 /// Apply any held extensions that require the given dialect. Users are not
207 /// expected to call this directly.
208 void applyExtensions(Dialect *dialect) const;
209
210 /// Apply any applicable extensions to the given context. Users are not
211 /// expected to call this directly.
212 void applyExtensions(MLIRContext *ctx) const;
213
214 /// Add the given extension to the registry.
215 bool addExtension(TypeID extensionID,
216 std::unique_ptr<DialectExtensionBase> extension) {
217 return extensions.try_emplace(extensionID, std::move(extension)).second;
218 }
219
220 /// Add the given extensions to the registry.
221 template <typename... ExtensionsT>
223 (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
224 ...);
225 }
226
227 /// Add an extension function that requires the given dialects.
228 /// Note: This bare functor overload is provided in addition to the
229 /// std::function variant to enable dialect type deduction, e.g.:
230 /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
231 /// ... })
232 ///
233 /// is equivalent to:
234 /// registry.addExtension<MyDialect>(
235 /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
236 /// )
237 template <typename... DialectsT>
238 bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
239 using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
240
241 struct Extension : public DialectExtension<Extension, DialectsT...> {
242 Extension(const Extension &) = default;
243 Extension(ExtensionFnT extensionFn)
244 : DialectExtension<Extension, DialectsT...>(),
245 extensionFn(extensionFn) {}
246 ~Extension() override = default;
247
248 void apply(MLIRContext *context, DialectsT *...dialects) const final {
249 extensionFn(context, dialects...);
250 }
251 ExtensionFnT extensionFn;
252 };
254 reinterpret_cast<const void *>(extensionFn)),
255 std::make_unique<Extension>(extensionFn));
256 }
257
258 /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
259 /// contains all of the components of this registry.
260 bool isSubsetOf(const DialectRegistry &rhs) const;
261
262private:
263 MapTy registry;
264 llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
265};
266
267} // namespace mlir
268
269#endif // MLIR_IR_DIALECTREGISTRY_H
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition Dialect.cpp:320
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition Dialect.cpp:210
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition Dialect.cpp:228
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
DialectRegistry(const DialectRegistry &)=delete
void addExtensions()
Add the given extensions to the registry.
DialectRegistry & operator=(const DialectRegistry &other)=delete
DialectRegistry & operator=(DialectRegistry &&other)=default
DialectRegistry(DialectRegistry &&)=default
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition Dialect.cpp:246
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
static TypeID getFromOpaquePointer(const void *pointer)
Definition TypeID.h:135
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.cpp:163
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition Dialect.cpp:157
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.cpp:150
Include the generated interface declarations.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
function_ref< Dialect *(MLIRContext *)> DialectAllocatorFunctionRef
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152