MLIR 22.0.0git
Dialect.h
Go to the documentation of this file.
1//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the 'dialect' abstraction.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_DIALECT_H
14#define MLIR_IR_DIALECT_H
15
18#include "mlir/Support/TypeID.h"
19
20namespace mlir {
24class OpBuilder;
25class Type;
26
27//===----------------------------------------------------------------------===//
28// Dialect
29//===----------------------------------------------------------------------===//
30
31/// Dialects are groups of MLIR operations, types and attributes, as well as
32/// behavior associated with the entire group. For example, hooks into other
33/// systems for constant folding, interfaces, default named types for asm
34/// printing, etc.
35///
36/// Instances of the dialect object are loaded in a specific MLIRContext.
37///
38class Dialect {
39public:
40 /// Type for a callback provided by the dialect to parse a custom operation.
41 /// This is used for the dialect to provide an alternative way to parse custom
42 /// operations, including unregistered ones.
44 function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
45
46 virtual ~Dialect();
47
48 /// Utility function that returns if the given string is a valid dialect
49 /// namespace
50 static bool isValidNamespace(StringRef str);
51
52 MLIRContext *getContext() const { return context; }
53
54 StringRef getNamespace() const { return name; }
55
56 /// Returns the unique identifier that corresponds to this dialect.
57 TypeID getTypeID() const { return dialectID; }
58
59 /// Returns true if this dialect allows for unregistered operations, i.e.
60 /// operations prefixed with the dialect namespace but not registered with
61 /// addOperation.
62 bool allowsUnknownOperations() const { return unknownOpsAllowed; }
63
64 /// Return true if this dialect allows for unregistered types, i.e., types
65 /// prefixed with the dialect namespace but not registered with addType.
66 /// These are represented with OpaqueType.
67 bool allowsUnknownTypes() const { return unknownTypesAllowed; }
68
69 /// Register dialect-wide canonicalization patterns. This method should only
70 /// be used to register canonicalization patterns that do not conceptually
71 /// belong to any single operation in the dialect. (In that case, use the op's
72 /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
73 /// be registered here.
74 virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
75
76 /// Registered hook to materialize a single constant operation from a given
77 /// attribute value with the desired resultant type. This method should use
78 /// the provided builder to create the operation without changing the
79 /// insertion position. The generated operation is expected to be constant
80 /// like, i.e. single result, zero operands, non side-effecting, etc. On
81 /// success, this hook should return the value generated to represent the
82 /// constant value. Otherwise, it should return null on failure.
84 Type type, Location loc) {
85 return nullptr;
86 }
87
88 //===--------------------------------------------------------------------===//
89 // Parsing Hooks
90 //===--------------------------------------------------------------------===//
91
92 /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
93 /// refers to the expected type of the attribute.
94 virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
95
96 /// Print an attribute registered to this dialect. Note: The type of the
97 /// attribute need not be printed by this method as it is always printed by
98 /// the caller.
100 llvm_unreachable("dialect has no registered attribute printing hook");
101 }
102
103 /// Parse a type registered to this dialect.
104 virtual Type parseType(DialectAsmParser &parser) const;
105
106 /// Print a type registered to this dialect.
107 virtual void printType(Type, DialectAsmPrinter &) const {
108 llvm_unreachable("dialect has no registered type printing hook");
109 }
110
111 /// Return the hook to parse an operation registered to this dialect, if any.
112 /// By default this will lookup for registered operations and return the
113 /// `parse()` method registered on the RegisteredOperationName. Dialects can
114 /// override this behavior and handle unregistered operations as well.
115 virtual std::optional<ParseOpHook>
116 getParseOperationHook(StringRef opName) const;
117
118 /// Print an operation registered to this dialect.
119 /// This hook is invoked for registered operation which don't override the
120 /// `print()` method to define their own custom assembly.
121 virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
123
124 //===--------------------------------------------------------------------===//
125 // Verification Hooks
126 //===--------------------------------------------------------------------===//
127
128 /// Verify an attribute from this dialect on the argument at 'argIndex' for
129 /// the region at 'regionIndex' on the given operation. Returns failure if
130 /// the verification failed, success otherwise. This hook may optionally be
131 /// invoked from any operation containing a region.
132 virtual LogicalResult verifyRegionArgAttribute(Operation *,
133 unsigned regionIndex,
134 unsigned argIndex,
136
137 /// Verify an attribute from this dialect on the result at 'resultIndex' for
138 /// the region at 'regionIndex' on the given operation. Returns failure if
139 /// the verification failed, success otherwise. This hook may optionally be
140 /// invoked from any operation containing a region.
141 virtual LogicalResult verifyRegionResultAttribute(Operation *,
142 unsigned regionIndex,
143 unsigned resultIndex,
145
146 /// Verify an attribute from this dialect on the given operation. Returns
147 /// failure if the verification failed, success otherwise.
149 return success();
150 }
151
152 //===--------------------------------------------------------------------===//
153 // Interfaces
154 //===--------------------------------------------------------------------===//
155
156 /// Lookup an interface for the given ID if one is registered, otherwise
157 /// nullptr.
159#ifndef NDEBUG
161#endif
162
163 auto it = registeredInterfaces.find(interfaceID);
164 return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
165 }
166 template <typename InterfaceT>
168#ifndef NDEBUG
170 InterfaceT::getInterfaceID(),
171 llvm::getTypeName<InterfaceT>());
172#endif
173
174 return static_cast<InterfaceT *>(
175 getRegisteredInterface(InterfaceT::getInterfaceID()));
176 }
177
178 /// Lookup an op interface for the given ID if one is registered, otherwise
179 /// nullptr.
180 virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
181 OperationName opName) {
182 return nullptr;
183 }
184 template <typename InterfaceT>
185 typename InterfaceT::Concept *
187 return static_cast<typename InterfaceT::Concept *>(
188 getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
189 }
190
191 /// Register a dialect interface with this dialect instance.
192 void addInterface(std::unique_ptr<DialectInterface> interface);
193
194 /// Register a set of dialect interfaces with this dialect instance.
195 template <typename... Args>
197 (addInterface(std::make_unique<Args>(this)), ...);
198 }
199 template <typename InterfaceT, typename... Args>
200 InterfaceT &addInterface(Args &&...args) {
201 InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
202 addInterface(std::unique_ptr<DialectInterface>(interface));
203 return *interface;
204 }
205
206 /// Declare that the given interface will be implemented, but has a delayed
207 /// registration. The promised interface type can be an interface of any type
208 /// not just a dialect interface, i.e. it may also be an
209 /// AttributeInterface/OpInterface/TypeInterface/etc.
210 template <typename InterfaceT, typename ConcreteT>
212 unresolvedPromisedInterfaces.insert(
213 {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
214 }
215
216 // Declare the same interface for multiple types.
217 // Example:
218 // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
219 template <typename InterfaceT, typename... ConcreteT>
223
224 /// Checks if the given interface, which is attempting to be used, is a
225 /// promised interface of this dialect that has yet to be implemented. If so,
226 /// emits a fatal error. `interfaceName` is an optional string that contains a
227 /// more user readable name for the interface (such as the class name).
229 TypeID interfaceID,
230 StringRef interfaceName = "") {
231 if (unresolvedPromisedInterfaces.count(
232 {interfaceRequestorID, interfaceID})) {
233 llvm::report_fatal_error(
234 "checking for an interface (`" + interfaceName +
235 "`) that was promised by dialect '" + getNamespace() +
236 "' but never implemented. This is generally an indication "
237 "that the dialect extension implementing the interface was never "
238 "registered.");
239 }
240 }
241
242 /// Checks if the given interface, which is attempting to be attached to a
243 /// construct owned by this dialect, is a promised interface of this dialect
244 /// that has yet to be implemented. If so, it resolves the interface promise.
246 TypeID interfaceID) {
247 unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
248 }
249
250 /// Checks if a promise has been made for the interface/requestor pair.
251 bool hasPromisedInterface(TypeID interfaceRequestorID,
252 TypeID interfaceID) const {
253 return unresolvedPromisedInterfaces.count(
254 {interfaceRequestorID, interfaceID});
255 }
256
257 /// Checks if a promise has been made for the interface/requestor pair.
258 template <typename ConcreteT, typename InterfaceT>
259 bool hasPromisedInterface() const {
261 InterfaceT::getInterfaceID());
262 }
263
264protected:
265 /// The constructor takes a unique namespace for this dialect as well as the
266 /// context to bind to.
267 /// Note: The namespace must not contain '.' characters.
268 /// Note: All operations belonging to this dialect must have names starting
269 /// with the namespace followed by '.'.
270 /// Example:
271 /// - "tf" for the TensorFlow ops like "tf.add".
272 Dialect(StringRef name, MLIRContext *context, TypeID id);
273
274 /// This method is used by derived classes to add their operations to the set.
275 ///
276 template <typename... Args>
278 // This initializer_list argument pack expansion is essentially equal to
279 // using a fold expression with a comma operator. Clang however, refuses
280 // to compile a fold expression with a depth of more than 256 by default.
281 // There seem to be no such limitations for initializer_list.
282 (void)std::initializer_list<int>{
283 0, (RegisteredOperationName::insert<Args>(*this), 0)...};
284 }
285
286 /// Register a set of type classes with this dialect.
287 template <typename... Args>
288 void addTypes() {
289 // This initializer_list argument pack expansion is essentially equal to
290 // using a fold expression with a comma operator. Clang however, refuses
291 // to compile a fold expression with a depth of more than 256 by default.
292 // There seem to be no such limitations for initializer_list.
293 (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
294 }
295
296 /// Register a type instance with this dialect.
297 /// The use of this method is in general discouraged in favor of
298 /// 'addTypes<CustomType>()'.
299 void addType(TypeID typeID, AbstractType &&typeInfo);
300
301 /// Register a set of attribute classes with this dialect.
302 template <typename... Args>
304 // This initializer_list argument pack expansion is essentially equal to
305 // using a fold expression with a comma operator. Clang however, refuses
306 // to compile a fold expression with a depth of more than 256 by default.
307 // There seem to be no such limitations for initializer_list.
308 (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
309 }
310
311 /// Register an attribute instance with this dialect.
312 /// The use of this method is in general discouraged in favor of
313 /// 'addAttributes<CustomAttr>()'.
314 void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
315
316 /// Enable support for unregistered operations.
317 void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
318
319 /// Enable support for unregistered types.
320 void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
321
322private:
323 Dialect(const Dialect &) = delete;
324 void operator=(Dialect &) = delete;
325
326 /// Register an attribute instance with this dialect.
327 template <typename T>
328 void addAttribute() {
329 // Add this attribute to the dialect and register it with the uniquer.
330 addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
332 }
333
334 /// Register a type instance with this dialect.
335 template <typename T>
336 void addType() {
337 // Add this type to the dialect and register it with the uniquer.
338 addType(T::getTypeID(), AbstractType::get<T>(*this));
340 }
341
342 /// The namespace of this dialect.
343 StringRef name;
344
345 /// The unique identifier of the derived Op class, this is used in the context
346 /// to allow registering multiple times the same dialect.
347 TypeID dialectID;
348
349 /// This is the context that owns this Dialect object.
350 MLIRContext *context;
351
352 /// Flag that specifies whether this dialect supports unregistered operations,
353 /// i.e. operations prefixed with the dialect namespace but not registered
354 /// with addOperation.
355 bool unknownOpsAllowed = false;
356
357 /// Flag that specifies whether this dialect allows unregistered types, i.e.
358 /// types prefixed with the dialect namespace but not registered with addType.
359 /// These types are represented with OpaqueType.
360 bool unknownTypesAllowed = false;
361
362 /// A collection of registered dialect interfaces.
364
365 /// A set of interfaces that the dialect (or its constructs, i.e.
366 /// Attributes/Operations/Types/etc.) has promised to implement, but has yet
367 /// to provide an implementation for.
368 DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
369
370 friend class DialectRegistry;
371 friend class MLIRContext;
372};
373
374} // namespace mlir
375
376namespace llvm {
377/// Provide isa functionality for Dialects.
378template <typename T>
379struct isa_impl<T, ::mlir::Dialect,
380 std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
381 static inline bool doit(const ::mlir::Dialect &dialect) {
382 return mlir::TypeID::get<T>() == dialect.getTypeID();
383 }
384};
385template <typename T>
386struct isa_impl<
387 T, ::mlir::Dialect,
388 std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
389 static inline bool doit(const ::mlir::Dialect &dialect) {
390 return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
391 }
392};
393template <typename T>
394struct cast_retty_impl<T, ::mlir::Dialect *> {
395 using ret_type = T *;
396};
397template <typename T>
398struct cast_retty_impl<T, ::mlir::Dialect> {
399 using ret_type = T &;
400};
401
402template <typename T>
403struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
404 template <typename To>
405 static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
407 return static_cast<To &>(dialect);
408 }
409 template <typename To>
410 static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
411 To &>
413 return *dialect.getRegisteredInterface<To>();
414 }
415
416 static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
417};
418template <class T>
419struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
420 static auto doit(::mlir::Dialect *dialect) {
421 return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
422 *dialect);
423 }
424};
425
426} // namespace llvm
427
428#endif
return success()
This class contains all of the static information common to all instances of a registered Attribute.
static AbstractAttribute get(Dialect &dialect)
This method is used by Dialect objects when they register the list of attributes they contain.
This class contains all of the static information common to all instances of a registered Type.
Definition TypeSupport.h:30
static AbstractType get(Dialect &dialect)
This method is used by Dialect objects when they register the list of types they contain.
Definition TypeSupport.h:50
Attributes are known-constant values of operations.
Definition Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents an interface overridden for a single dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
MLIRContext * getContext() const
Definition Dialect.h:52
void addAttributes()
Register a set of attribute classes with this dialect.
Definition Dialect.h:303
void addInterfaces()
Register a set of dialect interfaces with this dialect instance.
Definition Dialect.h:196
virtual ~Dialect()
friend class MLIRContext
Definition Dialect.h:371
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition Dialect.cpp:69
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition Dialect.cpp:82
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.h:228
StringRef getNamespace() const
Definition Dialect.h:54
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition Dialect.cpp:55
function_ref< ParseResult(OpAsmParser &parser, OperationState &result)> ParseOpHook
Type for a callback provided by the dialect to parse a custom operation.
Definition Dialect.h:43
InterfaceT::Concept * getRegisteredInterfaceForOp(OperationName opName)
Definition Dialect.h:186
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition Dialect.cpp:46
void addOperations()
This method is used by derived classes to add their operations to the set.
Definition Dialect.h:277
bool hasPromisedInterface() const
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.h:259
InterfaceT & addInterface(Args &&...args)
Definition Dialect.h:200
void declarePromisedInterface()
Declare that the given interface will be implemented, but has a delayed registration.
Definition Dialect.h:211
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition Dialect.cpp:87
virtual void * getRegisteredInterfaceForOp(TypeID interfaceID, OperationName opName)
Lookup an op interface for the given ID if one is registered, otherwise nullptr.
Definition Dialect.h:180
void addType(TypeID typeID, AbstractType &&typeInfo)
Register a type instance with this dialect.
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition Dialect.h:245
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition Dialect.h:67
virtual void printAttribute(Attribute, DialectAsmPrinter &) const
Print an attribute registered to this dialect.
Definition Dialect.h:99
void addTypes()
Register a set of type classes with this dialect.
Definition Dialect.h:288
virtual void printType(Type, DialectAsmPrinter &) const
Print a type registered to this dialect.
Definition Dialect.h:107
void allowUnknownTypes(bool allow=true)
Enable support for unregistered types.
Definition Dialect.h:320
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute)
Verify an attribute from this dialect on the given operation.
Definition Dialect.h:148
void allowUnknownOperations(bool allow=true)
Enable support for unregistered operations.
Definition Dialect.h:317
bool allowsUnknownOperations() const
Returns true if this dialect allows for unregistered operations, i.e.
Definition Dialect.h:62
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition Dialect.cpp:61
void declarePromisedInterfaces()
Definition Dialect.h:220
InterfaceT * getRegisteredInterface()
Definition Dialect.h:167
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition Dialect.cpp:101
friend class DialectRegistry
Definition Dialect.h:370
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition Dialect.h:57
DialectInterface * getRegisteredInterface(TypeID interfaceID)
Lookup an interface for the given ID if one is registered, otherwise nullptr.
Definition Dialect.h:158
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.h:251
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition Dialect.cpp:35
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo)
Register an attribute instance with this dialect.
virtual void getCanonicalizationPatterns(RewritePatternSet &results) const
Register dialect-wide canonicalization patterns.
Definition Dialect.h:74
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
static void insert(Dialect &dialect)
Register a new operation in a Dialect object.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static void registerAttribute(MLIRContext *ctx)
Register an attribute instance T with the uniquer.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
static std::enable_if_t< std::is_base_of<::mlir::DialectInterface, To >::value, To & > doitImpl(::mlir::Dialect &dialect)
Definition Dialect.h:412
static auto & doit(::mlir::Dialect &dialect)
Definition Dialect.h:416
static std::enable_if_t< std::is_base_of<::mlir::Dialect, To >::value, To & > doitImpl(::mlir::Dialect &dialect)
Definition Dialect.h:406
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static void registerType(MLIRContext *ctx)
Register a type instance T with the uniquer.