MLIR  16.0.0git
Dialect.cpp
Go to the documentation of this file.
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21 
22 #define DEBUG_TYPE "dialect"
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 //===----------------------------------------------------------------------===//
28 // Dialect
29 //===----------------------------------------------------------------------===//
30 
31 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
32  : name(name), dialectID(id), context(context) {
33  assert(isValidNamespace(name) && "invalid dialect namespace");
34 }
35 
36 Dialect::~Dialect() = default;
37 
38 /// Verify an attribute from this dialect on the argument at 'argIndex' for
39 /// the region at 'regionIndex' on the given operation. Returns failure if
40 /// the verification failed, success otherwise. This hook may optionally be
41 /// invoked from any operation containing a region.
44  return success();
45 }
46 
47 /// Verify an attribute from this dialect on the result at 'resultIndex' for
48 /// the region at 'regionIndex' on the given operation. Returns failure if
49 /// the verification failed, success otherwise. This hook may optionally be
50 /// invoked from any operation containing a region.
52  unsigned, NamedAttribute) {
53  return success();
54 }
55 
56 /// Parse an attribute registered to this dialect.
58  parser.emitError(parser.getNameLoc())
59  << "dialect '" << getNamespace()
60  << "' provides no attribute parsing hook";
61  return Attribute();
62 }
63 
64 /// Parse a type registered to this dialect.
66  // If this dialect allows unknown types, then represent this with OpaqueType.
67  if (allowsUnknownTypes()) {
68  StringAttr ns = StringAttr::get(getContext(), getNamespace());
69  return OpaqueType::get(ns, parser.getFullSymbolSpec());
70  }
71 
72  parser.emitError(parser.getNameLoc())
73  << "dialect '" << getNamespace() << "' provides no type parsing hook";
74  return Type();
75 }
76 
78 Dialect::getParseOperationHook(StringRef opName) const {
79  return None;
80 }
81 
82 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
84  assert(op->getDialect() == this &&
85  "Dialect hook invoked on non-dialect owned operation");
86  return nullptr;
87 }
88 
89 /// Utility function that returns if the given string is a valid dialect
90 /// namespace
91 bool Dialect::isValidNamespace(StringRef str) {
92  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
93  return dialectNameRegex.match(str);
94 }
95 
96 /// Register a set of dialect interfaces with this dialect instance.
97 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
98  auto it = registeredInterfaces.try_emplace(interface->getID(),
99  std::move(interface));
100  (void)it;
101  LLVM_DEBUG({
102  if (!it.second) {
103  llvm::dbgs() << "[" DEBUG_TYPE
104  "] repeated interface registration for dialect "
105  << getNamespace();
106  }
107  });
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Dialect Interface
112 //===----------------------------------------------------------------------===//
113 
115 
116 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
117  MLIRContext *ctx, TypeID interfaceKind) {
118  for (auto *dialect : ctx->getLoadedDialects()) {
119  if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
120  interfaces.insert(interface);
121  orderedInterfaces.push_back(interface);
122  }
123  }
124 }
125 
126 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
127 
128 /// Get the interface for the dialect of given operation, or null if one
129 /// is not registered.
130 const DialectInterface *
131 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
132  return getInterfaceFor(op->getDialect());
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // DialectExtension
137 //===----------------------------------------------------------------------===//
138 
140 
141 //===----------------------------------------------------------------------===//
142 // DialectRegistry
143 //===----------------------------------------------------------------------===//
144 
145 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
146 
149  auto it = registry.find(name.str());
150  if (it == registry.end())
151  return nullptr;
152  return it->second.second;
153 }
154 
155 void DialectRegistry::insert(TypeID typeID, StringRef name,
156  const DialectAllocatorFunction &ctor) {
157  auto inserted = registry.insert(
158  std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
159  if (!inserted.second && inserted.first->second.first != typeID) {
160  llvm::report_fatal_error(
161  "Trying to register different dialects for the same namespace: " +
162  name);
163  }
164 }
165 
167  MLIRContext *ctx = dialect->getContext();
168  StringRef dialectName = dialect->getNamespace();
169 
170  // Functor used to try to apply the given extension.
171  auto applyExtension = [&](const DialectExtensionBase &extension) {
172  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
173 
174  // Handle the simple case of a single dialect name. In this case, the
175  // required dialect should be the current dialect.
176  if (dialectNames.size() == 1) {
177  if (dialectNames.front() == dialectName)
178  extension.apply(ctx, dialect);
179  return;
180  }
181 
182  // Otherwise, check to see if this extension requires this dialect.
183  const StringRef *nameIt = llvm::find(dialectNames, dialectName);
184  if (nameIt == dialectNames.end())
185  return;
186 
187  // If it does, ensure that all of the other required dialects have been
188  // loaded.
189  SmallVector<Dialect *> requiredDialects;
190  requiredDialects.reserve(dialectNames.size());
191  for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
192  ++it) {
193  // The current dialect is known to be loaded.
194  if (it == nameIt) {
195  requiredDialects.push_back(dialect);
196  continue;
197  }
198  // Otherwise, check if it is loaded.
199  Dialect *loadedDialect = ctx->getLoadedDialect(*it);
200  if (!loadedDialect)
201  return;
202  requiredDialects.push_back(loadedDialect);
203  }
204  extension.apply(ctx, requiredDialects);
205  };
206 
207  for (const auto &extension : extensions)
208  applyExtension(*extension);
209 }
210 
212  // Functor used to try to apply the given extension.
213  auto applyExtension = [&](const DialectExtensionBase &extension) {
214  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
215 
216  // Check to see if all of the dialects for this extension are loaded.
217  SmallVector<Dialect *> requiredDialects;
218  requiredDialects.reserve(dialectNames.size());
219  for (StringRef dialectName : dialectNames) {
220  Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
221  if (!loadedDialect)
222  return;
223  requiredDialects.push_back(loadedDialect);
224  }
225  extension.apply(ctx, requiredDialects);
226  };
227 
228  for (const auto &extension : extensions)
229  applyExtension(*extension);
230 }
231 
233  // Treat any extensions conservatively.
234  if (!extensions.empty())
235  return false;
236  // Check that the current dialects fully overlap with the dialects in 'rhs'.
237  return llvm::all_of(
238  registry, [&](const auto &it) { return rhs.registry.count(it.first); });
239 }
Include the generated interface declarations.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class represents an opaque dialect extension.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:57
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:83
virtual ~Dialect()
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to...
Definition: Dialect.cpp:31
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:149
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
virtual Optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:78
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:70
MLIRContext * getContext() const
Definition: Dialect.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace, or nullptr if the namespace is not in this registry.
Definition: Dialect.cpp:148
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:166
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
StringRef getNamespace() const
Definition: Dialect.h:57
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:97
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:151
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at &#39;argIndex&#39; for the region at &#39;regionIndex&#39; o...
Definition: Dialect.cpp:42
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of &#39;rhs&#39;, i.e.
Definition: Dialect.cpp:232
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:65
This class represents an interface overridden for a single dialect.
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at &#39;resultIndex&#39; for the region at &#39;regionIndex&#39; ...
Definition: Dialect.cpp:51
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:91