MLIR  14.0.0git
Dialect.h
Go to the documentation of this file.
1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the 'dialect' abstraction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_H
14 #define MLIR_IR_DIALECT_H
15 
17 #include "mlir/Support/TypeID.h"
18 
19 #include <map>
20 #include <tuple>
21 
22 namespace mlir {
23 class DialectAsmParser;
24 class DialectAsmPrinter;
25 class DialectInterface;
26 class OpBuilder;
27 class Type;
28 
29 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
32  std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
33 using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
34 
35 /// Dialects are groups of MLIR operations, types and attributes, as well as
36 /// behavior associated with the entire group. For example, hooks into other
37 /// systems for constant folding, interfaces, default named types for asm
38 /// printing, etc.
39 ///
40 /// Instances of the dialect object are loaded in a specific MLIRContext.
41 ///
42 class Dialect {
43 public:
44  /// Type for a callback provided by the dialect to parse a custom operation.
45  /// This is used for the dialect to provide an alternative way to parse custom
46  /// operations, including unregistered ones.
47  using ParseOpHook =
49 
50  virtual ~Dialect();
51 
52  /// Utility function that returns if the given string is a valid dialect
53  /// namespace
54  static bool isValidNamespace(StringRef str);
55 
56  MLIRContext *getContext() const { return context; }
57 
58  StringRef getNamespace() const { return name; }
59 
60  /// Returns the unique identifier that corresponds to this dialect.
61  TypeID getTypeID() const { return dialectID; }
62 
63  /// Returns true if this dialect allows for unregistered operations, i.e.
64  /// operations prefixed with the dialect namespace but not registered with
65  /// addOperation.
66  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
67 
68  /// Return true if this dialect allows for unregistered types, i.e., types
69  /// prefixed with the dialect namespace but not registered with addType.
70  /// These are represented with OpaqueType.
71  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
72 
73  /// Register dialect-wide canonicalization patterns. This method should only
74  /// be used to register canonicalization patterns that do not conceptually
75  /// belong to any single operation in the dialect. (In that case, use the op's
76  /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
77  /// be registered here.
78  virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
79 
80  /// Registered hook to materialize a single constant operation from a given
81  /// attribute value with the desired resultant type. This method should use
82  /// the provided builder to create the operation without changing the
83  /// insertion position. The generated operation is expected to be constant
84  /// like, i.e. single result, zero operands, non side-effecting, etc. On
85  /// success, this hook should return the value generated to represent the
86  /// constant value. Otherwise, it should return null on failure.
88  Type type, Location loc) {
89  return nullptr;
90  }
91 
92  //===--------------------------------------------------------------------===//
93  // Parsing Hooks
94  //===--------------------------------------------------------------------===//
95 
96  /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
97  /// refers to the expected type of the attribute.
98  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
99 
100  /// Print an attribute registered to this dialect. Note: The type of the
101  /// attribute need not be printed by this method as it is always printed by
102  /// the caller.
103  virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
104  llvm_unreachable("dialect has no registered attribute printing hook");
105  }
106 
107  /// Parse a type registered to this dialect.
108  virtual Type parseType(DialectAsmParser &parser) const;
109 
110  /// Print a type registered to this dialect.
111  virtual void printType(Type, DialectAsmPrinter &) const {
112  llvm_unreachable("dialect has no registered type printing hook");
113  }
114 
115  /// Return the hook to parse an operation registered to this dialect, if any.
116  /// By default this will lookup for registered operations and return the
117  /// `parse()` method registered on the RegisteredOperationName. Dialects can
118  /// override this behavior and handle unregistered operations as well.
119  virtual Optional<ParseOpHook> getParseOperationHook(StringRef opName) const;
120 
121  /// Print an operation registered to this dialect.
122  /// This hook is invoked for registered operation which don't override the
123  /// `print()` method to define their own custom assembly.
124  virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
125  getOperationPrinter(Operation *op) const;
126 
127  //===--------------------------------------------------------------------===//
128  // Verification Hooks
129  //===--------------------------------------------------------------------===//
130 
131  /// Verify an attribute from this dialect on the argument at 'argIndex' for
132  /// the region at 'regionIndex' on the given operation. Returns failure if
133  /// the verification failed, success otherwise. This hook may optionally be
134  /// invoked from any operation containing a region.
136  unsigned regionIndex,
137  unsigned argIndex,
139 
140  /// Verify an attribute from this dialect on the result at 'resultIndex' for
141  /// the region at 'regionIndex' on the given operation. Returns failure if
142  /// the verification failed, success otherwise. This hook may optionally be
143  /// invoked from any operation containing a region.
145  unsigned regionIndex,
146  unsigned resultIndex,
148 
149  /// Verify an attribute from this dialect on the given operation. Returns
150  /// failure if the verification failed, success otherwise.
152  return success();
153  }
154 
155  //===--------------------------------------------------------------------===//
156  // Interfaces
157  //===--------------------------------------------------------------------===//
158 
159  /// Lookup an interface for the given ID if one is registered, otherwise
160  /// nullptr.
162  auto it = registeredInterfaces.find(interfaceID);
163  return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
164  }
165  template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
166  return static_cast<const InterfaceT *>(
167  getRegisteredInterface(InterfaceT::getInterfaceID()));
168  }
169 
170  /// Lookup an op interface for the given ID if one is registered, otherwise
171  /// nullptr.
172  virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
173  OperationName opName) {
174  return nullptr;
175  }
176  template <typename InterfaceT>
177  typename InterfaceT::Concept *
179  return static_cast<typename InterfaceT::Concept *>(
180  getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
181  }
182 
183 protected:
184  /// The constructor takes a unique namespace for this dialect as well as the
185  /// context to bind to.
186  /// Note: The namespace must not contain '.' characters.
187  /// Note: All operations belonging to this dialect must have names starting
188  /// with the namespace followed by '.'.
189  /// Example:
190  /// - "tf" for the TensorFlow ops like "tf.add".
191  Dialect(StringRef name, MLIRContext *context, TypeID id);
192 
193  /// This method is used by derived classes to add their operations to the set.
194  ///
195  template <typename... Args> void addOperations() {
196  (void)std::initializer_list<int>{
197  0, (RegisteredOperationName::insert<Args>(*this), 0)...};
198  }
199 
200  /// Register a set of type classes with this dialect.
201  template <typename... Args> void addTypes() {
202  (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
203  }
204 
205  /// Register a type instance with this dialect.
206  /// The use of this method is in general discouraged in favor of
207  /// 'addTypes<CustomType>()'.
208  void addType(TypeID typeID, AbstractType &&typeInfo);
209 
210  /// Register a set of attribute classes with this dialect.
211  template <typename... Args> void addAttributes() {
212  (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
213  }
214 
215  /// Enable support for unregistered operations.
216  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
217 
218  /// Enable support for unregistered types.
219  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
220 
221  /// Register a dialect interface with this dialect instance.
222  void addInterface(std::unique_ptr<DialectInterface> interface);
223 
224  /// Register a set of dialect interfaces with this dialect instance.
225  template <typename... Args> void addInterfaces() {
226  (void)std::initializer_list<int>{
227  0, (addInterface(std::make_unique<Args>(this)), 0)...};
228  }
229 
230 private:
231  Dialect(const Dialect &) = delete;
232  void operator=(Dialect &) = delete;
233 
234  /// Register an attribute instance with this dialect.
235  template <typename T> void addAttribute() {
236  // Add this attribute to the dialect and register it with the uniquer.
237  addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
238  detail::AttributeUniquer::registerAttribute<T>(context);
239  }
240  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
241 
242  /// Register a type instance with this dialect.
243  template <typename T> void addType() {
244  // Add this type to the dialect and register it with the uniquer.
245  addType(T::getTypeID(), AbstractType::get<T>(*this));
246  detail::TypeUniquer::registerType<T>(context);
247  }
248 
249  /// The namespace of this dialect.
250  StringRef name;
251 
252  /// The unique identifier of the derived Op class, this is used in the context
253  /// to allow registering multiple times the same dialect.
254  TypeID dialectID;
255 
256  /// This is the context that owns this Dialect object.
257  MLIRContext *context;
258 
259  /// Flag that specifies whether this dialect supports unregistered operations,
260  /// i.e. operations prefixed with the dialect namespace but not registered
261  /// with addOperation.
262  bool unknownOpsAllowed = false;
263 
264  /// Flag that specifies whether this dialect allows unregistered types, i.e.
265  /// types prefixed with the dialect namespace but not registered with addType.
266  /// These types are represented with OpaqueType.
267  bool unknownTypesAllowed = false;
268 
269  /// A collection of registered dialect interfaces.
271 
272  friend class DialectRegistry;
273  friend void registerDialect();
274  friend class MLIRContext;
275 };
276 
277 /// The DialectRegistry maps a dialect namespace to a constructor for the
278 /// matching dialect.
279 /// This allows for decoupling the list of dialects "available" from the
280 /// dialects loaded in the Context. The parser in particular will lazily load
281 /// dialects in the Context as operations are encountered.
283  /// Lists of interfaces that need to be registered when the dialect is loaded.
284  struct DelayedInterfaces {
285  /// Dialect interfaces.
287  dialectInterfaces;
288  /// Attribute/Operation/Type interfaces.
290  objectInterfaces;
291  };
292 
293  using MapTy =
294  std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
296 
297 public:
298  explicit DialectRegistry();
299 
300  template <typename ConcreteDialect> void insert() {
301  insert(TypeID::get<ConcreteDialect>(),
302  ConcreteDialect::getDialectNamespace(),
303  static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
304  // Just allocate the dialect, the context
305  // takes ownership of it.
306  return ctx->getOrLoadDialect<ConcreteDialect>();
307  })));
308  }
309 
310  template <typename ConcreteDialect, typename OtherDialect,
311  typename... MoreDialects>
312  void insert() {
313  insert<ConcreteDialect>();
314  insert<OtherDialect, MoreDialects...>();
315  }
316 
317  /// Add a new dialect constructor to the registry. The constructor must be
318  /// calling MLIRContext::getOrLoadDialect in order for the context to take
319  /// ownership of the dialect and for delayed interface registration to happen.
320  void insert(TypeID typeID, StringRef name,
321  const DialectAllocatorFunction &ctor);
322 
323  /// Return an allocation function for constructing the dialect identified by
324  /// its namespace, or nullptr if the namespace is not in this registry.
325  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
326 
327  // Register all dialects available in the current registry with the registry
328  // in the provided context.
329  void appendTo(DialectRegistry &destination) const {
330  for (const auto &nameAndRegistrationIt : registry)
331  destination.insert(nameAndRegistrationIt.second.first,
332  nameAndRegistrationIt.first,
333  nameAndRegistrationIt.second.second);
334  // Merge interfaces.
335  for (auto it : interfaces) {
336  TypeID dialect = it.first;
337  auto destInterfaces = destination.interfaces.find(dialect);
338  if (destInterfaces == destination.interfaces.end()) {
339  destination.interfaces[dialect] = it.second;
340  continue;
341  }
342  // The destination already has delayed interface registrations for this
343  // dialect. Merge registrations into the destination registry.
344  destInterfaces->second.dialectInterfaces.append(
345  it.second.dialectInterfaces.begin(),
346  it.second.dialectInterfaces.end());
347  destInterfaces->second.objectInterfaces.append(
348  it.second.objectInterfaces.begin(), it.second.objectInterfaces.end());
349  }
350  }
351 
352  /// Return the names of dialects known to this registry.
353  auto getDialectNames() const {
354  return llvm::map_range(
355  registry,
356  [](const MapTy::value_type &item) -> StringRef { return item.first; });
357  }
358 
359  /// Add an interface constructed with the given allocation function to the
360  /// dialect provided as template parameter. The dialect must be present in
361  /// the registry.
362  template <typename DialectTy>
363  void addDialectInterface(TypeID interfaceTypeID,
365  addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
366  allocator);
367  }
368 
369  /// Add an interface to the dialect, both provided as template parameter. The
370  /// dialect must be present in the registry.
371  template <typename DialectTy, typename InterfaceTy>
373  addDialectInterface<DialectTy>(
374  InterfaceTy::getInterfaceID(), [](Dialect *dialect) {
375  return std::make_unique<InterfaceTy>(dialect);
376  });
377  }
378 
379  /// Add an external op interface model for an op that belongs to a dialect,
380  /// both provided as template parameters. The dialect must be present in the
381  /// registry.
382  template <typename OpTy, typename ModelTy> void addOpInterface() {
383  StringRef opName = OpTy::getOperationName();
384  StringRef dialectName = opName.split('.').first;
385  addObjectInterface(dialectName, TypeID::get<OpTy>(),
386  ModelTy::Interface::getInterfaceID(),
387  [](MLIRContext *context) {
388  OpTy::template attachInterface<ModelTy>(*context);
389  });
390  }
391 
392  /// Add an external attribute interface model for an attribute type `AttrTy`
393  /// that is going to belong to `DialectTy`. The dialect must be present in the
394  /// registry.
395  template <typename DialectTy, typename AttrTy, typename ModelTy>
397  addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
398  }
399 
400  /// Add an external type interface model for an type class `TypeTy` that is
401  /// going to belong to `DialectTy`. The dialect must be present in the
402  /// registry.
403  template <typename DialectTy, typename TypeTy, typename ModelTy>
405  addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
406  }
407 
408  /// Register any interfaces required for the given dialect (based on its
409  /// TypeID). Users are not expected to call this directly.
410  void registerDelayedInterfaces(Dialect *dialect) const;
411 
412 private:
413  /// Add an interface constructed with the given allocation function to the
414  /// dialect identified by its namespace.
415  void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
416  const DialectInterfaceAllocatorFunction &allocator);
417 
418  /// Add an attribute/operation/type interface constructible with the given
419  /// allocation function to the dialect identified by its namespace.
420  void addObjectInterface(StringRef dialectName, TypeID objectID,
421  TypeID interfaceTypeID,
422  const ObjectInterfaceAllocatorFunction &allocator);
423 
424  /// Add an external model for an attribute/type interface to the dialect
425  /// identified by its namespace.
426  template <typename ObjectTy, typename ModelTy>
427  void addStorageUserInterface(StringRef dialectName) {
428  addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
429  ModelTy::Interface::getInterfaceID(),
430  [](MLIRContext *context) {
431  ObjectTy::template attachInterface<ModelTy>(*context);
432  });
433  }
434 
435  MapTy registry;
436  InterfaceMapTy interfaces;
437 };
438 
439 } // namespace mlir
440 
441 namespace llvm {
442 /// Provide isa functionality for Dialects.
443 template <typename T> struct isa_impl<T, ::mlir::Dialect> {
444  static inline bool doit(const ::mlir::Dialect &dialect) {
445  return mlir::TypeID::get<T>() == dialect.getTypeID();
446  }
447 };
448 } // namespace llvm
449 
450 #endif
Include the generated interface declarations.
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:61
void addInterfaces()
Register a set of dialect interfaces with this dialect instance.
Definition: Dialect.h:225
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
void allowUnknownOperations(bool allow=true)
Enable support for unregistered operations.
Definition: Dialect.h:216
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:148
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
Definition: Dialect.h:29
void addDialectInterface()
Add an interface to the dialect, both provided as template parameter.
Definition: Dialect.h:372
virtual void getCanonicalizationPatterns(RewritePatternSet &results) const
Register dialect-wide canonicalization patterns.
Definition: Dialect.h:78
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:174
virtual ~Dialect()
This class contains all of the static information common to all instances of a registered Type...
Definition: TypeSupport.h:30
void addOperations()
This method is used by derived classes to add their operations to the set.
Definition: Dialect.h:195
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to...
Definition: Dialect.cpp:122
static constexpr const bool value
void addDialectInterface(TypeID interfaceTypeID, DialectInterfaceAllocatorFunction allocator)
Add an interface constructed with the given allocation function to the dialect provided as template p...
Definition: Dialect.h:363
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:52
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
friend void registerDialect()
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
void addOpInterface()
Add an external op interface model for an op that belongs to a dialect, both provided as template par...
Definition: Dialect.h:382
const InterfaceT * getRegisteredInterface()
Definition: Dialect.h:165
virtual void * getRegisteredInterfaceForOp(TypeID interfaceID, OperationName opName)
Lookup an op interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:172
const DialectInterface * getRegisteredInterface(TypeID interfaceID)
Lookup an interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:161
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:42
virtual Optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:169
std::function< void(MLIRContext *)> ObjectInterfaceAllocatorFunction
Definition: Dialect.h:33
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:71
MLIRContext * getContext() const
Definition: Dialect.h:56
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute)
Verify an attribute from this dialect on the given operation.
Definition: Dialect.h:151
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
auto getDialectNames() const
Return the names of dialects known to this registry.
Definition: Dialect.h:353
static int resultIndex(int i)
Definition: Operator.cpp:308
void addAttrInterface()
Add an external attribute interface model for an attribute type AttrTy that is going to belong to Dia...
Definition: Dialect.h:396
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:92
void addTypeInterface()
Add an external type interface model for an type class TypeTy that is going to belong to DialectTy...
Definition: Dialect.h:404
void addAttributes()
Register a set of attribute classes with this dialect.
Definition: Dialect.h:211
StringRef getNamespace() const
Definition: Dialect.h:58
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:188
void allowUnknownTypes(bool allow=true)
Enable support for unregistered types.
Definition: Dialect.h:219
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Definition: Dialect.h:282
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void addTypes()
Register a set of type classes with this dialect.
Definition: Dialect.h:201
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
bool allowsUnknownOperations() const
Returns true if this dialect allows for unregistered operations, i.e.
Definition: Dialect.h:66
std::function< std::unique_ptr< DialectInterface >(Dialect *)> DialectInterfaceAllocatorFunction
Definition: Dialect.h:32
virtual void printType(Type, DialectAsmPrinter &) const
Print a type registered to this dialect.
Definition: Dialect.h:111
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at &#39;argIndex&#39; for the region at &#39;regionIndex&#39; o...
Definition: Dialect.cpp:133
InterfaceT::Concept * getRegisteredInterfaceForOp(OperationName opName)
Definition: Dialect.h:178
virtual void printAttribute(Attribute, DialectAsmPrinter &) const
Print an attribute registered to this dialect.
Definition: Dialect.h:103
friend class DialectRegistry
Definition: Dialect.h:272
void addType(TypeID typeID, AbstractType &&typeInfo)
Register a type instance with this dialect.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:87
static bool doit(const ::mlir::Dialect &dialect)
Definition: Dialect.h:444
This class helps build Operations.
Definition: Builders.h:177
This class contains all of the static information common to all instances of a registered Attribute...
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:156
void appendTo(DialectRegistry &destination) const
Definition: Dialect.h:329
This class represents an interface overridden for a single dialect.
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at &#39;resultIndex&#39; for the region at &#39;regionIndex&#39; ...
Definition: Dialect.cpp:142
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:182