MLIR  20.0.0git
DialectRegistry.h
Go to the documentation of this file.
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/MapVector.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30  std::function<void(MLIRContext *, DynamicDialect *)>;
31 
32 //===----------------------------------------------------------------------===//
33 // DialectExtension
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents an opaque dialect extension. It contains a set of
37 /// required dialects and an application function. The required dialects control
38 /// when the extension is applied, i.e. the extension is applied when all
39 /// required dialects are loaded. The application function can be used to attach
40 /// additional functionality to attributes, dialects, operations, types, etc.,
41 /// and may also load additional necessary dialects.
43 public:
45 
46  /// Return the dialects that our required by this extension to be loaded
47  /// before applying. If empty then the extension is invoked for every loaded
48  /// dialect indepently.
49  ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50 
51  /// Apply this extension to the given context and the required dialects.
52  virtual void apply(MLIRContext *context,
53  MutableArrayRef<Dialect *> dialects) const = 0;
54 
55  /// Return a copy of this extension.
56  virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57 
58 protected:
59  /// Initialize the extension with a set of required dialects.
60  /// If the list is empty, the extension is invoked for every loaded dialect
61  /// independently.
63  : dialectNames(dialectNames) {}
64 
65 private:
66  /// The names of the dialects affected by this extension.
67  SmallVector<StringRef> dialectNames;
68 };
69 
70 /// This class represents a dialect extension anchored on the given set of
71 /// dialects. When all of the specified dialects have been loaded, the
72 /// application function of this extension will be executed.
73 template <typename DerivedT, typename... DialectsT>
75 public:
76  /// Applies this extension to the given context and set of required dialects.
77  virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78 
79  /// Return a copy of this extension.
80  std::unique_ptr<DialectExtensionBase> clone() const final {
81  return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82  }
83 
84 protected:
87  ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88 
89  /// Override the base apply method to allow providing the exact dialect types.
90  void apply(MLIRContext *context,
91  MutableArrayRef<Dialect *> dialects) const final {
92  unsigned dialectIdx = 0;
93  auto derivedDialects = std::tuple<DialectsT *...>{
94  static_cast<DialectsT *>(dialects[dialectIdx++])...};
95  std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96  derivedDialects);
97  }
98 };
99 
100 namespace dialect_extension_detail {
101 
102 /// Checks if the given interface, which is attempting to be used, is a
103 /// promised interface of this dialect that has yet to be implemented. If so,
104 /// emits a fatal error.
106  TypeID interfaceRequestorID,
107  TypeID interfaceID,
108  StringRef interfaceName);
109 
110 /// Checks if the given interface, which is attempting to be attached, is a
111 /// promised interface of this dialect that has yet to be implemented. If so,
112 /// the promised interface is marked as resolved.
114  TypeID interfaceRequestorID,
115  TypeID interfaceID);
116 
117 /// Checks if a promise has been made for the interface/requestor pair.
118 bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119  TypeID interfaceID);
120 
121 /// Checks if a promise has been made for the interface/requestor pair.
122 template <typename ConcreteT, typename InterfaceT>
124  return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125  InterfaceT::getInterfaceID());
126 }
127 
128 } // namespace dialect_extension_detail
129 
130 //===----------------------------------------------------------------------===//
131 // DialectRegistry
132 //===----------------------------------------------------------------------===//
133 
134 /// The DialectRegistry maps a dialect namespace to a constructor for the
135 /// matching dialect. This allows for decoupling the list of dialects
136 /// "available" from the dialects loaded in the Context. The parser in
137 /// particular will lazily load dialects in the Context as operations are
138 /// encountered.
140  using MapTy =
141  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
142 
143 public:
144  explicit DialectRegistry();
145 
146  template <typename ConcreteDialect>
147  void insert() {
148  insert(TypeID::get<ConcreteDialect>(),
149  ConcreteDialect::getDialectNamespace(),
150  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
151  // Just allocate the dialect, the context
152  // takes ownership of it.
153  return ctx->getOrLoadDialect<ConcreteDialect>();
154  })));
155  }
156 
157  template <typename ConcreteDialect, typename OtherDialect,
158  typename... MoreDialects>
159  void insert() {
160  insert<ConcreteDialect>();
161  insert<OtherDialect, MoreDialects...>();
162  }
163 
164  /// Add a new dialect constructor to the registry. The constructor must be
165  /// calling MLIRContext::getOrLoadDialect in order for the context to take
166  /// ownership of the dialect and for delayed interface registration to happen.
167  void insert(TypeID typeID, StringRef name,
168  const DialectAllocatorFunction &ctor);
169 
170  /// Add a new dynamic dialect constructor in the registry. The constructor
171  /// provides as argument the created dynamic dialect, and is expected to
172  /// register the dialect types, attributes, and ops, using the
173  /// methods defined in ExtensibleDialect such as registerDynamicOperation.
174  void insertDynamic(StringRef name,
176 
177  /// Return an allocation function for constructing the dialect identified
178  /// by its namespace, or nullptr if the namespace is not in this registry.
179  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
180 
181  // Register all dialects available in the current registry with the registry
182  // in the provided context.
183  void appendTo(DialectRegistry &destination) const {
184  for (const auto &nameAndRegistrationIt : registry)
185  destination.insert(nameAndRegistrationIt.second.first,
186  nameAndRegistrationIt.first,
187  nameAndRegistrationIt.second.second);
188  // Merge the extensions.
189  for (const auto &extension : extensions)
190  destination.extensions.try_emplace(extension.first,
191  extension.second->clone());
192  }
193 
194  /// Return the names of dialects known to this registry.
195  auto getDialectNames() const {
196  return llvm::map_range(
197  registry,
198  [](const MapTy::value_type &item) -> StringRef { return item.first; });
199  }
200 
201  /// Apply any held extensions that require the given dialect. Users are not
202  /// expected to call this directly.
203  void applyExtensions(Dialect *dialect) const;
204 
205  /// Apply any applicable extensions to the given context. Users are not
206  /// expected to call this directly.
207  void applyExtensions(MLIRContext *ctx) const;
208 
209  /// Add the given extension to the registry.
210  bool addExtension(TypeID extensionID,
211  std::unique_ptr<DialectExtensionBase> extension) {
212  return extensions.try_emplace(extensionID, std::move(extension)).second;
213  }
214 
215  /// Add the given extensions to the registry.
216  template <typename... ExtensionsT>
217  void addExtensions() {
218  (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
219  ...);
220  }
221 
222  /// Add an extension function that requires the given dialects.
223  /// Note: This bare functor overload is provided in addition to the
224  /// std::function variant to enable dialect type deduction, e.g.:
225  /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
226  /// ... })
227  ///
228  /// is equivalent to:
229  /// registry.addExtension<MyDialect>(
230  /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
231  /// )
232  template <typename... DialectsT>
233  bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
234  using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
235 
236  struct Extension : public DialectExtension<Extension, DialectsT...> {
237  Extension(const Extension &) = default;
238  Extension(ExtensionFnT extensionFn)
239  : DialectExtension<Extension, DialectsT...>(),
240  extensionFn(extensionFn) {}
241  ~Extension() override = default;
242 
243  void apply(MLIRContext *context, DialectsT *...dialects) const final {
244  extensionFn(context, dialects...);
245  }
246  ExtensionFnT extensionFn;
247  };
249  reinterpret_cast<const void *>(extensionFn)),
250  std::make_unique<Extension>(extensionFn));
251  }
252 
253  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
254  /// contains all of the components of this registry.
255  bool isSubsetOf(const DialectRegistry &rhs) const;
256 
257 private:
258  MapTy registry;
259  llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
260 };
261 
262 } // namespace mlir
263 
264 #endif // MLIR_IR_DIALECTREGISTRY_H
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:329
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:219
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:237
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtensions()
Add the given extensions to the registry.
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:255
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
static TypeID getFromOpaquePointer(const void *pointer)
Definition: TypeID.h:132
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:172
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:166
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:159
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction