MLIR  16.0.0git
DialectRegistry.h
Go to the documentation of this file.
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30  std::function<void(MLIRContext *, DynamicDialect *)>;
31 
32 //===----------------------------------------------------------------------===//
33 // DialectExtension
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents an opaque dialect extension. It contains a set of
37 /// required dialects and an application function. The required dialects control
38 /// when the extension is applied, i.e. the extension is applied when all
39 /// required dialects are loaded. The application function can be used to attach
40 /// additional functionality to attributes, dialects, operations, types, etc.,
41 /// and may also load additional necessary dialects.
43 public:
45 
46  /// Return the dialects that our required by this extension to be loaded
47  /// before applying.
48  ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
49 
50  /// Apply this extension to the given context and the required dialects.
51  virtual void apply(MLIRContext *context,
52  MutableArrayRef<Dialect *> dialects) const = 0;
53 
54  /// Return a copy of this extension.
55  virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
56 
57 protected:
58  /// Initialize the extension with a set of required dialects. Note that there
59  /// should always be at least one affected dialect.
61  : dialectNames(dialectNames.begin(), dialectNames.end()) {
62  assert(!dialectNames.empty() && "expected at least one affected dialect");
63  }
64 
65 private:
66  /// The names of the dialects affected by this extension.
67  SmallVector<StringRef> dialectNames;
68 };
69 
70 /// This class represents a dialect extension anchored on the given set of
71 /// dialects. When all of the specified dialects have been loaded, the
72 /// application function of this extension will be executed.
73 template <typename DerivedT, typename... DialectsT>
75 public:
76  /// Applies this extension to the given context and set of required dialects.
77  virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78 
79  /// Return a copy of this extension.
80  std::unique_ptr<DialectExtensionBase> clone() const final {
81  return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82  }
83 
84 protected:
87  ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88 
89  /// Override the base apply method to allow providing the exact dialect types.
90  void apply(MLIRContext *context,
91  MutableArrayRef<Dialect *> dialects) const final {
92  unsigned dialectIdx = 0;
93  auto derivedDialects = std::tuple<DialectsT *...>{
94  static_cast<DialectsT *>(dialects[dialectIdx++])...};
95  std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96  derivedDialects);
97  }
98 };
99 
100 //===----------------------------------------------------------------------===//
101 // DialectRegistry
102 //===----------------------------------------------------------------------===//
103 
104 /// The DialectRegistry maps a dialect namespace to a constructor for the
105 /// matching dialect. This allows for decoupling the list of dialects
106 /// "available" from the dialects loaded in the Context. The parser in
107 /// particular will lazily load dialects in the Context as operations are
108 /// encountered.
110  using MapTy =
111  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
112 
113 public:
114  explicit DialectRegistry();
115 
116  template <typename ConcreteDialect>
117  void insert() {
118  insert(TypeID::get<ConcreteDialect>(),
119  ConcreteDialect::getDialectNamespace(),
120  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
121  // Just allocate the dialect, the context
122  // takes ownership of it.
123  return ctx->getOrLoadDialect<ConcreteDialect>();
124  })));
125  }
126 
127  template <typename ConcreteDialect, typename OtherDialect,
128  typename... MoreDialects>
129  void insert() {
130  insert<ConcreteDialect>();
131  insert<OtherDialect, MoreDialects...>();
132  }
133 
134  /// Add a new dialect constructor to the registry. The constructor must be
135  /// calling MLIRContext::getOrLoadDialect in order for the context to take
136  /// ownership of the dialect and for delayed interface registration to happen.
137  void insert(TypeID typeID, StringRef name,
138  const DialectAllocatorFunction &ctor);
139 
140  /// Add a new dynamic dialect constructor in the registry. The constructor
141  /// provides as argument the created dynamic dialect, and is expected to
142  /// register the dialect types, attributes, and ops, using the
143  /// methods defined in ExtensibleDialect such as registerDynamicOperation.
144  void insertDynamic(StringRef name,
146 
147  /// Return an allocation function for constructing the dialect identified
148  /// by its namespace, or nullptr if the namespace is not in this registry.
149  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
150 
151  // Register all dialects available in the current registry with the registry
152  // in the provided context.
153  void appendTo(DialectRegistry &destination) const {
154  for (const auto &nameAndRegistrationIt : registry)
155  destination.insert(nameAndRegistrationIt.second.first,
156  nameAndRegistrationIt.first,
157  nameAndRegistrationIt.second.second);
158  // Merge the extensions.
159  for (const auto &extension : extensions)
160  destination.extensions.push_back(extension->clone());
161  }
162 
163  /// Return the names of dialects known to this registry.
164  auto getDialectNames() const {
165  return llvm::map_range(
166  registry,
167  [](const MapTy::value_type &item) -> StringRef { return item.first; });
168  }
169 
170  /// Apply any held extensions that require the given dialect. Users are not
171  /// expected to call this directly.
172  void applyExtensions(Dialect *dialect) const;
173 
174  /// Apply any applicable extensions to the given context. Users are not
175  /// expected to call this directly.
176  void applyExtensions(MLIRContext *ctx) const;
177 
178  /// Add the given extension to the registry.
179  void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
180  extensions.push_back(std::move(extension));
181  }
182 
183  /// Add the given extensions to the registry.
184  template <typename... ExtensionsT>
185  void addExtensions() {
186  (addExtension(std::make_unique<ExtensionsT>()), ...);
187  }
188 
189  /// Add an extension function that requires the given dialects.
190  /// Note: This bare functor overload is provided in addition to the
191  /// std::function variant to enable dialect type deduction, e.g.:
192  /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
193  ///
194  /// is equivalent to:
195  /// registry.addExtension<MyDialect>(
196  /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
197  /// )
198  template <typename... DialectsT>
199  void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
200  addExtension<DialectsT...>(
201  std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
202  }
203  template <typename... DialectsT>
204  void
205  addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
206  using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
207 
208  struct Extension : public DialectExtension<Extension, DialectsT...> {
209  Extension(const Extension &) = default;
210  Extension(ExtensionFnT extensionFn)
211  : extensionFn(std::move(extensionFn)) {}
212  ~Extension() override = default;
213 
214  void apply(MLIRContext *context, DialectsT *...dialects) const final {
215  extensionFn(context, dialects...);
216  }
217  ExtensionFnT extensionFn;
218  };
219  addExtension(std::make_unique<Extension>(std::move(extensionFn)));
220  }
221 
222  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
223  /// contains all of the components of this registry.
224  bool isSubsetOf(const DialectRegistry &rhs) const;
225 
226 private:
227  MapTy registry;
228  std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
229 };
230 
231 } // namespace mlir
232 
233 #endif // MLIR_IR_DIALECTREGISTRY_H
This class represents an opaque dialect extension.
ArrayRef< StringRef > getRequiredDialects() const
Return the dialects that our required by this extension to be loaded before applying.
DialectExtensionBase(ArrayRef< StringRef > dialectNames)
Initialize the extension with a set of required dialects.
virtual void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const =0
Apply this extension to the given context and the required dialects.
virtual std::unique_ptr< DialectExtensionBase > clone() const =0
Return a copy of this extension.
This class represents a dialect extension anchored on the given set of dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const =0
Applies this extension to the given context and set of required dialects.
std::unique_ptr< DialectExtensionBase > clone() const final
Return a copy of this extension.
void apply(MLIRContext *context, MutableArrayRef< Dialect * > dialects) const final
Override the base apply method to allow providing the exact dialect types.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:255
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:153
void appendTo(DialectRegistry &destination) const
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:171
void addExtension(std::function< void(MLIRContext *, DialectsT *...)> extensionFn)
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
void addExtension(void(*extensionFn)(MLIRContext *, DialectsT *...))
Add an extension function that requires the given dialects.
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:189
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:93
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction