MLIR  19.0.0git
DialectRegistry.h
Go to the documentation of this file.
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30  std::function<void(MLIRContext *, DynamicDialect *)>;
31 
32 //===----------------------------------------------------------------------===//
33 // DialectExtension
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents an opaque dialect extension. It contains a set of
37 /// required dialects and an application function. The required dialects control
38 /// when the extension is applied, i.e. the extension is applied when all
39 /// required dialects are loaded. The application function can be used to attach
40 /// additional functionality to attributes, dialects, operations, types, etc.,
41 /// and may also load additional necessary dialects.
43 public:
45 
46  /// Return the dialects that our required by this extension to be loaded
47  /// before applying. If empty then the extension is invoked for every loaded
48  /// dialect indepently.
49  ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50 
51  /// Apply this extension to the given context and the required dialects.
52  virtual void apply(MLIRContext *context,
53  MutableArrayRef<Dialect *> dialects) const = 0;
54 
55  /// Return a copy of this extension.
56  virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57 
58 protected:
59  /// Initialize the extension with a set of required dialects.
60  /// If the list is empty, the extension is invoked for every loaded dialect
61  /// independently.
63  : dialectNames(dialectNames.begin(), dialectNames.end()) {}
64 
65 private:
66  /// The names of the dialects affected by this extension.
67  SmallVector<StringRef> dialectNames;
68 };
69 
70 /// This class represents a dialect extension anchored on the given set of
71 /// dialects. When all of the specified dialects have been loaded, the
72 /// application function of this extension will be executed.
73 template <typename DerivedT, typename... DialectsT>
75 public:
76  /// Applies this extension to the given context and set of required dialects.
77  virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78 
79  /// Return a copy of this extension.
80  std::unique_ptr<DialectExtensionBase> clone() const final {
81  return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82  }
83 
84 protected:
87  ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88 
89  /// Override the base apply method to allow providing the exact dialect types.
90  void apply(MLIRContext *context,
91  MutableArrayRef<Dialect *> dialects) const final {
92  unsigned dialectIdx = 0;
93  auto derivedDialects = std::tuple<DialectsT *...>{
94  static_cast<DialectsT *>(dialects[dialectIdx++])...};
95  std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96  derivedDialects);
97  }
98 };
99 
100 namespace dialect_extension_detail {
101 
102 /// Checks if the given interface, which is attempting to be used, is a
103 /// promised interface of this dialect that has yet to be implemented. If so,
104 /// emits a fatal error.
106  TypeID interfaceRequestorID,
107  TypeID interfaceID,
108  StringRef interfaceName);
109 
110 /// Checks if the given interface, which is attempting to be attached, is a
111 /// promised interface of this dialect that has yet to be implemented. If so,
112 /// the promised interface is marked as resolved.
114  TypeID interfaceRequestorID,
115  TypeID interfaceID);
116 
117 /// Checks if a promise has been made for the interface/requestor pair.
118 bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119  TypeID interfaceID);
120 
121 /// Checks if a promise has been made for the interface/requestor pair.
122 template <typename ConcreteT, typename InterfaceT>
124  return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125  InterfaceT::getInterfaceID());
126 }
127 
128 } // namespace dialect_extension_detail
129 
130 //===----------------------------------------------------------------------===//
131 // DialectRegistry
132 //===----------------------------------------------------------------------===//
133 
134 /// The DialectRegistry maps a dialect namespace to a constructor for the
135 /// matching dialect. This allows for decoupling the list of dialects
136 /// "available" from the dialects loaded in the Context. The parser in
137 /// particular will lazily load dialects in the Context as operations are
138 /// encountered.
140  using MapTy =
141  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
142 
143 public:
144  explicit DialectRegistry();
145 
146  template <typename ConcreteDialect>
147  void insert() {
148  insert(TypeID::get<ConcreteDialect>(),
149  ConcreteDialect::getDialectNamespace(),
150  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
151  // Just allocate the dialect, the context
152  // takes ownership of it.
153  return ctx->getOrLoadDialect<ConcreteDialect>();
154  })));
155  }
156 
157  template <typename ConcreteDialect, typename OtherDialect,
158  typename... MoreDialects>
159  void insert() {
160  insert<ConcreteDialect>();
161  insert<OtherDialect, MoreDialects...>();
162  }
163 
164  /// Add a new dialect constructor to the registry. The constructor must be
165  /// calling MLIRContext::getOrLoadDialect in order for the context to take
166  /// ownership of the dialect and for delayed interface registration to happen.
167  void insert(TypeID typeID, StringRef name,
168  const DialectAllocatorFunction &ctor);
169 
170  /// Add a new dynamic dialect constructor in the registry. The constructor
171  /// provides as argument the created dynamic dialect, and is expected to
172  /// register the dialect types, attributes, and ops, using the
173  /// methods defined in ExtensibleDialect such as registerDynamicOperation.
174  void insertDynamic(StringRef name,
176 
177  /// Return an allocation function for constructing the dialect identified
178  /// by its namespace, or nullptr if the namespace is not in this registry.
179  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
180 
181  // Register all dialects available in the current registry with the registry
182  // in the provided context.
183  void appendTo(DialectRegistry &destination) const {
184  for (const auto &nameAndRegistrationIt : registry)
185  destination.insert(nameAndRegistrationIt.second.first,
186  nameAndRegistrationIt.first,
187  nameAndRegistrationIt.second.second);
188  // Merge the extensions.
189  for (const auto &extension : extensions)
190  destination.extensions.push_back(extension->clone());
191  }
192 
193  /// Return the names of dialects known to this registry.
194  auto getDialectNames() const {
195  return llvm::map_range(
196  registry,
197  [](const MapTy::value_type &item) -> StringRef { return item.first; });
198  }
199 
200  /// Apply any held extensions that require the given dialect. Users are not
201  /// expected to call this directly.
202  void applyExtensions(Dialect *dialect) const;
203 
204  /// Apply any applicable extensions to the given context. Users are not
205  /// expected to call this directly.
206  void applyExtensions(MLIRContext *ctx) const;
207 
208  /// Add the given extension to the registry.
209  void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
210  extensions.push_back(std::move(extension));
211  }
212 
213  /// Add the given extensions to the registry.
214  template <typename... ExtensionsT>
215  void addExtensions() {
216  (addExtension(std::make_unique<ExtensionsT>()), ...);
217  }
218 
219  /// Add an extension function that requires the given dialects.
220  /// Note: This bare functor overload is provided in addition to the
221  /// std::function variant to enable dialect type deduction, e.g.:
222  /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
223  ///
224  /// is equivalent to:
225  /// registry.addExtension<MyDialect>(
226  /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
227  /// )
228  template <typename... DialectsT>
229  void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
230  addExtension<DialectsT...>(
231  std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
232  }
233  template <typename... DialectsT>
234  void
235  addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
236  using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
237 
238  struct Extension : public DialectExtension<Extension, DialectsT...> {
239  Extension(const Extension &) = default;
240  Extension(ExtensionFnT extensionFn)
241  : extensionFn(std::move(extensionFn)) {}
242  ~Extension() override = default;
243 
244  void apply(MLIRContext *context, DialectsT *...dialects) const final {
245  extensionFn(context, dialects...);
246  }
247  ExtensionFnT extensionFn;
248  };
249  addExtension(std::make_unique<Extension>(std::move(extensionFn)));
250  }
251 
252  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
253  /// contains all of the components of this registry.
254  bool isSubsetOf(const DialectRegistry &rhs) const;
255 
256 private:
257  MapTy registry;
258  std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
259 };
260 
261 } // namespace mlir
262 
263 #endif // MLIR_IR_DIALECTREGISTRY_H
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:293
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:179
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:197
void addExtension(std::function< void(MLIRContext *, DialectsT *...)> extensionFn)
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
void addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:215
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:166
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:160
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:153
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction