MLIR  14.0.0git
Dialect.cpp
Go to the documentation of this file.
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21 
22 #define DEBUG_TYPE "dialect"
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 //===----------------------------------------------------------------------===//
28 // DialectRegistry
29 //===----------------------------------------------------------------------===//
30 
31 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
32 
34  StringRef dialectName, TypeID interfaceTypeID,
35  const DialectInterfaceAllocatorFunction &allocator) {
36  assert(allocator && "unexpected null interface allocation function");
37  auto it = registry.find(dialectName.str());
38  assert(it != registry.end() &&
39  "adding an interface for an unregistered dialect");
40 
41  // Bail out if the interface with the given ID is already in the registry for
42  // the given dialect. We expect a small number (dozens) of interfaces so a
43  // linear search is fine here.
44  auto &ifaces = interfaces[it->second.first];
45  for (const auto &kvp : ifaces.dialectInterfaces) {
46  if (kvp.first == interfaceTypeID) {
47  LLVM_DEBUG(llvm::dbgs()
48  << "[" DEBUG_TYPE
49  "] repeated interface registration for dialect "
50  << dialectName);
51  return;
52  }
53  }
54 
55  ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
56 }
57 
58 void DialectRegistry::addObjectInterface(
59  StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
60  const ObjectInterfaceAllocatorFunction &allocator) {
61  assert(allocator && "unexpected null interface allocation function");
62 
63  auto it = registry.find(dialectName.str());
64  assert(it != registry.end() &&
65  "adding an interface for an op from an unregistered dialect");
66 
67  auto dialectID = it->second.first;
68  auto &ifaces = interfaces[dialectID];
69 
70  for (const auto &info : ifaces.objectInterfaces) {
71  if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
72  LLVM_DEBUG(llvm::dbgs()
73  << "[" DEBUG_TYPE
74  "] repeated interface object interface registration");
75  return;
76  }
77  }
78 
79  ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
80 }
81 
83 DialectRegistry::getDialectAllocator(StringRef name) const {
84  auto it = registry.find(name.str());
85  if (it == registry.end())
86  return nullptr;
87  return it->second.second;
88 }
89 
90 void DialectRegistry::insert(TypeID typeID, StringRef name,
91  const DialectAllocatorFunction &ctor) {
92  auto inserted = registry.insert(
93  std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
94  if (!inserted.second && inserted.first->second.first != typeID) {
95  llvm::report_fatal_error(
96  "Trying to register different dialects for the same namespace: " +
97  name);
98  }
99 }
100 
102  auto it = interfaces.find(dialect->getTypeID());
103  if (it == interfaces.end())
104  return;
105 
106  // Add an interface if it is not already present.
107  for (const auto &kvp : it->getSecond().dialectInterfaces) {
108  if (dialect->getRegisteredInterface(kvp.first))
109  continue;
110  dialect->addInterface(kvp.second(dialect));
111  }
112 
113  // Add attribute, operation and type interfaces.
114  for (const auto &info : it->getSecond().objectInterfaces)
115  std::get<2>(info)(dialect->getContext());
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // Dialect
120 //===----------------------------------------------------------------------===//
121 
122 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
123  : name(name), dialectID(id), context(context) {
124  assert(isValidNamespace(name) && "invalid dialect namespace");
125 }
126 
127 Dialect::~Dialect() = default;
128 
129 /// Verify an attribute from this dialect on the argument at 'argIndex' for
130 /// the region at 'regionIndex' on the given operation. Returns failure if
131 /// the verification failed, success otherwise. This hook may optionally be
132 /// invoked from any operation containing a region.
134  NamedAttribute) {
135  return success();
136 }
137 
138 /// Verify an attribute from this dialect on the result at 'resultIndex' for
139 /// the region at 'regionIndex' on the given operation. Returns failure if
140 /// the verification failed, success otherwise. This hook may optionally be
141 /// invoked from any operation containing a region.
143  unsigned, NamedAttribute) {
144  return success();
145 }
146 
147 /// Parse an attribute registered to this dialect.
149  parser.emitError(parser.getNameLoc())
150  << "dialect '" << getNamespace()
151  << "' provides no attribute parsing hook";
152  return Attribute();
153 }
154 
155 /// Parse a type registered to this dialect.
157  // If this dialect allows unknown types, then represent this with OpaqueType.
158  if (allowsUnknownTypes()) {
159  StringAttr ns = StringAttr::get(getContext(), getNamespace());
160  return OpaqueType::get(ns, parser.getFullSymbolSpec());
161  }
162 
163  parser.emitError(parser.getNameLoc())
164  << "dialect '" << getNamespace() << "' provides no type parsing hook";
165  return Type();
166 }
167 
169 Dialect::getParseOperationHook(StringRef opName) const {
170  return None;
171 }
172 
173 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
175  assert(op->getDialect() == this &&
176  "Dialect hook invoked on non-dialect owned operation");
177  return nullptr;
178 }
179 
180 /// Utility function that returns if the given string is a valid dialect
181 /// namespace
182 bool Dialect::isValidNamespace(StringRef str) {
183  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
184  return dialectNameRegex.match(str);
185 }
186 
187 /// Register a set of dialect interfaces with this dialect instance.
188 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
189  auto it = registeredInterfaces.try_emplace(interface->getID(),
190  std::move(interface));
191  (void)it;
192  assert(it.second && "interface kind has already been registered");
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Dialect Interface
197 //===----------------------------------------------------------------------===//
198 
200 
201 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
202  MLIRContext *ctx, TypeID interfaceKind) {
203  for (auto *dialect : ctx->getLoadedDialects()) {
204  if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
205  interfaces.insert(interface);
206  orderedInterfaces.push_back(interface);
207  }
208  }
209 }
210 
211 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
212 
213 /// Get the interface for the dialect of given operation, or null if one
214 /// is not registered.
215 const DialectInterface *
216 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
217  return getInterfaceFor(op->getDialect());
218 }
Include the generated interface declarations.
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:61
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:148
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
Definition: Dialect.h:29
void addDialectInterface()
Add an interface to the dialect, both provided as template parameter.
Definition: Dialect.h:357
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:174
virtual ~Dialect()
#define DEBUG_TYPE
Definition: Dialect.cpp:22
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
void registerDelayedInterfaces(Dialect *dialect) const
Register any interfaces required for the given dialect (based on its TypeID).
Definition: Dialect.cpp:101
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to...
Definition: Dialect.cpp:122
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:52
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
const DialectInterface * getRegisteredInterface(TypeID interfaceID)
Lookup an interface for the given ID if one is registered, otherwise nullptr.
Definition: Dialect.h:161
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:42
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
virtual Optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:169
std::function< void(MLIRContext *)> ObjectInterfaceAllocatorFunction
Definition: Dialect.h:33
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:71
MLIRContext * getContext() const
Definition: Dialect.h:56
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace, or nullptr if the namespace is not in this registry.
Definition: Dialect.cpp:83
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
StringRef getNamespace() const
Definition: Dialect.h:58
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:188
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
std::function< std::unique_ptr< DialectInterface >(Dialect *)> DialectInterfaceAllocatorFunction
Definition: Dialect.h:32
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at &#39;argIndex&#39; for the region at &#39;regionIndex&#39; o...
Definition: Dialect.cpp:133
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:156
This class represents an interface overridden for a single dialect.
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at &#39;resultIndex&#39; for the region at &#39;regionIndex&#39; ...
Definition: Dialect.cpp:142
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:182