MLIR  21.0.0git
DialectRegistry.h
Go to the documentation of this file.
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/MapVector.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30  std::function<void(MLIRContext *, DynamicDialect *)>;
31 
32 //===----------------------------------------------------------------------===//
33 // DialectExtension
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents an opaque dialect extension. It contains a set of
37 /// required dialects and an application function. The required dialects control
38 /// when the extension is applied, i.e. the extension is applied when all
39 /// required dialects are loaded. The application function can be used to attach
40 /// additional functionality to attributes, dialects, operations, types, etc.,
41 /// and may also load additional necessary dialects.
43 public:
45 
46  /// Return the dialects that our required by this extension to be loaded
47  /// before applying. If empty then the extension is invoked for every loaded
48  /// dialect indepently.
49  ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50 
51  /// Apply this extension to the given context and the required dialects.
52  virtual void apply(MLIRContext *context,
53  MutableArrayRef<Dialect *> dialects) const = 0;
54 
55  /// Return a copy of this extension.
56  virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57 
58 protected:
59  /// Initialize the extension with a set of required dialects.
60  /// If the list is empty, the extension is invoked for every loaded dialect
61  /// independently.
63  : dialectNames(dialectNames) {}
64 
65 private:
66  /// The names of the dialects affected by this extension.
67  SmallVector<StringRef> dialectNames;
68 };
69 
70 /// This class represents a dialect extension anchored on the given set of
71 /// dialects. When all of the specified dialects have been loaded, the
72 /// application function of this extension will be executed.
73 template <typename DerivedT, typename... DialectsT>
75 public:
76  /// Applies this extension to the given context and set of required dialects.
77  virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78 
79  /// Return a copy of this extension.
80  std::unique_ptr<DialectExtensionBase> clone() const final {
81  return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82  }
83 
84 protected:
87  ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88 
89  /// Override the base apply method to allow providing the exact dialect types.
90  void apply(MLIRContext *context,
91  MutableArrayRef<Dialect *> dialects) const final {
92  unsigned dialectIdx = 0;
93  auto derivedDialects = std::tuple<DialectsT *...>{
94  static_cast<DialectsT *>(dialects[dialectIdx++])...};
95  std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96  derivedDialects);
97  }
98 };
99 
100 namespace dialect_extension_detail {
101 
102 /// Checks if the given interface, which is attempting to be used, is a
103 /// promised interface of this dialect that has yet to be implemented. If so,
104 /// emits a fatal error.
106  TypeID interfaceRequestorID,
107  TypeID interfaceID,
108  StringRef interfaceName);
109 
110 /// Checks if the given interface, which is attempting to be attached, is a
111 /// promised interface of this dialect that has yet to be implemented. If so,
112 /// the promised interface is marked as resolved.
114  TypeID interfaceRequestorID,
115  TypeID interfaceID);
116 
117 /// Checks if a promise has been made for the interface/requestor pair.
118 bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119  TypeID interfaceID);
120 
121 /// Checks if a promise has been made for the interface/requestor pair.
122 template <typename ConcreteT, typename InterfaceT>
124  return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125  InterfaceT::getInterfaceID());
126 }
127 
128 } // namespace dialect_extension_detail
129 
130 //===----------------------------------------------------------------------===//
131 // DialectRegistry
132 //===----------------------------------------------------------------------===//
133 
134 /// The DialectRegistry maps a dialect namespace to a constructor for the
135 /// matching dialect. This allows for decoupling the list of dialects
136 /// "available" from the dialects loaded in the Context. The parser in
137 /// particular will lazily load dialects in the Context as operations are
138 /// encountered.
140  using MapTy =
141  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
142  std::less<>>;
143 
144 public:
145  explicit DialectRegistry();
146  DialectRegistry(const DialectRegistry &) = delete;
147  DialectRegistry &operator=(const DialectRegistry &other) = delete;
150 
151  template <typename ConcreteDialect>
152  void insert() {
153  insert(TypeID::get<ConcreteDialect>(),
154  ConcreteDialect::getDialectNamespace(),
155  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
156  // Just allocate the dialect, the context
157  // takes ownership of it.
158  return ctx->getOrLoadDialect<ConcreteDialect>();
159  })));
160  }
161 
162  template <typename ConcreteDialect, typename OtherDialect,
163  typename... MoreDialects>
164  void insert() {
165  insert<ConcreteDialect>();
166  insert<OtherDialect, MoreDialects...>();
167  }
168 
169  /// Add a new dialect constructor to the registry. The constructor must be
170  /// calling MLIRContext::getOrLoadDialect in order for the context to take
171  /// ownership of the dialect and for delayed interface registration to happen.
172  void insert(TypeID typeID, StringRef name,
173  const DialectAllocatorFunction &ctor);
174 
175  /// Add a new dynamic dialect constructor in the registry. The constructor
176  /// provides as argument the created dynamic dialect, and is expected to
177  /// register the dialect types, attributes, and ops, using the
178  /// methods defined in ExtensibleDialect such as registerDynamicOperation.
179  void insertDynamic(StringRef name,
181 
182  /// Return an allocation function for constructing the dialect identified
183  /// by its namespace, or nullptr if the namespace is not in this registry.
184  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
185 
186  // Register all dialects available in the current registry with the registry
187  // in the provided context.
188  void appendTo(DialectRegistry &destination) const {
189  for (const auto &nameAndRegistrationIt : registry)
190  destination.insert(nameAndRegistrationIt.second.first,
191  nameAndRegistrationIt.first,
192  nameAndRegistrationIt.second.second);
193  // Merge the extensions.
194  for (const auto &extension : extensions)
195  destination.extensions.try_emplace(extension.first,
196  extension.second->clone());
197  }
198 
199  /// Return the names of dialects known to this registry.
200  auto getDialectNames() const {
201  return llvm::map_range(
202  registry,
203  [](const MapTy::value_type &item) -> StringRef { return item.first; });
204  }
205 
206  /// Apply any held extensions that require the given dialect. Users are not
207  /// expected to call this directly.
208  void applyExtensions(Dialect *dialect) const;
209 
210  /// Apply any applicable extensions to the given context. Users are not
211  /// expected to call this directly.
212  void applyExtensions(MLIRContext *ctx) const;
213 
214  /// Add the given extension to the registry.
215  bool addExtension(TypeID extensionID,
216  std::unique_ptr<DialectExtensionBase> extension) {
217  return extensions.try_emplace(extensionID, std::move(extension)).second;
218  }
219 
220  /// Add the given extensions to the registry.
221  template <typename... ExtensionsT>
222  void addExtensions() {
223  (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
224  ...);
225  }
226 
227  /// Add an extension function that requires the given dialects.
228  /// Note: This bare functor overload is provided in addition to the
229  /// std::function variant to enable dialect type deduction, e.g.:
230  /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
231  /// ... })
232  ///
233  /// is equivalent to:
234  /// registry.addExtension<MyDialect>(
235  /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
236  /// )
237  template <typename... DialectsT>
238  bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
239  using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
240 
241  struct Extension : public DialectExtension<Extension, DialectsT...> {
242  Extension(const Extension &) = default;
243  Extension(ExtensionFnT extensionFn)
244  : DialectExtension<Extension, DialectsT...>(),
245  extensionFn(extensionFn) {}
246  ~Extension() override = default;
247 
248  void apply(MLIRContext *context, DialectsT *...dialects) const final {
249  extensionFn(context, dialects...);
250  }
251  ExtensionFnT extensionFn;
252  };
254  reinterpret_cast<const void *>(extensionFn)),
255  std::make_unique<Extension>(extensionFn));
256  }
257 
258  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
259  /// contains all of the components of this registry.
260  bool isSubsetOf(const DialectRegistry &rhs) const;
261 
262 private:
263  MapTy registry;
264  llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
265 };
266 
267 } // namespace mlir
268 
269 #endif // MLIR_IR_DIALECTREGISTRY_H
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:329
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:219
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:237
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
DialectRegistry & operator=(const DialectRegistry &other)=delete
DialectRegistry(const DialectRegistry &)=delete
DialectRegistry & operator=(DialectRegistry &&other)=default
void addExtensions()
Add the given extensions to the registry.
DialectRegistry(DialectRegistry &&)=default
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:255
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:135
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:172
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:166
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:159
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction