MLIR  20.0.0git
DialectInterface.h
Go to the documentation of this file.
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include <vector>
16 
17 namespace mlir {
18 class Dialect;
19 class MLIRContext;
20 class Operation;
21 
22 //===----------------------------------------------------------------------===//
23 // DialectInterface
24 //===----------------------------------------------------------------------===//
25 namespace detail {
26 /// The base class used for all derived interface types. This class provides
27 /// utilities necessary for registration.
28 template <typename ConcreteType, typename BaseT>
29 class DialectInterfaceBase : public BaseT {
30 public:
32 
33  /// Get a unique id for the derived interface type.
34  static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
35 
36 protected:
37  DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
38 };
39 } // namespace detail
40 
41 /// This class represents an interface overridden for a single dialect.
43 public:
44  virtual ~DialectInterface();
45 
46  /// The base class used for all derived interface types. This class provides
47  /// utilities necessary for registration.
48  template <typename ConcreteType>
50 
51  /// Return the dialect that this interface represents.
52  Dialect *getDialect() const { return dialect; }
53 
54  /// Return the context that holds the parent dialect of this interface.
55  MLIRContext *getContext() const;
56 
57  /// Return the derived interface id.
58  TypeID getID() const { return interfaceID; }
59 
60 protected:
62  : dialect(dialect), interfaceID(id) {}
63 
64 private:
65  /// The dialect that represents this interface.
66  Dialect *dialect;
67 
68  /// The unique identifier for the derived interface type.
69  TypeID interfaceID;
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // DialectInterfaceCollection
74 //===----------------------------------------------------------------------===//
75 
76 namespace detail {
77 /// This class is the base class for a collection of instances for a specific
78 /// interface kind.
80  /// DenseMap info for dialect interfaces that allows lookup by the dialect.
81  struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
83 
84  static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
85  static unsigned getHashValue(const DialectInterface *key) {
86  return getHashValue(key->getDialect());
87  }
88 
89  static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
90  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
91  return false;
92  return lhs == rhs->getDialect();
93  }
94  };
95 
96  /// A set of registered dialect interface instances.
98  using InterfaceVectorT = std::vector<const DialectInterface *>;
99 
100 public:
102  StringRef interfaceName);
104 
105 protected:
106  /// Get the interface for the dialect of given operation, or null if one
107  /// is not registered.
108  const DialectInterface *getInterfaceFor(Operation *op) const;
109 
110  /// Get the interface for the given dialect.
111  const DialectInterface *getInterfaceFor(Dialect *dialect) const {
112  auto it = interfaces.find_as(dialect);
113  return it == interfaces.end() ? nullptr : *it;
114  }
115 
116  /// An iterator class that iterates the held interface objects of the given
117  /// derived interface type.
118  template <typename InterfaceT>
119  struct iterator
120  : public llvm::mapped_iterator_base<iterator<InterfaceT>,
121  InterfaceVectorT::const_iterator,
122  const InterfaceT &> {
123  using llvm::mapped_iterator_base<iterator<InterfaceT>,
124  InterfaceVectorT::const_iterator,
125  const InterfaceT &>::mapped_iterator_base;
126 
127  /// Map the element to the iterator result type.
128  const InterfaceT &mapElement(const DialectInterface *interface) const {
129  return *static_cast<const InterfaceT *>(interface);
130  }
131  };
132 
133  /// Iterator access to the held interfaces.
134  template <typename InterfaceT>
136  return iterator<InterfaceT>(orderedInterfaces.begin());
137  }
138  template <typename InterfaceT>
140  return iterator<InterfaceT>(orderedInterfaces.end());
141  }
142 
143 private:
144  /// A set of registered dialect interface instances.
145  InterfaceSetT interfaces;
146  /// An ordered list of the registered interface instances, necessary for
147  /// deterministic iteration.
148  // NOTE: SetVector does not provide find access, so it can't be used here.
149  InterfaceVectorT orderedInterfaces;
150 };
151 } // namespace detail
152 
153 /// A collection of dialect interfaces within a context, for a given concrete
154 /// interface type.
155 template <typename InterfaceType>
158 public:
160 
161  /// Collect the registered dialect interfaces within the provided context.
164  ctx, InterfaceType::getInterfaceID(),
165  llvm::getTypeName<InterfaceType>()) {}
166 
167  /// Get the interface for a given object, or null if one is not registered.
168  /// The object may be a dialect or an operation instance.
169  template <typename Object>
170  const InterfaceType *getInterfaceFor(Object *obj) const {
171  return static_cast<const InterfaceType *>(
173  }
174 
175  /// Iterator access to the held interfaces.
176  using iterator =
178  iterator begin() const { return interface_begin<InterfaceType>(); }
179  iterator end() const { return interface_end<InterfaceType>(); }
180 
181 private:
184 };
185 
186 } // namespace mlir
187 
188 #endif
A collection of dialect interfaces within a context, for a given concrete interface type.
DialectInterfaceCollection(MLIRContext *ctx)
Collect the registered dialect interfaces within the provided context.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition: Dialect.cpp:126
DialectInterface(Dialect *dialect, TypeID id)
Dialect * getDialect() const
Return the dialect that this interface represents.
TypeID getID() const
Return the derived interface id.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
The base class used for all derived interface types.
static TypeID getInterfaceID()
Get a unique id for the derived interface type.
This class is the base class for a collection of instances for a specific interface kind.
const DialectInterface * getInterfaceFor(Operation *op) const
Get the interface for the dialect of given operation, or null if one is not registered.
Definition: Dialect.cpp:149
iterator< InterfaceT > interface_end() const
const DialectInterface * getInterfaceFor(Dialect *dialect) const
Get the interface for the given dialect.
iterator< InterfaceT > interface_begin() const
Iterator access to the held interfaces.
DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName)
Definition: Dialect.cpp:130
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Definition: Polynomial.h:262
Include the generated interface declarations.
An iterator class that iterates the held interface objects of the given derived interface type.
const InterfaceT & mapElement(const DialectInterface *interface) const
Map the element to the iterator result type.