MLIR  17.0.0git
DialectInterface.h
Go to the documentation of this file.
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 #include "llvm/ADT/STLExtras.h"
15 
16 namespace mlir {
17 class Dialect;
18 class MLIRContext;
19 class Operation;
20 
21 //===----------------------------------------------------------------------===//
22 // DialectInterface
23 //===----------------------------------------------------------------------===//
24 namespace detail {
25 /// The base class used for all derived interface types. This class provides
26 /// utilities necessary for registration.
27 template <typename ConcreteType, typename BaseT>
28 class DialectInterfaceBase : public BaseT {
29 public:
31 
32  /// Get a unique id for the derived interface type.
33  static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
34 
35 protected:
36  DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
37 };
38 } // namespace detail
39 
40 /// This class represents an interface overridden for a single dialect.
42 public:
43  virtual ~DialectInterface();
44 
45  /// The base class used for all derived interface types. This class provides
46  /// utilities necessary for registration.
47  template <typename ConcreteType>
49 
50  /// Return the dialect that this interface represents.
51  Dialect *getDialect() const { return dialect; }
52 
53  /// Return the context that holds the parent dialect of this interface.
54  MLIRContext *getContext() const;
55 
56  /// Return the derived interface id.
57  TypeID getID() const { return interfaceID; }
58 
59 protected:
61  : dialect(dialect), interfaceID(id) {}
62 
63 private:
64  /// The dialect that represents this interface.
65  Dialect *dialect;
66 
67  /// The unique identifier for the derived interface type.
68  TypeID interfaceID;
69 };
70 
71 //===----------------------------------------------------------------------===//
72 // DialectInterfaceCollection
73 //===----------------------------------------------------------------------===//
74 
75 namespace detail {
76 /// This class is the base class for a collection of instances for a specific
77 /// interface kind.
79  /// DenseMap info for dialect interfaces that allows lookup by the dialect.
80  struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
82 
83  static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
84  static unsigned getHashValue(const DialectInterface *key) {
85  return getHashValue(key->getDialect());
86  }
87 
88  static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
89  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
90  return false;
91  return lhs == rhs->getDialect();
92  }
93  };
94 
95  /// A set of registered dialect interface instances.
97  using InterfaceVectorT = std::vector<const DialectInterface *>;
98 
99 public:
102 
103 protected:
104  /// Get the interface for the dialect of given operation, or null if one
105  /// is not registered.
106  const DialectInterface *getInterfaceFor(Operation *op) const;
107 
108  /// Get the interface for the given dialect.
109  const DialectInterface *getInterfaceFor(Dialect *dialect) const {
110  auto it = interfaces.find_as(dialect);
111  return it == interfaces.end() ? nullptr : *it;
112  }
113 
114  /// An iterator class that iterates the held interface objects of the given
115  /// derived interface type.
116  template <typename InterfaceT>
117  struct iterator
118  : public llvm::mapped_iterator_base<iterator<InterfaceT>,
119  InterfaceVectorT::const_iterator,
120  const InterfaceT &> {
121  using llvm::mapped_iterator_base<iterator<InterfaceT>,
122  InterfaceVectorT::const_iterator,
123  const InterfaceT &>::mapped_iterator_base;
124 
125  /// Map the element to the iterator result type.
126  const InterfaceT &mapElement(const DialectInterface *interface) const {
127  return *static_cast<const InterfaceT *>(interface);
128  }
129  };
130 
131  /// Iterator access to the held interfaces.
132  template <typename InterfaceT>
134  return iterator<InterfaceT>(orderedInterfaces.begin());
135  }
136  template <typename InterfaceT>
138  return iterator<InterfaceT>(orderedInterfaces.end());
139  }
140 
141 private:
142  /// A set of registered dialect interface instances.
143  InterfaceSetT interfaces;
144  /// An ordered list of the registered interface instances, necessary for
145  /// deterministic iteration.
146  // NOTE: SetVector does not provide find access, so it can't be used here.
147  InterfaceVectorT orderedInterfaces;
148 };
149 } // namespace detail
150 
151 /// A collection of dialect interfaces within a context, for a given concrete
152 /// interface type.
153 template <typename InterfaceType>
156 public:
158 
159  /// Collect the registered dialect interfaces within the provided context.
162  ctx, InterfaceType::getInterfaceID()) {}
163 
164  /// Get the interface for a given object, or null if one is not registered.
165  /// The object may be a dialect or an operation instance.
166  template <typename Object>
167  const InterfaceType *getInterfaceFor(Object *obj) const {
168  return static_cast<const InterfaceType *>(
170  }
171 
172  /// Iterator access to the held interfaces.
173  using iterator =
175  iterator begin() const { return interface_begin<InterfaceType>(); }
176  iterator end() const { return interface_end<InterfaceType>(); }
177 
178 private:
181 };
182 
183 } // namespace mlir
184 
185 #endif
A collection of dialect interfaces within a context, for a given concrete interface type.
DialectInterfaceCollection(MLIRContext *ctx)
Collect the registered dialect interfaces within the provided context.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition: Dialect.cpp:117
DialectInterface(Dialect *dialect, TypeID id)
Dialect * getDialect() const
Return the dialect that this interface represents.
TypeID getID() const
Return the derived interface id.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
The base class used for all derived interface types.
static TypeID getInterfaceID()
Get a unique id for the derived interface type.
This class is the base class for a collection of instances for a specific interface kind.
const DialectInterface * getInterfaceFor(Operation *op) const
Get the interface for the dialect of given operation, or null if one is not registered.
Definition: Dialect.cpp:136
iterator< InterfaceT > interface_end() const
DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind)
Definition: Dialect.cpp:121
const DialectInterface * getInterfaceFor(Dialect *dialect) const
Get the interface for the given dialect.
iterator< InterfaceT > interface_begin() const
Iterator access to the held interfaces.
llvm::hash_code hash_value(const MPInt &x)
Redeclarations of friend declaration above to make it discoverable by lookups.
Definition: MPInt.cpp:15
Include the generated interface declarations.
An iterator class that iterates the held interface objects of the given derived interface type.
const InterfaceT & mapElement(const DialectInterface *interface) const
Map the element to the iterator result type.