MLIR  20.0.0git
DialectRegistry.h
Go to the documentation of this file.
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/MapVector.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30  std::function<void(MLIRContext *, DynamicDialect *)>;
31 
32 //===----------------------------------------------------------------------===//
33 // DialectExtension
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents an opaque dialect extension. It contains a set of
37 /// required dialects and an application function. The required dialects control
38 /// when the extension is applied, i.e. the extension is applied when all
39 /// required dialects are loaded. The application function can be used to attach
40 /// additional functionality to attributes, dialects, operations, types, etc.,
41 /// and may also load additional necessary dialects.
43 public:
45 
46  /// Return the dialects that our required by this extension to be loaded
47  /// before applying. If empty then the extension is invoked for every loaded
48  /// dialect indepently.
49  ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50 
51  /// Apply this extension to the given context and the required dialects.
52  virtual void apply(MLIRContext *context,
53  MutableArrayRef<Dialect *> dialects) const = 0;
54 
55  /// Return a copy of this extension.
56  virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57 
58 protected:
59  /// Initialize the extension with a set of required dialects.
60  /// If the list is empty, the extension is invoked for every loaded dialect
61  /// independently.
63  : dialectNames(dialectNames) {}
64 
65 private:
66  /// The names of the dialects affected by this extension.
67  SmallVector<StringRef> dialectNames;
68 };
69 
70 /// This class represents a dialect extension anchored on the given set of
71 /// dialects. When all of the specified dialects have been loaded, the
72 /// application function of this extension will be executed.
73 template <typename DerivedT, typename... DialectsT>
75 public:
76  /// Applies this extension to the given context and set of required dialects.
77  virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78 
79  /// Return a copy of this extension.
80  std::unique_ptr<DialectExtensionBase> clone() const final {
81  return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82  }
83 
84 protected:
87  ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88 
89  /// Override the base apply method to allow providing the exact dialect types.
90  void apply(MLIRContext *context,
91  MutableArrayRef<Dialect *> dialects) const final {
92  unsigned dialectIdx = 0;
93  auto derivedDialects = std::tuple<DialectsT *...>{
94  static_cast<DialectsT *>(dialects[dialectIdx++])...};
95  std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96  derivedDialects);
97  }
98 };
99 
100 namespace dialect_extension_detail {
101 
102 /// Checks if the given interface, which is attempting to be used, is a
103 /// promised interface of this dialect that has yet to be implemented. If so,
104 /// emits a fatal error.
106  TypeID interfaceRequestorID,
107  TypeID interfaceID,
108  StringRef interfaceName);
109 
110 /// Checks if the given interface, which is attempting to be attached, is a
111 /// promised interface of this dialect that has yet to be implemented. If so,
112 /// the promised interface is marked as resolved.
114  TypeID interfaceRequestorID,
115  TypeID interfaceID);
116 
117 /// Checks if a promise has been made for the interface/requestor pair.
118 bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119  TypeID interfaceID);
120 
121 /// Checks if a promise has been made for the interface/requestor pair.
122 template <typename ConcreteT, typename InterfaceT>
124  return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125  InterfaceT::getInterfaceID());
126 }
127 
128 } // namespace dialect_extension_detail
129 
130 //===----------------------------------------------------------------------===//
131 // DialectRegistry
132 //===----------------------------------------------------------------------===//
133 
134 /// The DialectRegistry maps a dialect namespace to a constructor for the
135 /// matching dialect. This allows for decoupling the list of dialects
136 /// "available" from the dialects loaded in the Context. The parser in
137 /// particular will lazily load dialects in the Context as operations are
138 /// encountered.
140  using MapTy =
141  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
142  std::less<>>;
143 
144 public:
145  explicit DialectRegistry();
146 
147  template <typename ConcreteDialect>
148  void insert() {
149  insert(TypeID::get<ConcreteDialect>(),
150  ConcreteDialect::getDialectNamespace(),
151  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
152  // Just allocate the dialect, the context
153  // takes ownership of it.
154  return ctx->getOrLoadDialect<ConcreteDialect>();
155  })));
156  }
157 
158  template <typename ConcreteDialect, typename OtherDialect,
159  typename... MoreDialects>
160  void insert() {
161  insert<ConcreteDialect>();
162  insert<OtherDialect, MoreDialects...>();
163  }
164 
165  /// Add a new dialect constructor to the registry. The constructor must be
166  /// calling MLIRContext::getOrLoadDialect in order for the context to take
167  /// ownership of the dialect and for delayed interface registration to happen.
168  void insert(TypeID typeID, StringRef name,
169  const DialectAllocatorFunction &ctor);
170 
171  /// Add a new dynamic dialect constructor in the registry. The constructor
172  /// provides as argument the created dynamic dialect, and is expected to
173  /// register the dialect types, attributes, and ops, using the
174  /// methods defined in ExtensibleDialect such as registerDynamicOperation.
175  void insertDynamic(StringRef name,
177 
178  /// Return an allocation function for constructing the dialect identified
179  /// by its namespace, or nullptr if the namespace is not in this registry.
180  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
181 
182  // Register all dialects available in the current registry with the registry
183  // in the provided context.
184  void appendTo(DialectRegistry &destination) const {
185  for (const auto &nameAndRegistrationIt : registry)
186  destination.insert(nameAndRegistrationIt.second.first,
187  nameAndRegistrationIt.first,
188  nameAndRegistrationIt.second.second);
189  // Merge the extensions.
190  for (const auto &extension : extensions)
191  destination.extensions.try_emplace(extension.first,
192  extension.second->clone());
193  }
194 
195  /// Return the names of dialects known to this registry.
196  auto getDialectNames() const {
197  return llvm::map_range(
198  registry,
199  [](const MapTy::value_type &item) -> StringRef { return item.first; });
200  }
201 
202  /// Apply any held extensions that require the given dialect. Users are not
203  /// expected to call this directly.
204  void applyExtensions(Dialect *dialect) const;
205 
206  /// Apply any applicable extensions to the given context. Users are not
207  /// expected to call this directly.
208  void applyExtensions(MLIRContext *ctx) const;
209 
210  /// Add the given extension to the registry.
211  bool addExtension(TypeID extensionID,
212  std::unique_ptr<DialectExtensionBase> extension) {
213  return extensions.try_emplace(extensionID, std::move(extension)).second;
214  }
215 
216  /// Add the given extensions to the registry.
217  template <typename... ExtensionsT>
218  void addExtensions() {
219  (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
220  ...);
221  }
222 
223  /// Add an extension function that requires the given dialects.
224  /// Note: This bare functor overload is provided in addition to the
225  /// std::function variant to enable dialect type deduction, e.g.:
226  /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
227  /// ... })
228  ///
229  /// is equivalent to:
230  /// registry.addExtension<MyDialect>(
231  /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
232  /// )
233  template <typename... DialectsT>
234  bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
235  using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
236 
237  struct Extension : public DialectExtension<Extension, DialectsT...> {
238  Extension(const Extension &) = default;
239  Extension(ExtensionFnT extensionFn)
240  : DialectExtension<Extension, DialectsT...>(),
241  extensionFn(extensionFn) {}
242  ~Extension() override = default;
243 
244  void apply(MLIRContext *context, DialectsT *...dialects) const final {
245  extensionFn(context, dialects...);
246  }
247  ExtensionFnT extensionFn;
248  };
250  reinterpret_cast<const void *>(extensionFn)),
251  std::make_unique<Extension>(extensionFn));
252  }
253 
254  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
255  /// contains all of the components of this registry.
256  bool isSubsetOf(const DialectRegistry &rhs) const;
257 
258 private:
259  MapTy registry;
260  llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
261 };
262 
263 } // namespace mlir
264 
265 #endif // MLIR_IR_DIALECTREGISTRY_H
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:329
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:219
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:237
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtensions()
Add the given extensions to the registry.
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:255
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:132
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:172
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:166
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:159
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction