MLIR  19.0.0git
Dialect.h
Go to the documentation of this file.
1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the 'dialect' abstraction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_H
14 #define MLIR_IR_DIALECT_H
15 
18 #include "mlir/Support/TypeID.h"
19 
20 #include <map>
21 #include <tuple>
22 
23 namespace mlir {
24 class DialectAsmParser;
25 class DialectAsmPrinter;
26 class DialectInterface;
27 class OpBuilder;
28 class Type;
29 
30 //===----------------------------------------------------------------------===//
31 // Dialect
32 //===----------------------------------------------------------------------===//
33 
34 /// Dialects are groups of MLIR operations, types and attributes, as well as
35 /// behavior associated with the entire group. For example, hooks into other
36 /// systems for constant folding, interfaces, default named types for asm
37 /// printing, etc.
38 ///
39 /// Instances of the dialect object are loaded in a specific MLIRContext.
40 ///
41 class Dialect {
42 public:
43  /// Type for a callback provided by the dialect to parse a custom operation.
44  /// This is used for the dialect to provide an alternative way to parse custom
45  /// operations, including unregistered ones.
46  using ParseOpHook =
48 
49  virtual ~Dialect();
50 
51  /// Utility function that returns if the given string is a valid dialect
52  /// namespace
53  static bool isValidNamespace(StringRef str);
54 
55  MLIRContext *getContext() const { return context; }
56 
57  StringRef getNamespace() const { return name; }
58 
59  /// Returns the unique identifier that corresponds to this dialect.
60  TypeID getTypeID() const { return dialectID; }
61 
62  /// Returns true if this dialect allows for unregistered operations, i.e.
63  /// operations prefixed with the dialect namespace but not registered with
64  /// addOperation.
65  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
66 
67  /// Return true if this dialect allows for unregistered types, i.e., types
68  /// prefixed with the dialect namespace but not registered with addType.
69  /// These are represented with OpaqueType.
70  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
71 
72  /// Register dialect-wide canonicalization patterns. This method should only
73  /// be used to register canonicalization patterns that do not conceptually
74  /// belong to any single operation in the dialect. (In that case, use the op's
75  /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
76  /// be registered here.
77  virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
78 
79  /// Registered hook to materialize a single constant operation from a given
80  /// attribute value with the desired resultant type. This method should use
81  /// the provided builder to create the operation without changing the
82  /// insertion position. The generated operation is expected to be constant
83  /// like, i.e. single result, zero operands, non side-effecting, etc. On
84  /// success, this hook should return the value generated to represent the
85  /// constant value. Otherwise, it should return null on failure.
87  Type type, Location loc) {
88  return nullptr;
89  }
90 
91  //===--------------------------------------------------------------------===//
92  // Parsing Hooks
93  //===--------------------------------------------------------------------===//
94 
95  /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
96  /// refers to the expected type of the attribute.
97  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
98 
99  /// Print an attribute registered to this dialect. Note: The type of the
100  /// attribute need not be printed by this method as it is always printed by
101  /// the caller.
102  virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
103  llvm_unreachable("dialect has no registered attribute printing hook");
104  }
105 
106  /// Parse a type registered to this dialect.
107  virtual Type parseType(DialectAsmParser &parser) const;
108 
109  /// Print a type registered to this dialect.
110  virtual void printType(Type, DialectAsmPrinter &) const {
111  llvm_unreachable("dialect has no registered type printing hook");
112  }
113 
114  /// Return the hook to parse an operation registered to this dialect, if any.
115  /// By default this will lookup for registered operations and return the
116  /// `parse()` method registered on the RegisteredOperationName. Dialects can
117  /// override this behavior and handle unregistered operations as well.
118  virtual std::optional<ParseOpHook>
119  getParseOperationHook(StringRef opName) const;
120 
121  /// Print an operation registered to this dialect.
122  /// This hook is invoked for registered operation which don't override the
123  /// `print()` method to define their own custom assembly.
124  virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
125  getOperationPrinter(Operation *op) const;
126 
127  //===--------------------------------------------------------------------===//
128  // Verification Hooks
129  //===--------------------------------------------------------------------===//
130 
131  /// Verify an attribute from this dialect on the argument at 'argIndex' for
132  /// the region at 'regionIndex' on the given operation. Returns failure if
133  /// the verification failed, success otherwise. This hook may optionally be
134  /// invoked from any operation containing a region.
136  unsigned regionIndex,
137  unsigned argIndex,
139 
140  /// Verify an attribute from this dialect on the result at 'resultIndex' for
141  /// the region at 'regionIndex' on the given operation. Returns failure if
142  /// the verification failed, success otherwise. This hook may optionally be
143  /// invoked from any operation containing a region.
145  unsigned regionIndex,
146  unsigned resultIndex,
148 
149  /// Verify an attribute from this dialect on the given operation. Returns
150  /// failure if the verification failed, success otherwise.
152  return success();
153  }
154 
155  //===--------------------------------------------------------------------===//
156  // Interfaces
157  //===--------------------------------------------------------------------===//
158 
159  /// Lookup an interface for the given ID if one is registered, otherwise
160  /// nullptr.
162 #ifndef NDEBUG
164 #endif
165 
166  auto it = registeredInterfaces.find(interfaceID);
167  return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
168  }
169  template <typename InterfaceT>
170  InterfaceT *getRegisteredInterface() {
171 #ifndef NDEBUG
173  InterfaceT::getInterfaceID(),
174  llvm::getTypeName<InterfaceT>());
175 #endif
176 
177  return static_cast<InterfaceT *>(
178  getRegisteredInterface(InterfaceT::getInterfaceID()));
179  }
180 
181  /// Lookup an op interface for the given ID if one is registered, otherwise
182  /// nullptr.
183  virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
184  OperationName opName) {
185  return nullptr;
186  }
187  template <typename InterfaceT>
188  typename InterfaceT::Concept *
190  return static_cast<typename InterfaceT::Concept *>(
191  getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
192  }
193 
194  /// Register a dialect interface with this dialect instance.
195  void addInterface(std::unique_ptr<DialectInterface> interface);
196 
197  /// Register a set of dialect interfaces with this dialect instance.
198  template <typename... Args>
199  void addInterfaces() {
200  (addInterface(std::make_unique<Args>(this)), ...);
201  }
202  template <typename InterfaceT, typename... Args>
203  InterfaceT &addInterface(Args &&...args) {
204  InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
205  addInterface(std::unique_ptr<DialectInterface>(interface));
206  return *interface;
207  }
208 
209  /// Declare that the given interface will be implemented, but has a delayed
210  /// registration. The promised interface type can be an interface of any type
211  /// not just a dialect interface, i.e. it may also be an
212  /// AttributeInterface/OpInterface/TypeInterface/etc.
213  template <typename InterfaceT, typename ConcreteT>
215  unresolvedPromisedInterfaces.insert(
216  {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
217  }
218 
219  // Declare the same interface for multiple types.
220  // Example:
221  // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
222  template <typename InterfaceT, typename... ConcreteT>
224  (declarePromisedInterface<InterfaceT, ConcreteT>(), ...);
225  }
226 
227  /// Checks if the given interface, which is attempting to be used, is a
228  /// promised interface of this dialect that has yet to be implemented. If so,
229  /// emits a fatal error. `interfaceName` is an optional string that contains a
230  /// more user readable name for the interface (such as the class name).
232  TypeID interfaceID,
233  StringRef interfaceName = "") {
234  if (unresolvedPromisedInterfaces.count(
235  {interfaceRequestorID, interfaceID})) {
236  llvm::report_fatal_error(
237  "checking for an interface (`" + interfaceName +
238  "`) that was promised by dialect '" + getNamespace() +
239  "' but never implemented. This is generally an indication "
240  "that the dialect extension implementing the interface was never "
241  "registered.");
242  }
243  }
244 
245  /// Checks if the given interface, which is attempting to be attached to a
246  /// construct owned by this dialect, is a promised interface of this dialect
247  /// that has yet to be implemented. If so, it resolves the interface promise.
249  TypeID interfaceID) {
250  unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
251  }
252 
253  /// Checks if a promise has been made for the interface/requestor pair.
254  bool hasPromisedInterface(TypeID interfaceRequestorID,
255  TypeID interfaceID) const {
256  return unresolvedPromisedInterfaces.count(
257  {interfaceRequestorID, interfaceID});
258  }
259 
260  /// Checks if a promise has been made for the interface/requestor pair.
261  template <typename ConcreteT, typename InterfaceT>
262  bool hasPromisedInterface() const {
263  return hasPromisedInterface(TypeID::get<ConcreteT>(),
264  InterfaceT::getInterfaceID());
265  }
266 
267 protected:
268  /// The constructor takes a unique namespace for this dialect as well as the
269  /// context to bind to.
270  /// Note: The namespace must not contain '.' characters.
271  /// Note: All operations belonging to this dialect must have names starting
272  /// with the namespace followed by '.'.
273  /// Example:
274  /// - "tf" for the TensorFlow ops like "tf.add".
275  Dialect(StringRef name, MLIRContext *context, TypeID id);
276 
277  /// This method is used by derived classes to add their operations to the set.
278  ///
279  template <typename... Args>
280  void addOperations() {
281  // This initializer_list argument pack expansion is essentially equal to
282  // using a fold expression with a comma operator. Clang however, refuses
283  // to compile a fold expression with a depth of more than 256 by default.
284  // There seem to be no such limitations for initializer_list.
285  (void)std::initializer_list<int>{
286  0, (RegisteredOperationName::insert<Args>(*this), 0)...};
287  }
288 
289  /// Register a set of type classes with this dialect.
290  template <typename... Args>
291  void addTypes() {
292  // This initializer_list argument pack expansion is essentially equal to
293  // using a fold expression with a comma operator. Clang however, refuses
294  // to compile a fold expression with a depth of more than 256 by default.
295  // There seem to be no such limitations for initializer_list.
296  (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
297  }
298 
299  /// Register a type instance with this dialect.
300  /// The use of this method is in general discouraged in favor of
301  /// 'addTypes<CustomType>()'.
302  void addType(TypeID typeID, AbstractType &&typeInfo);
303 
304  /// Register a set of attribute classes with this dialect.
305  template <typename... Args>
306  void addAttributes() {
307  // This initializer_list argument pack expansion is essentially equal to
308  // using a fold expression with a comma operator. Clang however, refuses
309  // to compile a fold expression with a depth of more than 256 by default.
310  // There seem to be no such limitations for initializer_list.
311  (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
312  }
313 
314  /// Register an attribute instance with this dialect.
315  /// The use of this method is in general discouraged in favor of
316  /// 'addAttributes<CustomAttr>()'.
317  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
318 
319  /// Enable support for unregistered operations.
320  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
321 
322  /// Enable support for unregistered types.
323  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
324 
325 private:
326  Dialect(const Dialect &) = delete;
327  void operator=(Dialect &) = delete;
328 
329  /// Register an attribute instance with this dialect.
330  template <typename T>
331  void addAttribute() {
332  // Add this attribute to the dialect and register it with the uniquer.
333  addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
334  detail::AttributeUniquer::registerAttribute<T>(context);
335  }
336 
337  /// Register a type instance with this dialect.
338  template <typename T>
339  void addType() {
340  // Add this type to the dialect and register it with the uniquer.
341  addType(T::getTypeID(), AbstractType::get<T>(*this));
342  detail::TypeUniquer::registerType<T>(context);
343  }
344 
345  /// The namespace of this dialect.
346  StringRef name;
347 
348  /// The unique identifier of the derived Op class, this is used in the context
349  /// to allow registering multiple times the same dialect.
350  TypeID dialectID;
351 
352  /// This is the context that owns this Dialect object.
353  MLIRContext *context;
354 
355  /// Flag that specifies whether this dialect supports unregistered operations,
356  /// i.e. operations prefixed with the dialect namespace but not registered
357  /// with addOperation.
358  bool unknownOpsAllowed = false;
359 
360  /// Flag that specifies whether this dialect allows unregistered types, i.e.
361  /// types prefixed with the dialect namespace but not registered with addType.
362  /// These types are represented with OpaqueType.
363  bool unknownTypesAllowed = false;
364 
365  /// A collection of registered dialect interfaces.
366  DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
367 
368  /// A set of interfaces that the dialect (or its constructs, i.e.
369  /// Attributes/Operations/Types/etc.) has promised to implement, but has yet
370  /// to provide an implementation for.
371  DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
372 
373  friend class DialectRegistry;
374  friend void registerDialect();
375  friend class MLIRContext;
376 };
377 
378 } // namespace mlir
379 
380 namespace llvm {
381 /// Provide isa functionality for Dialects.
382 template <typename T>
383 struct isa_impl<T, ::mlir::Dialect,
384  std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
385  static inline bool doit(const ::mlir::Dialect &dialect) {
386  return mlir::TypeID::get<T>() == dialect.getTypeID();
387  }
388 };
389 template <typename T>
390 struct isa_impl<
391  T, ::mlir::Dialect,
392  std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
393  static inline bool doit(const ::mlir::Dialect &dialect) {
394  return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
395  }
396 };
397 template <typename T>
398 struct cast_retty_impl<T, ::mlir::Dialect *> {
399  using ret_type = T *;
400 };
401 template <typename T>
402 struct cast_retty_impl<T, ::mlir::Dialect> {
403  using ret_type = T &;
404 };
405 
406 template <typename T>
407 struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
408  template <typename To>
409  static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
410  doitImpl(::mlir::Dialect &dialect) {
411  return static_cast<To &>(dialect);
412  }
413  template <typename To>
414  static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
415  To &>
416  doitImpl(::mlir::Dialect &dialect) {
417  return *dialect.getRegisteredInterface<To>();
418  }
419 
420  static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
421 };
422 template <class T>
423 struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
424  static auto doit(::mlir::Dialect *dialect) {
426  *dialect);
427  }
428 };
429 
430 } // namespace llvm
431 
432 #endif
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
This class contains all of the static information common to all instances of a registered Attribute.
This class contains all of the static information common to all instances of a registered Type.
Definition: TypeSupport.h:30
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents an interface overridden for a single dialect.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
void addAttributes()
Register a set of attribute classes with this dialect.
Definition: Dialect.h:306
void addInterfaces()
Register a set of dialect interfaces with this dialect instance.
Definition: Dialect.h:199
virtual ~Dialect()
virtual void * getRegisteredInterfaceForOp(TypeID interfaceID, OperationName opName)
Lookup an op interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:183
friend class MLIRContext
Definition: Dialect.h:375
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:66
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:79
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.h:231
StringRef getNamespace() const
Definition: Dialect.h:57
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition: Dialect.cpp:52
DialectInterface * getRegisteredInterface(TypeID interfaceID)
Lookup an interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:161
InterfaceT::Concept * getRegisteredInterfaceForOp(OperationName opName)
Definition: Dialect.h:189
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition: Dialect.cpp:43
void addOperations()
This method is used by derived classes to add their operations to the set.
Definition: Dialect.h:280
bool hasPromisedInterface() const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:262
InterfaceT * getRegisteredInterface()
Definition: Dialect.h:170
void declarePromisedInterface()
Declare that the given interface will be implemented, but has a delayed registration.
Definition: Dialect.h:214
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:84
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition: Dialect.h:248
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:70
virtual void printAttribute(Attribute, DialectAsmPrinter &) const
Print an attribute registered to this dialect.
Definition: Dialect.h:102
void addTypes()
Register a set of type classes with this dialect.
Definition: Dialect.h:291
virtual void printType(Type, DialectAsmPrinter &) const
Print a type registered to this dialect.
Definition: Dialect.h:110
MLIRContext * getContext() const
Definition: Dialect.h:55
void allowUnknownTypes(bool allow=true)
Enable support for unregistered types.
Definition: Dialect.h:323
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute)
Verify an attribute from this dialect on the given operation.
Definition: Dialect.h:151
void allowUnknownOperations(bool allow=true)
Enable support for unregistered operations.
Definition: Dialect.h:320
bool allowsUnknownOperations() const
Returns true if this dialect allows for unregistered operations, i.e.
Definition: Dialect.h:65
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:58
void declarePromisedInterfaces()
Definition: Dialect.h:223
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:98
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:60
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:254
friend void registerDialect()
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition: Dialect.cpp:32
InterfaceT & addInterface(Args &&...args)
Definition: Dialect.h:203
virtual void getCanonicalizationPatterns(RewritePatternSet &results) const
Register dialect-wide canonicalization patterns.
Definition: Dialect.h:77
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
Definition: CallGraph.h:229
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static auto & doit(::mlir::Dialect &dialect)
Definition: Dialect.h:420
static std::enable_if_t< std::is_base_of<::mlir::Dialect, To >::value, To & > doitImpl(::mlir::Dialect &dialect)
Definition: Dialect.h:410
static std::enable_if_t< std::is_base_of<::mlir::DialectInterface, To >::value, To & > doitImpl(::mlir::Dialect &dialect)
Definition: Dialect.h:416
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.