MLIR  15.0.0git
DialectInterface.h
Go to the documentation of this file.
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 #include "llvm/ADT/STLExtras.h"
15 
16 namespace mlir {
17 class Dialect;
18 class MLIRContext;
19 class Operation;
20 
21 //===----------------------------------------------------------------------===//
22 // DialectInterface
23 //===----------------------------------------------------------------------===//
24 namespace detail {
25 /// The base class used for all derived interface types. This class provides
26 /// utilities necessary for registration.
27 template <typename ConcreteType, typename BaseT>
28 class DialectInterfaceBase : public BaseT {
29 public:
31 
32  /// Get a unique id for the derived interface type.
33  static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
34 
35 protected:
36  DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
37 };
38 } // namespace detail
39 
40 /// This class represents an interface overridden for a single dialect.
42 public:
43  virtual ~DialectInterface();
44 
45  /// The base class used for all derived interface types. This class provides
46  /// utilities necessary for registration.
47  template <typename ConcreteType>
49 
50  /// Return the dialect that this interface represents.
51  Dialect *getDialect() const { return dialect; }
52 
53  /// Return the derived interface id.
54  TypeID getID() const { return interfaceID; }
55 
56 protected:
58  : dialect(dialect), interfaceID(id) {}
59 
60 private:
61  /// The dialect that represents this interface.
62  Dialect *dialect;
63 
64  /// The unique identifier for the derived interface type.
65  TypeID interfaceID;
66 };
67 
68 //===----------------------------------------------------------------------===//
69 // DialectInterfaceCollection
70 //===----------------------------------------------------------------------===//
71 
72 namespace detail {
73 /// This class is the base class for a collection of instances for a specific
74 /// interface kind.
76  /// DenseMap info for dialect interfaces that allows lookup by the dialect.
77  struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
79 
80  static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
81  static unsigned getHashValue(const DialectInterface *key) {
82  return getHashValue(key->getDialect());
83  }
84 
85  static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
86  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
87  return false;
88  return lhs == rhs->getDialect();
89  }
90  };
91 
92  /// A set of registered dialect interface instances.
94  using InterfaceVectorT = std::vector<const DialectInterface *>;
95 
96 public:
99 
100 protected:
101  /// Get the interface for the dialect of given operation, or null if one
102  /// is not registered.
103  const DialectInterface *getInterfaceFor(Operation *op) const;
104 
105  /// Get the interface for the given dialect.
106  const DialectInterface *getInterfaceFor(Dialect *dialect) const {
107  auto it = interfaces.find_as(dialect);
108  return it == interfaces.end() ? nullptr : *it;
109  }
110 
111  /// An iterator class that iterates the held interface objects of the given
112  /// derived interface type.
113  template <typename InterfaceT>
114  struct iterator
115  : public llvm::mapped_iterator_base<iterator<InterfaceT>,
116  InterfaceVectorT::const_iterator,
117  const InterfaceT &> {
118  using llvm::mapped_iterator_base<iterator<InterfaceT>,
119  InterfaceVectorT::const_iterator,
120  const InterfaceT &>::mapped_iterator_base;
121 
122  /// Map the element to the iterator result type.
123  const InterfaceT &mapElement(const DialectInterface *interface) const {
124  return *static_cast<const InterfaceT *>(interface);
125  }
126  };
127 
128  /// Iterator access to the held interfaces.
129  template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
130  return iterator<InterfaceT>(orderedInterfaces.begin());
131  }
132  template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
133  return iterator<InterfaceT>(orderedInterfaces.end());
134  }
135 
136 private:
137  /// A set of registered dialect interface instances.
138  InterfaceSetT interfaces;
139  /// An ordered list of the registered interface instances, necessary for
140  /// deterministic iteration.
141  // NOTE: SetVector does not provide find access, so it can't be used here.
142  InterfaceVectorT orderedInterfaces;
143 };
144 } // namespace detail
145 
146 /// A collection of dialect interfaces within a context, for a given concrete
147 /// interface type.
148 template <typename InterfaceType>
151 public:
153 
154  /// Collect the registered dialect interfaces within the provided context.
156  : detail::DialectInterfaceCollectionBase(
157  ctx, InterfaceType::getInterfaceID()) {}
158 
159  /// Get the interface for a given object, or null if one is not registered.
160  /// The object may be a dialect or an operation instance.
161  template <typename Object>
162  const InterfaceType *getInterfaceFor(Object *obj) const {
163  return static_cast<const InterfaceType *>(
165  }
166 
167  /// Iterator access to the held interfaces.
168  using iterator =
170  iterator begin() const { return interface_begin<InterfaceType>(); }
171  iterator end() const { return interface_end<InterfaceType>(); }
172 
173 private:
176 };
177 
178 } // namespace mlir
179 
180 #endif
An iterator class that iterates the held interface objects of the given derived interface type...
Include the generated interface declarations.
TypeID getID() const
Return the derived interface id.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
The base class used for all derived interface types.
iterator< InterfaceT > interface_end() const
Dialect * getDialect() const
Return the dialect that this interface represents.
A collection of dialect interfaces within a context, for a given concrete interface type...
const InterfaceT & mapElement(const DialectInterface *interface) const
Map the element to the iterator result type.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
const DialectInterface * getInterfaceFor(Dialect *dialect) const
Get the interface for the given dialect.
static TypeID getInterfaceID()
Get a unique id for the derived interface type.
DialectInterface(Dialect *dialect, TypeID id)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
This class is the base class for a collection of instances for a specific interface kind...
iterator< InterfaceT > interface_begin() const
Iterator access to the held interfaces.
DialectInterfaceCollection(MLIRContext *ctx)
Collect the registered dialect interfaces within the provided context.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
const DialectInterface * getInterfaceFor(Operation *op) const
Get the interface for the dialect of given operation, or null if one is not registered.
Definition: Dialect.cpp:131
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
This class represents an interface overridden for a single dialect.