MLIR  16.0.0git
DialectRegistry.h
Go to the documentation of this file.
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
29 
30 //===----------------------------------------------------------------------===//
31 // DialectExtension
32 //===----------------------------------------------------------------------===//
33 
34 /// This class represents an opaque dialect extension. It contains a set of
35 /// required dialects and an application function. The required dialects control
36 /// when the extension is applied, i.e. the extension is applied when all
37 /// required dialects are loaded. The application function can be used to attach
38 /// additional functionality to attributes, dialects, operations, types, etc.,
39 /// and may also load additional necessary dialects.
41 public:
42  virtual ~DialectExtensionBase();
43 
44  /// Return the dialects that our required by this extension to be loaded
45  /// before applying.
46  ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
47 
48  /// Apply this extension to the given context and the required dialects.
49  virtual void apply(MLIRContext *context,
50  MutableArrayRef<Dialect *> dialects) const = 0;
51 
52  /// Return a copy of this extension.
53  virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
54 
55 protected:
56  /// Initialize the extension with a set of required dialects. Note that there
57  /// should always be at least one affected dialect.
59  : dialectNames(dialectNames.begin(), dialectNames.end()) {
60  assert(!dialectNames.empty() && "expected at least one affected dialect");
61  }
62 
63 private:
64  /// The names of the dialects affected by this extension.
65  SmallVector<StringRef> dialectNames;
66 };
67 
68 /// This class represents a dialect extension anchored on the given set of
69 /// dialects. When all of the specified dialects have been loaded, the
70 /// application function of this extension will be executed.
71 template <typename DerivedT, typename... DialectsT>
73 public:
74  /// Applies this extension to the given context and set of required dialects.
75  virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
76 
77  /// Return a copy of this extension.
78  std::unique_ptr<DialectExtensionBase> clone() const final {
79  return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
80  }
81 
82 protected:
85  ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
86 
87  /// Override the base apply method to allow providing the exact dialect types.
88  void apply(MLIRContext *context,
89  MutableArrayRef<Dialect *> dialects) const final {
90  unsigned dialectIdx = 0;
91  auto derivedDialects = std::tuple<DialectsT *...>{
92  static_cast<DialectsT *>(dialects[dialectIdx++])...};
93  std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
94  derivedDialects);
95  }
96 };
97 
98 //===----------------------------------------------------------------------===//
99 // DialectRegistry
100 //===----------------------------------------------------------------------===//
101 
102 /// The DialectRegistry maps a dialect namespace to a constructor for the
103 /// matching dialect. This allows for decoupling the list of dialects
104 /// "available" from the dialects loaded in the Context. The parser in
105 /// particular will lazily load dialects in the Context as operations are
106 /// encountered.
108  using MapTy =
109  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
110 
111 public:
112  explicit DialectRegistry();
113 
114  template <typename ConcreteDialect>
115  void insert() {
116  insert(TypeID::get<ConcreteDialect>(),
117  ConcreteDialect::getDialectNamespace(),
118  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
119  // Just allocate the dialect, the context
120  // takes ownership of it.
121  return ctx->getOrLoadDialect<ConcreteDialect>();
122  })));
123  }
124 
125  template <typename ConcreteDialect, typename OtherDialect,
126  typename... MoreDialects>
127  void insert() {
128  insert<ConcreteDialect>();
129  insert<OtherDialect, MoreDialects...>();
130  }
131 
132  /// Add a new dialect constructor to the registry. The constructor must be
133  /// calling MLIRContext::getOrLoadDialect in order for the context to take
134  /// ownership of the dialect and for delayed interface registration to happen.
135  void insert(TypeID typeID, StringRef name,
136  const DialectAllocatorFunction &ctor);
137 
138  /// Return an allocation function for constructing the dialect identified by
139  /// its namespace, or nullptr if the namespace is not in this registry.
140  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
141 
142  // Register all dialects available in the current registry with the registry
143  // in the provided context.
144  void appendTo(DialectRegistry &destination) const {
145  for (const auto &nameAndRegistrationIt : registry)
146  destination.insert(nameAndRegistrationIt.second.first,
147  nameAndRegistrationIt.first,
148  nameAndRegistrationIt.second.second);
149  // Merge the extensions.
150  for (const auto &extension : extensions)
151  destination.extensions.push_back(extension->clone());
152  }
153 
154  /// Return the names of dialects known to this registry.
155  auto getDialectNames() const {
156  return llvm::map_range(
157  registry,
158  [](const MapTy::value_type &item) -> StringRef { return item.first; });
159  }
160 
161  /// Apply any held extensions that require the given dialect. Users are not
162  /// expected to call this directly.
163  void applyExtensions(Dialect *dialect) const;
164 
165  /// Apply any applicable extensions to the given context. Users are not
166  /// expected to call this directly.
167  void applyExtensions(MLIRContext *ctx) const;
168 
169  /// Add the given extension to the registry.
170  void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
171  extensions.push_back(std::move(extension));
172  }
173 
174  /// Add the given extensions to the registry.
175  template <typename... ExtensionsT>
176  void addExtensions() {
177  (addExtension(std::make_unique<ExtensionsT>()), ...);
178  }
179 
180  /// Add an extension function that requires the given dialects.
181  /// Note: This bare functor overload is provided in addition to the
182  /// std::function variant to enable dialect type deduction, e.g.:
183  /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
184  ///
185  /// is equivalent to:
186  /// registry.addExtension<MyDialect>(
187  /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
188  /// )
189  template <typename... DialectsT>
190  void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
191  addExtension<DialectsT...>(
192  std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
193  }
194  template <typename... DialectsT>
195  void
196  addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
197  using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
198 
199  struct Extension : public DialectExtension<Extension, DialectsT...> {
200  Extension(const Extension &) = default;
201  Extension(ExtensionFnT extensionFn)
202  : extensionFn(std::move(extensionFn)) {}
203  ~Extension() override = default;
204 
205  void apply(MLIRContext *context, DialectsT *...dialects) const final {
206  extensionFn(context, dialects...);
207  }
208  ExtensionFnT extensionFn;
209  };
210  addExtension(std::make_unique<Extension>(std::move(extensionFn)));
211  }
212 
213  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
214  /// contains all of the components of this registry.
215  bool isSubsetOf(const DialectRegistry &rhs) const;
216 
217 private:
218  MapTy registry;
219  std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
220 };
221 
222 } // namespace mlir
223 
224 #endif // MLIR_IR_DIALECTREGISTRY_H
Include the generated interface declarations.
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect *> dialects) const =0
Apply this extension to the given context and the required dialects.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
void addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
void apply(MLIRContext *context, MutableArrayRef< Dialect *> dialects) const final
Override the base apply method to allow providing the exact dialect types.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
This class represents a dialect extension anchored on the given set of dialects.
auto getDialectNames() const
Return the names of dialects known to this registry.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:92
void addExtension(std::function< void(MLIRContext *, DialectsT *...)> extensionFn)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void appendTo(DialectRegistry &destination) const