MLIR 23.0.0git
DialectRegistry.h
Go to the documentation of this file.
1//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines functionality for registring and extending dialects.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_DIALECTREGISTRY_H
14#define MLIR_IR_DIALECTREGISTRY_H
15
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/Support/TypeID.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/LogicalResult.h"
22
23#include <map>
24#include <string>
25#include <tuple>
26
27namespace mlir {
28class Dialect;
29
30using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
33 std::function<void(MLIRContext *, DynamicDialect *)>;
34
35//===----------------------------------------------------------------------===//
36// DialectExtension
37//===----------------------------------------------------------------------===//
38
39/// This class represents an opaque dialect extension. It contains a set of
40/// required dialects and an application function. The required dialects control
41/// when the extension is applied, i.e. the extension is applied when all
42/// required dialects are loaded. The application function can be used to attach
43/// additional functionality to attributes, dialects, operations, types, etc.,
44/// and may also load additional necessary dialects.
46public:
48
49 /// Return the dialects that our required by this extension to be loaded
50 /// before applying. If empty then the extension is invoked for every loaded
51 /// dialect independently.
52 ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
53
54 /// Apply this extension to the given context and the required dialects.
55 virtual void apply(MLIRContext *context,
56 MutableArrayRef<Dialect *> dialects) const = 0;
57
58 /// Return a copy of this extension.
59 virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
60
61protected:
62 /// Initialize the extension with a set of required dialects.
63 /// If the list is empty, the extension is invoked for every loaded dialect
64 /// independently.
66 : dialectNames(dialectNames) {}
67
68private:
69 /// The names of the dialects affected by this extension.
70 SmallVector<StringRef> dialectNames;
71};
72
73/// This class represents a dialect extension anchored on the given set of
74/// dialects. When all of the specified dialects have been loaded, the
75/// application function of this extension will be executed.
76template <typename DerivedT, typename... DialectsT>
78public:
79 /// Applies this extension to the given context and set of required dialects.
80 virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
81
82 /// Return a copy of this extension.
83 std::unique_ptr<DialectExtensionBase> clone() const final {
84 return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
85 }
86
87protected:
90 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
91
92 /// Override the base apply method to allow providing the exact dialect types.
93 void apply(MLIRContext *context,
94 MutableArrayRef<Dialect *> dialects) const final {
95 unsigned dialectIdx = 0;
96 auto derivedDialects = std::tuple<DialectsT *...>{
97 static_cast<DialectsT *>(dialects[dialectIdx++])...};
98 std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
99 derivedDialects);
100 }
101};
102
104
105/// Checks if the given interface, which is attempting to be used, is a
106/// promised interface of this dialect that has yet to be implemented. If so,
107/// emits a fatal error.
109 TypeID interfaceRequestorID,
110 TypeID interfaceID,
111 StringRef interfaceName);
112
113/// Checks if the given interface, which is attempting to be attached, is a
114/// promised interface of this dialect that has yet to be implemented. If so,
115/// the promised interface is marked as resolved.
117 TypeID interfaceRequestorID,
118 TypeID interfaceID);
119
120/// Checks if a promise has been made for the interface/requestor pair.
121bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
122 TypeID interfaceID);
123
124/// Checks if a promise has been made for the interface/requestor pair.
125template <typename ConcreteT, typename InterfaceT>
128 InterfaceT::getInterfaceID());
129}
130
131} // namespace dialect_extension_detail
132
133//===----------------------------------------------------------------------===//
134// DialectRegistry
135//===----------------------------------------------------------------------===//
136
137/// The DialectRegistry maps a dialect namespace to a constructor for the
138/// matching dialect. This allows for decoupling the list of dialects
139/// "available" from the dialects loaded in the Context. The parser in
140/// particular will lazily load dialects in the Context as operations are
141/// encountered.
142///
143/// In addition to allocator-backed registrations, the registry can also carry
144/// a set of dialect *names* that some caller has asked to be preloaded into
145/// the context (see `addDialectToPreload(StringRef)`). The registry itself
146/// does not load those dialects — it merely records the request; the
147/// allocator is expected to live in the MLIRContext's own registry, and
148/// actually loading them is the caller's responsibility via
149/// `preloadSelectDialects(MLIRContext *)`.
151 using MapTy =
152 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
153 std::less<>>;
154
155public:
156 explicit DialectRegistry();
161
162 template <typename ConcreteDialect>
163 void insert() {
165 ConcreteDialect::getDialectNamespace(),
166 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
167 // Just allocate the dialect, the context
168 // takes ownership of it.
169 return ctx->getOrLoadDialect<ConcreteDialect>();
170 })));
171 }
172
173 template <typename ConcreteDialect, typename OtherDialect,
174 typename... MoreDialects>
175 void insert() {
177 insert<OtherDialect, MoreDialects...>();
178 }
179
180 /// Add a new dialect constructor to the registry. The constructor must be
181 /// calling MLIRContext::getOrLoadDialect in order for the context to take
182 /// ownership of the dialect and for delayed interface registration to happen.
183 void insert(TypeID typeID, StringRef name,
184 const DialectAllocatorFunction &ctor);
185
186 /// Request that the dialect with the given name be preloaded into the
187 /// MLIRContext, without providing an allocator. Useful when a caller knows a
188 /// dialect is required but expects its allocator to be available in the
189 /// MLIRContext's own registry at load time (e.g. a pass learning dialect
190 /// names from string-valued options).
191 void addDialectToPreload(StringRef name);
192
193 /// Add a new dynamic dialect constructor in the registry. The constructor
194 /// provides as argument the created dynamic dialect, and is expected to
195 /// register the dialect types, attributes, and ops, using the
196 /// methods defined in ExtensibleDialect such as registerDynamicOperation.
197 void insertDynamic(StringRef name,
199
200 /// Return an allocation function for constructing the dialect identified
201 /// by its namespace, or nullptr if the namespace is not in this registry.
203
204 // Register all dialects available in the current registry with the registry
205 // in the provided context.
206 void appendTo(DialectRegistry &destination) const {
207 for (const auto &nameAndRegistrationIt : registry)
208 destination.insert(nameAndRegistrationIt.second.first,
209 nameAndRegistrationIt.first,
210 nameAndRegistrationIt.second.second);
211 for (const std::string &name : dialectsToPreload)
212 destination.addDialectToPreload(StringRef(name));
213 // Merge the extensions.
214 for (const auto &extension : extensions)
215 destination.extensions.try_emplace(extension.first,
216 extension.second->clone());
217 }
218
219 /// Return the names of dialects registered in this registry with an
220 /// allocator function. Does not include preload-only entries added via
221 /// `addDialectToPreload(StringRef)` — use `getDialectsToPreload()` for those.
224 names.reserve(registry.size());
225 for (const auto &item : registry)
226 names.push_back(item.first);
227 return names;
228 }
229
230 /// Return the names of dialects that should be preloaded into the context
231 /// but whose allocator is expected to be resolved from the context's own
232 /// registry (added via `addDialectToPreload(StringRef)`).
234 return dialectsToPreload;
235 }
236
237 /// Load into `ctx` every dialect previously added via
238 /// `addDialectToPreload(StringRef)`. The allocator is resolved from the
239 /// context's own registry. On failure, if `emitError` is provided, it is
240 /// invoked to produce a diagnostic naming the offending dialect; otherwise
241 /// the failure is silent.
242 LogicalResult preloadSelectDialects(
243 MLIRContext *ctx,
245
246 /// Apply any held extensions that require the given dialect. Users are not
247 /// expected to call this directly.
248 void applyExtensions(Dialect *dialect) const;
249
250 /// Apply any applicable extensions to the given context. Users are not
251 /// expected to call this directly.
252 void applyExtensions(MLIRContext *ctx) const;
253
254 /// Add the given extension to the registry.
255 bool addExtension(TypeID extensionID,
256 std::unique_ptr<DialectExtensionBase> extension) {
257 return extensions.try_emplace(extensionID, std::move(extension)).second;
258 }
259
260 /// Add the given extensions to the registry.
261 template <typename... ExtensionsT>
263 (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
264 ...);
265 }
266
267 /// Add an extension function that requires the given dialects.
268 /// Note: This bare functor overload is provided in addition to the
269 /// std::function variant to enable dialect type deduction, e.g.:
270 /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
271 /// ... })
272 ///
273 /// is equivalent to:
274 /// registry.addExtension<MyDialect>(
275 /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
276 /// )
277 template <typename... DialectsT>
278 bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
279 using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
280
281 struct Extension : public DialectExtension<Extension, DialectsT...> {
282 Extension(const Extension &) = default;
283 Extension(ExtensionFnT extensionFn)
284 : DialectExtension<Extension, DialectsT...>(),
285 extensionFn(extensionFn) {}
286 ~Extension() override = default;
287
288 void apply(MLIRContext *context, DialectsT *...dialects) const final {
289 extensionFn(context, dialects...);
290 }
291 ExtensionFnT extensionFn;
292 };
294 reinterpret_cast<const void *>(extensionFn)),
295 std::make_unique<Extension>(extensionFn));
296 }
297
298 /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
299 /// contains all of the components of this registry.
300 bool isSubsetOf(const DialectRegistry &rhs) const;
301
302private:
303 MapTy registry;
304 /// Names of dialects that should be preloaded into the MLIRContext but for
305 /// which no allocator has been registered here. The allocator is expected
306 /// to be resolved from the MLIRContext's own registry when the dialect is
307 /// loaded (e.g. via MLIRContext::getOrLoadDialect).
308 SmallVector<std::string> dialectsToPreload;
309 llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
310};
311
312} // namespace mlir
313
314#endif // MLIR_IR_DIALECTREGISTRY_H
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition Dialect.cpp:343
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition Dialect.cpp:210
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition Dialect.cpp:251
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
SmallVector< StringRef > getRegisteredDialectNames() const
Return the names of dialects registered in this registry with an allocator function.
ArrayRef< std::string > getDialectsToPreload() const
Return the names of dialects that should be preloaded into the context but whose allocator is expecte...
LogicalResult preloadSelectDialects(MLIRContext *ctx, function_ref< InFlightDiagnostic()> emitError={}) const
Load into ctx every dialect previously added via addDialectToPreload(StringRef).
Definition Dialect.cpp:228
DialectRegistry(const DialectRegistry &)=delete
void addExtensions()
Add the given extensions to the registry.
DialectRegistry & operator=(const DialectRegistry &other)=delete
DialectRegistry & operator=(DialectRegistry &&other)=default
void addDialectToPreload(StringRef name)
Request that the dialect with the given name be preloaded into the MLIRContext, without providing an ...
Definition Dialect.cpp:241
DialectRegistry(DialectRegistry &&)=default
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition Dialect.cpp:269
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A dialect that can be defined at runtime.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
static TypeID getFromOpaquePointer(const void *pointer)
Definition TypeID.h:135
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.cpp:163
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition Dialect.cpp:157
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.cpp:150
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
function_ref< Dialect *(MLIRContext *)> DialectAllocatorFunctionRef
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147