MLIR  20.0.0git
Dialect.h
Go to the documentation of this file.
1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the 'dialect' abstraction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_H
14 #define MLIR_IR_DIALECT_H
15 
18 #include "mlir/Support/TypeID.h"
19 
20 namespace mlir {
21 class DialectAsmParser;
22 class DialectAsmPrinter;
23 class DialectInterface;
24 class OpBuilder;
25 class Type;
26 
27 //===----------------------------------------------------------------------===//
28 // Dialect
29 //===----------------------------------------------------------------------===//
30 
31 /// Dialects are groups of MLIR operations, types and attributes, as well as
32 /// behavior associated with the entire group. For example, hooks into other
33 /// systems for constant folding, interfaces, default named types for asm
34 /// printing, etc.
35 ///
36 /// Instances of the dialect object are loaded in a specific MLIRContext.
37 ///
38 class Dialect {
39 public:
40  /// Type for a callback provided by the dialect to parse a custom operation.
41  /// This is used for the dialect to provide an alternative way to parse custom
42  /// operations, including unregistered ones.
43  using ParseOpHook =
44  function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
45 
46  virtual ~Dialect();
47 
48  /// Utility function that returns if the given string is a valid dialect
49  /// namespace
50  static bool isValidNamespace(StringRef str);
51 
52  MLIRContext *getContext() const { return context; }
53 
54  StringRef getNamespace() const { return name; }
55 
56  /// Returns the unique identifier that corresponds to this dialect.
57  TypeID getTypeID() const { return dialectID; }
58 
59  /// Returns true if this dialect allows for unregistered operations, i.e.
60  /// operations prefixed with the dialect namespace but not registered with
61  /// addOperation.
62  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
63 
64  /// Return true if this dialect allows for unregistered types, i.e., types
65  /// prefixed with the dialect namespace but not registered with addType.
66  /// These are represented with OpaqueType.
67  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
68 
69  /// Register dialect-wide canonicalization patterns. This method should only
70  /// be used to register canonicalization patterns that do not conceptually
71  /// belong to any single operation in the dialect. (In that case, use the op's
72  /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
73  /// be registered here.
74  virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
75 
76  /// Registered hook to materialize a single constant operation from a given
77  /// attribute value with the desired resultant type. This method should use
78  /// the provided builder to create the operation without changing the
79  /// insertion position. The generated operation is expected to be constant
80  /// like, i.e. single result, zero operands, non side-effecting, etc. On
81  /// success, this hook should return the value generated to represent the
82  /// constant value. Otherwise, it should return null on failure.
84  Type type, Location loc) {
85  return nullptr;
86  }
87 
88  //===--------------------------------------------------------------------===//
89  // Parsing Hooks
90  //===--------------------------------------------------------------------===//
91 
92  /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
93  /// refers to the expected type of the attribute.
94  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
95 
96  /// Print an attribute registered to this dialect. Note: The type of the
97  /// attribute need not be printed by this method as it is always printed by
98  /// the caller.
99  virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
100  llvm_unreachable("dialect has no registered attribute printing hook");
101  }
102 
103  /// Parse a type registered to this dialect.
104  virtual Type parseType(DialectAsmParser &parser) const;
105 
106  /// Print a type registered to this dialect.
107  virtual void printType(Type, DialectAsmPrinter &) const {
108  llvm_unreachable("dialect has no registered type printing hook");
109  }
110 
111  /// Return the hook to parse an operation registered to this dialect, if any.
112  /// By default this will lookup for registered operations and return the
113  /// `parse()` method registered on the RegisteredOperationName. Dialects can
114  /// override this behavior and handle unregistered operations as well.
115  virtual std::optional<ParseOpHook>
116  getParseOperationHook(StringRef opName) const;
117 
118  /// Print an operation registered to this dialect.
119  /// This hook is invoked for registered operation which don't override the
120  /// `print()` method to define their own custom assembly.
121  virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
122  getOperationPrinter(Operation *op) const;
123 
124  //===--------------------------------------------------------------------===//
125  // Verification Hooks
126  //===--------------------------------------------------------------------===//
127 
128  /// Verify an attribute from this dialect on the argument at 'argIndex' for
129  /// the region at 'regionIndex' on the given operation. Returns failure if
130  /// the verification failed, success otherwise. This hook may optionally be
131  /// invoked from any operation containing a region.
132  virtual LogicalResult verifyRegionArgAttribute(Operation *,
133  unsigned regionIndex,
134  unsigned argIndex,
136 
137  /// Verify an attribute from this dialect on the result at 'resultIndex' for
138  /// the region at 'regionIndex' on the given operation. Returns failure if
139  /// the verification failed, success otherwise. This hook may optionally be
140  /// invoked from any operation containing a region.
141  virtual LogicalResult verifyRegionResultAttribute(Operation *,
142  unsigned regionIndex,
143  unsigned resultIndex,
145 
146  /// Verify an attribute from this dialect on the given operation. Returns
147  /// failure if the verification failed, success otherwise.
149  return success();
150  }
151 
152  //===--------------------------------------------------------------------===//
153  // Interfaces
154  //===--------------------------------------------------------------------===//
155 
156  /// Lookup an interface for the given ID if one is registered, otherwise
157  /// nullptr.
159 #ifndef NDEBUG
161 #endif
162 
163  auto it = registeredInterfaces.find(interfaceID);
164  return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
165  }
166  template <typename InterfaceT>
167  InterfaceT *getRegisteredInterface() {
168 #ifndef NDEBUG
170  InterfaceT::getInterfaceID(),
171  llvm::getTypeName<InterfaceT>());
172 #endif
173 
174  return static_cast<InterfaceT *>(
175  getRegisteredInterface(InterfaceT::getInterfaceID()));
176  }
177 
178  /// Lookup an op interface for the given ID if one is registered, otherwise
179  /// nullptr.
180  virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
181  OperationName opName) {
182  return nullptr;
183  }
184  template <typename InterfaceT>
185  typename InterfaceT::Concept *
187  return static_cast<typename InterfaceT::Concept *>(
188  getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
189  }
190 
191  /// Register a dialect interface with this dialect instance.
192  void addInterface(std::unique_ptr<DialectInterface> interface);
193 
194  /// Register a set of dialect interfaces with this dialect instance.
195  template <typename... Args>
196  void addInterfaces() {
197  (addInterface(std::make_unique<Args>(this)), ...);
198  }
199  template <typename InterfaceT, typename... Args>
200  InterfaceT &addInterface(Args &&...args) {
201  InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
202  addInterface(std::unique_ptr<DialectInterface>(interface));
203  return *interface;
204  }
205 
206  /// Declare that the given interface will be implemented, but has a delayed
207  /// registration. The promised interface type can be an interface of any type
208  /// not just a dialect interface, i.e. it may also be an
209  /// AttributeInterface/OpInterface/TypeInterface/etc.
210  template <typename InterfaceT, typename ConcreteT>
212  unresolvedPromisedInterfaces.insert(
213  {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
214  }
215 
216  // Declare the same interface for multiple types.
217  // Example:
218  // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
219  template <typename InterfaceT, typename... ConcreteT>
221  (declarePromisedInterface<InterfaceT, ConcreteT>(), ...);
222  }
223 
224  /// Checks if the given interface, which is attempting to be used, is a
225  /// promised interface of this dialect that has yet to be implemented. If so,
226  /// emits a fatal error. `interfaceName` is an optional string that contains a
227  /// more user readable name for the interface (such as the class name).
229  TypeID interfaceID,
230  StringRef interfaceName = "") {
231  if (unresolvedPromisedInterfaces.count(
232  {interfaceRequestorID, interfaceID})) {
233  llvm::report_fatal_error(
234  "checking for an interface (`" + interfaceName +
235  "`) that was promised by dialect '" + getNamespace() +
236  "' but never implemented. This is generally an indication "
237  "that the dialect extension implementing the interface was never "
238  "registered.");
239  }
240  }
241 
242  /// Checks if the given interface, which is attempting to be attached to a
243  /// construct owned by this dialect, is a promised interface of this dialect
244  /// that has yet to be implemented. If so, it resolves the interface promise.
246  TypeID interfaceID) {
247  unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
248  }
249 
250  /// Checks if a promise has been made for the interface/requestor pair.
251  bool hasPromisedInterface(TypeID interfaceRequestorID,
252  TypeID interfaceID) const {
253  return unresolvedPromisedInterfaces.count(
254  {interfaceRequestorID, interfaceID});
255  }
256 
257  /// Checks if a promise has been made for the interface/requestor pair.
258  template <typename ConcreteT, typename InterfaceT>
259  bool hasPromisedInterface() const {
260  return hasPromisedInterface(TypeID::get<ConcreteT>(),
261  InterfaceT::getInterfaceID());
262  }
263 
264 protected:
265  /// The constructor takes a unique namespace for this dialect as well as the
266  /// context to bind to.
267  /// Note: The namespace must not contain '.' characters.
268  /// Note: All operations belonging to this dialect must have names starting
269  /// with the namespace followed by '.'.
270  /// Example:
271  /// - "tf" for the TensorFlow ops like "tf.add".
272  Dialect(StringRef name, MLIRContext *context, TypeID id);
273 
274  /// This method is used by derived classes to add their operations to the set.
275  ///
276  template <typename... Args>
277  void addOperations() {
278  // This initializer_list argument pack expansion is essentially equal to
279  // using a fold expression with a comma operator. Clang however, refuses
280  // to compile a fold expression with a depth of more than 256 by default.
281  // There seem to be no such limitations for initializer_list.
282  (void)std::initializer_list<int>{
283  0, (RegisteredOperationName::insert<Args>(*this), 0)...};
284  }
285 
286  /// Register a set of type classes with this dialect.
287  template <typename... Args>
288  void addTypes() {
289  // This initializer_list argument pack expansion is essentially equal to
290  // using a fold expression with a comma operator. Clang however, refuses
291  // to compile a fold expression with a depth of more than 256 by default.
292  // There seem to be no such limitations for initializer_list.
293  (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
294  }
295 
296  /// Register a type instance with this dialect.
297  /// The use of this method is in general discouraged in favor of
298  /// 'addTypes<CustomType>()'.
299  void addType(TypeID typeID, AbstractType &&typeInfo);
300 
301  /// Register a set of attribute classes with this dialect.
302  template <typename... Args>
303  void addAttributes() {
304  // This initializer_list argument pack expansion is essentially equal to
305  // using a fold expression with a comma operator. Clang however, refuses
306  // to compile a fold expression with a depth of more than 256 by default.
307  // There seem to be no such limitations for initializer_list.
308  (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
309  }
310 
311  /// Register an attribute instance with this dialect.
312  /// The use of this method is in general discouraged in favor of
313  /// 'addAttributes<CustomAttr>()'.
314  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
315 
316  /// Enable support for unregistered operations.
317  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
318 
319  /// Enable support for unregistered types.
320  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
321 
322 private:
323  Dialect(const Dialect &) = delete;
324  void operator=(Dialect &) = delete;
325 
326  /// Register an attribute instance with this dialect.
327  template <typename T>
328  void addAttribute() {
329  // Add this attribute to the dialect and register it with the uniquer.
330  addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
331  detail::AttributeUniquer::registerAttribute<T>(context);
332  }
333 
334  /// Register a type instance with this dialect.
335  template <typename T>
336  void addType() {
337  // Add this type to the dialect and register it with the uniquer.
338  addType(T::getTypeID(), AbstractType::get<T>(*this));
339  detail::TypeUniquer::registerType<T>(context);
340  }
341 
342  /// The namespace of this dialect.
343  StringRef name;
344 
345  /// The unique identifier of the derived Op class, this is used in the context
346  /// to allow registering multiple times the same dialect.
347  TypeID dialectID;
348 
349  /// This is the context that owns this Dialect object.
350  MLIRContext *context;
351 
352  /// Flag that specifies whether this dialect supports unregistered operations,
353  /// i.e. operations prefixed with the dialect namespace but not registered
354  /// with addOperation.
355  bool unknownOpsAllowed = false;
356 
357  /// Flag that specifies whether this dialect allows unregistered types, i.e.
358  /// types prefixed with the dialect namespace but not registered with addType.
359  /// These types are represented with OpaqueType.
360  bool unknownTypesAllowed = false;
361 
362  /// A collection of registered dialect interfaces.
363  DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
364 
365  /// A set of interfaces that the dialect (or its constructs, i.e.
366  /// Attributes/Operations/Types/etc.) has promised to implement, but has yet
367  /// to provide an implementation for.
368  DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
369 
370  friend class DialectRegistry;
371  friend void registerDialect();
372  friend class MLIRContext;
373 };
374 
375 } // namespace mlir
376 
377 namespace llvm {
378 /// Provide isa functionality for Dialects.
379 template <typename T>
380 struct isa_impl<T, ::mlir::Dialect,
381  std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
382  static inline bool doit(const ::mlir::Dialect &dialect) {
383  return mlir::TypeID::get<T>() == dialect.getTypeID();
384  }
385 };
386 template <typename T>
387 struct isa_impl<
388  T, ::mlir::Dialect,
389  std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
390  static inline bool doit(const ::mlir::Dialect &dialect) {
391  return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
392  }
393 };
394 template <typename T>
395 struct cast_retty_impl<T, ::mlir::Dialect *> {
396  using ret_type = T *;
397 };
398 template <typename T>
399 struct cast_retty_impl<T, ::mlir::Dialect> {
400  using ret_type = T &;
401 };
402 
403 template <typename T>
404 struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
405  template <typename To>
406  static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
407  doitImpl(::mlir::Dialect &dialect) {
408  return static_cast<To &>(dialect);
409  }
410  template <typename To>
411  static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
412  To &>
413  doitImpl(::mlir::Dialect &dialect) {
414  return *dialect.getRegisteredInterface<To>();
415  }
416 
417  static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
418 };
419 template <class T>
420 struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
421  static auto doit(::mlir::Dialect *dialect) {
423  *dialect);
424  }
425 };
426 
427 } // namespace llvm
428 
429 #endif
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
This class contains all of the static information common to all instances of a registered Attribute.
This class contains all of the static information common to all instances of a registered Type.
Definition: TypeSupport.h:30
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents an interface overridden for a single dialect.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
void addAttributes()
Register a set of attribute classes with this dialect.
Definition: Dialect.h:303
void addInterfaces()
Register a set of dialect interfaces with this dialect instance.
Definition: Dialect.h:196
virtual ~Dialect()
virtual void * getRegisteredInterfaceForOp(TypeID interfaceID, OperationName opName)
Lookup an op interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:180
friend class MLIRContext
Definition: Dialect.h:372
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:72
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:85
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.h:228
StringRef getNamespace() const
Definition: Dialect.h:54
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition: Dialect.cpp:58
DialectInterface * getRegisteredInterface(TypeID interfaceID)
Lookup an interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:158
InterfaceT::Concept * getRegisteredInterfaceForOp(OperationName opName)
Definition: Dialect.h:186
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition: Dialect.cpp:49
void addOperations()
This method is used by derived classes to add their operations to the set.
Definition: Dialect.h:277
bool hasPromisedInterface() const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:259
InterfaceT * getRegisteredInterface()
Definition: Dialect.h:167
void declarePromisedInterface()
Declare that the given interface will be implemented, but has a delayed registration.
Definition: Dialect.h:211
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:90
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition: Dialect.h:245
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:67
virtual void printAttribute(Attribute, DialectAsmPrinter &) const
Print an attribute registered to this dialect.
Definition: Dialect.h:99
void addTypes()
Register a set of type classes with this dialect.
Definition: Dialect.h:288
virtual void printType(Type, DialectAsmPrinter &) const
Print a type registered to this dialect.
Definition: Dialect.h:107
MLIRContext * getContext() const
Definition: Dialect.h:52
void allowUnknownTypes(bool allow=true)
Enable support for unregistered types.
Definition: Dialect.h:320
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute)
Verify an attribute from this dialect on the given operation.
Definition: Dialect.h:148
void allowUnknownOperations(bool allow=true)
Enable support for unregistered operations.
Definition: Dialect.h:317
bool allowsUnknownOperations() const
Returns true if this dialect allows for unregistered operations, i.e.
Definition: Dialect.h:62
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:64
void declarePromisedInterfaces()
Definition: Dialect.h:220
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:104
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:57
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:251
friend void registerDialect()
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition: Dialect.cpp:38
InterfaceT & addInterface(Args &&...args)
Definition: Dialect.h:200
virtual void getCanonicalizationPatterns(RewritePatternSet &results) const
Register dialect-wide canonicalization patterns.
Definition: Dialect.h:74
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
static auto & doit(::mlir::Dialect &dialect)
Definition: Dialect.h:417
static std::enable_if_t< std::is_base_of<::mlir::Dialect, To >::value, To & > doitImpl(::mlir::Dialect &dialect)
Definition: Dialect.h:407
static std::enable_if_t< std::is_base_of<::mlir::DialectInterface, To >::value, To & > doitImpl(::mlir::Dialect &dialect)
Definition: Dialect.h:413
This represents an operation in an abstracted form, suitable for use with the builder APIs.