MLIR  19.0.0git
DialectInterface.h
Go to the documentation of this file.
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 #include "llvm/ADT/STLExtras.h"
15 
16 namespace mlir {
17 class Dialect;
18 class MLIRContext;
19 class Operation;
20 
21 //===----------------------------------------------------------------------===//
22 // DialectInterface
23 //===----------------------------------------------------------------------===//
24 namespace detail {
25 /// The base class used for all derived interface types. This class provides
26 /// utilities necessary for registration.
27 template <typename ConcreteType, typename BaseT>
28 class DialectInterfaceBase : public BaseT {
29 public:
31 
32  /// Get a unique id for the derived interface type.
33  static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
34 
35 protected:
36  DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
37 };
38 } // namespace detail
39 
40 /// This class represents an interface overridden for a single dialect.
42 public:
43  virtual ~DialectInterface();
44 
45  /// The base class used for all derived interface types. This class provides
46  /// utilities necessary for registration.
47  template <typename ConcreteType>
49 
50  /// Return the dialect that this interface represents.
51  Dialect *getDialect() const { return dialect; }
52 
53  /// Return the context that holds the parent dialect of this interface.
54  MLIRContext *getContext() const;
55 
56  /// Return the derived interface id.
57  TypeID getID() const { return interfaceID; }
58 
59 protected:
61  : dialect(dialect), interfaceID(id) {}
62 
63 private:
64  /// The dialect that represents this interface.
65  Dialect *dialect;
66 
67  /// The unique identifier for the derived interface type.
68  TypeID interfaceID;
69 };
70 
71 //===----------------------------------------------------------------------===//
72 // DialectInterfaceCollection
73 //===----------------------------------------------------------------------===//
74 
75 namespace detail {
76 /// This class is the base class for a collection of instances for a specific
77 /// interface kind.
79  /// DenseMap info for dialect interfaces that allows lookup by the dialect.
80  struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
82 
83  static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
84  static unsigned getHashValue(const DialectInterface *key) {
85  return getHashValue(key->getDialect());
86  }
87 
88  static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
89  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
90  return false;
91  return lhs == rhs->getDialect();
92  }
93  };
94 
95  /// A set of registered dialect interface instances.
97  using InterfaceVectorT = std::vector<const DialectInterface *>;
98 
99 public:
101  StringRef interfaceName);
103 
104 protected:
105  /// Get the interface for the dialect of given operation, or null if one
106  /// is not registered.
107  const DialectInterface *getInterfaceFor(Operation *op) const;
108 
109  /// Get the interface for the given dialect.
110  const DialectInterface *getInterfaceFor(Dialect *dialect) const {
111  auto it = interfaces.find_as(dialect);
112  return it == interfaces.end() ? nullptr : *it;
113  }
114 
115  /// An iterator class that iterates the held interface objects of the given
116  /// derived interface type.
117  template <typename InterfaceT>
118  struct iterator
119  : public llvm::mapped_iterator_base<iterator<InterfaceT>,
120  InterfaceVectorT::const_iterator,
121  const InterfaceT &> {
122  using llvm::mapped_iterator_base<iterator<InterfaceT>,
123  InterfaceVectorT::const_iterator,
124  const InterfaceT &>::mapped_iterator_base;
125 
126  /// Map the element to the iterator result type.
127  const InterfaceT &mapElement(const DialectInterface *interface) const {
128  return *static_cast<const InterfaceT *>(interface);
129  }
130  };
131 
132  /// Iterator access to the held interfaces.
133  template <typename InterfaceT>
135  return iterator<InterfaceT>(orderedInterfaces.begin());
136  }
137  template <typename InterfaceT>
139  return iterator<InterfaceT>(orderedInterfaces.end());
140  }
141 
142 private:
143  /// A set of registered dialect interface instances.
144  InterfaceSetT interfaces;
145  /// An ordered list of the registered interface instances, necessary for
146  /// deterministic iteration.
147  // NOTE: SetVector does not provide find access, so it can't be used here.
148  InterfaceVectorT orderedInterfaces;
149 };
150 } // namespace detail
151 
152 /// A collection of dialect interfaces within a context, for a given concrete
153 /// interface type.
154 template <typename InterfaceType>
157 public:
159 
160  /// Collect the registered dialect interfaces within the provided context.
163  ctx, InterfaceType::getInterfaceID(),
164  llvm::getTypeName<InterfaceType>()) {}
165 
166  /// Get the interface for a given object, or null if one is not registered.
167  /// The object may be a dialect or an operation instance.
168  template <typename Object>
169  const InterfaceType *getInterfaceFor(Object *obj) const {
170  return static_cast<const InterfaceType *>(
172  }
173 
174  /// Iterator access to the held interfaces.
175  using iterator =
177  iterator begin() const { return interface_begin<InterfaceType>(); }
178  iterator end() const { return interface_end<InterfaceType>(); }
179 
180 private:
183 };
184 
185 } // namespace mlir
186 
187 #endif
A collection of dialect interfaces within a context, for a given concrete interface type.
DialectInterfaceCollection(MLIRContext *ctx)
Collect the registered dialect interfaces within the provided context.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition: Dialect.cpp:120
DialectInterface(Dialect *dialect, TypeID id)
Dialect * getDialect() const
Return the dialect that this interface represents.
TypeID getID() const
Return the derived interface id.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
The base class used for all derived interface types.
static TypeID getInterfaceID()
Get a unique id for the derived interface type.
This class is the base class for a collection of instances for a specific interface kind.
const DialectInterface * getInterfaceFor(Operation *op) const
Get the interface for the dialect of given operation, or null if one is not registered.
Definition: Dialect.cpp:143
iterator< InterfaceT > interface_end() const
const DialectInterface * getInterfaceFor(Dialect *dialect) const
Get the interface for the given dialect.
iterator< InterfaceT > interface_begin() const
Iterator access to the held interfaces.
DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName)
Definition: Dialect.cpp:124
Include the generated interface declarations.
Definition: CallGraph.h:229
llvm::hash_code hash_value(const MPInt &x)
Redeclarations of friend declaration above to make it discoverable by lookups.
Definition: MPInt.cpp:17
Include the generated interface declarations.
An iterator class that iterates the held interface objects of the given derived interface type.
const InterfaceT & mapElement(const DialectInterface *interface) const
Map the element to the iterator result type.