MLIR  20.0.0git
Dialect.cpp
Go to the documentation of this file.
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/TypeID.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/SmallVectorExtras.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ManagedStatic.h"
26 #include "llvm/Support/Regex.h"
27 #include <memory>
28 
29 #define DEBUG_TYPE "dialect"
30 
31 using namespace mlir;
32 using namespace detail;
33 
34 //===----------------------------------------------------------------------===//
35 // Dialect
36 //===----------------------------------------------------------------------===//
37 
38 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
39  : name(name), dialectID(id), context(context) {
40  assert(isValidNamespace(name) && "invalid dialect namespace");
41 }
42 
43 Dialect::~Dialect() = default;
44 
45 /// Verify an attribute from this dialect on the argument at 'argIndex' for
46 /// the region at 'regionIndex' on the given operation. Returns failure if
47 /// the verification failed, success otherwise. This hook may optionally be
48 /// invoked from any operation containing a region.
49 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
51  return success();
52 }
53 
54 /// Verify an attribute from this dialect on the result at 'resultIndex' for
55 /// the region at 'regionIndex' on the given operation. Returns failure if
56 /// the verification failed, success otherwise. This hook may optionally be
57 /// invoked from any operation containing a region.
59  unsigned, NamedAttribute) {
60  return success();
61 }
62 
63 /// Parse an attribute registered to this dialect.
65  parser.emitError(parser.getNameLoc())
66  << "dialect '" << getNamespace()
67  << "' provides no attribute parsing hook";
68  return Attribute();
69 }
70 
71 /// Parse a type registered to this dialect.
73  // If this dialect allows unknown types, then represent this with OpaqueType.
74  if (allowsUnknownTypes()) {
75  StringAttr ns = StringAttr::get(getContext(), getNamespace());
76  return OpaqueType::get(ns, parser.getFullSymbolSpec());
77  }
78 
79  parser.emitError(parser.getNameLoc())
80  << "dialect '" << getNamespace() << "' provides no type parsing hook";
81  return Type();
82 }
83 
84 std::optional<Dialect::ParseOpHook>
85 Dialect::getParseOperationHook(StringRef opName) const {
86  return std::nullopt;
87 }
88 
89 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
91  assert(op->getDialect() == this &&
92  "Dialect hook invoked on non-dialect owned operation");
93  return nullptr;
94 }
95 
96 /// Utility function that returns if the given string is a valid dialect
97 /// namespace
98 bool Dialect::isValidNamespace(StringRef str) {
99  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
100  return dialectNameRegex.match(str);
101 }
102 
103 /// Register a set of dialect interfaces with this dialect instance.
104 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
105  // Handle the case where the models resolve a promised interface.
107 
108  auto it = registeredInterfaces.try_emplace(interface->getID(),
109  std::move(interface));
110  (void)it;
111  LLVM_DEBUG({
112  if (!it.second) {
113  llvm::dbgs() << "[" DEBUG_TYPE
114  "] repeated interface registration for dialect "
115  << getNamespace();
116  }
117  });
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Dialect Interface
122 //===----------------------------------------------------------------------===//
123 
125 
127  return dialect->getContext();
128 }
129 
130 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
131  MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
132  for (auto *dialect : ctx->getLoadedDialects()) {
133 #ifndef NDEBUG
134  dialect->handleUseOfUndefinedPromisedInterface(
135  dialect->getTypeID(), interfaceKind, interfaceName);
136 #endif
137  if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
138  interfaces.insert(interface);
139  orderedInterfaces.push_back(interface);
140  }
141  }
142 }
143 
144 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
145 
146 /// Get the interface for the dialect of given operation, or null if one
147 /// is not registered.
148 const DialectInterface *
149 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
150  return getInterfaceFor(op->getDialect());
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // DialectExtension
155 //===----------------------------------------------------------------------===//
156 
158 
160  Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
161  StringRef interfaceName) {
162  dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
163  interfaceID, interfaceName);
164 }
165 
167  Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
168  dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
169  interfaceID);
170 }
171 
173  TypeID interfaceRequestorID,
174  TypeID interfaceID) {
175  return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // DialectRegistry
180 //===----------------------------------------------------------------------===//
181 
182 namespace {
183 template <typename Fn>
184 void applyExtensionsFn(
185  Fn &&applyExtension,
186  const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
187  &extensions) {
188  // Note: Additional extensions may be added while applying an extension.
189  // The iterators will be invalidated if extensions are added so we'll keep
190  // a copy of the extensions for ourselves.
191 
192  const auto extractExtension =
193  [](const auto &entry) -> DialectExtensionBase * {
194  return entry.second.get();
195  };
196 
197  auto startIt = extensions.begin(), endIt = extensions.end();
198  size_t count = 0;
199  while (startIt != endIt) {
200  count += endIt - startIt;
201 
202  // Grab the subset of extensions we'll apply in this iteration.
203  const auto subset =
204  llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
205 
206  for (const auto *ext : subset)
207  applyExtension(*ext);
208 
209  // Book-keep for the next iteration.
210  startIt = extensions.begin() + count;
211  endIt = extensions.end();
212  }
213 }
214 } // namespace
215 
216 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
217 
220  auto it = registry.find(name);
221  if (it == registry.end())
222  return nullptr;
223  return it->second.second;
224 }
225 
226 void DialectRegistry::insert(TypeID typeID, StringRef name,
227  const DialectAllocatorFunction &ctor) {
228  auto inserted = registry.insert(
229  std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
230  if (!inserted.second && inserted.first->second.first != typeID) {
231  llvm::report_fatal_error(
232  "Trying to register different dialects for the same namespace: " +
233  name);
234  }
235 }
236 
238  StringRef name, const DynamicDialectPopulationFunction &ctor) {
239  // This TypeID marks dynamic dialects. We cannot give a TypeID for the
240  // dialect yet, since the TypeID of a dynamic dialect is defined at its
241  // construction.
242  TypeID typeID = TypeID::get<void>();
243 
244  // Create the dialect, and then call ctor, which allocates its components.
245  auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
246  auto *dynDialect = ctx->getOrLoadDynamicDialect(
247  nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
248  assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
249  return dynDialect;
250  };
251 
252  insert(typeID, name, constructor);
253 }
254 
256  MLIRContext *ctx = dialect->getContext();
257  StringRef dialectName = dialect->getNamespace();
258 
259  // Functor used to try to apply the given extension.
260  auto applyExtension = [&](const DialectExtensionBase &extension) {
261  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
262  // An empty set is equivalent to always invoke.
263  if (dialectNames.empty()) {
264  extension.apply(ctx, dialect);
265  return;
266  }
267 
268  // Handle the simple case of a single dialect name. In this case, the
269  // required dialect should be the current dialect.
270  if (dialectNames.size() == 1) {
271  if (dialectNames.front() == dialectName)
272  extension.apply(ctx, dialect);
273  return;
274  }
275 
276  // Otherwise, check to see if this extension requires this dialect.
277  const StringRef *nameIt = llvm::find(dialectNames, dialectName);
278  if (nameIt == dialectNames.end())
279  return;
280 
281  // If it does, ensure that all of the other required dialects have been
282  // loaded.
283  SmallVector<Dialect *> requiredDialects;
284  requiredDialects.reserve(dialectNames.size());
285  for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
286  ++it) {
287  // The current dialect is known to be loaded.
288  if (it == nameIt) {
289  requiredDialects.push_back(dialect);
290  continue;
291  }
292  // Otherwise, check if it is loaded.
293  Dialect *loadedDialect = ctx->getLoadedDialect(*it);
294  if (!loadedDialect)
295  return;
296  requiredDialects.push_back(loadedDialect);
297  }
298  extension.apply(ctx, requiredDialects);
299  };
300 
301  applyExtensionsFn(applyExtension, extensions);
302 }
303 
305  // Functor used to try to apply the given extension.
306  auto applyExtension = [&](const DialectExtensionBase &extension) {
307  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
308  if (dialectNames.empty()) {
309  auto loadedDialects = ctx->getLoadedDialects();
310  extension.apply(ctx, loadedDialects);
311  return;
312  }
313 
314  // Check to see if all of the dialects for this extension are loaded.
315  SmallVector<Dialect *> requiredDialects;
316  requiredDialects.reserve(dialectNames.size());
317  for (StringRef dialectName : dialectNames) {
318  Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
319  if (!loadedDialect)
320  return;
321  requiredDialects.push_back(loadedDialect);
322  }
323  extension.apply(ctx, requiredDialects);
324  };
325 
326  applyExtensionsFn(applyExtension, extensions);
327 }
328 
330  // Check that all extension keys are present in 'rhs'.
331  const auto hasExtension = [&](const auto &key) {
332  return rhs.extensions.contains(key);
333  };
334  if (!llvm::all_of(make_first_range(extensions), hasExtension))
335  return false;
336 
337  // Check that the current dialects fully overlap with the dialects in 'rhs'.
338  return llvm::all_of(
339  registry, [&](const auto &it) { return rhs.registry.count(it.first); });
340 }
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
This class represents an opaque dialect extension.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition: Dialect.cpp:126
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:329
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:219
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:237
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:255
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual ~Dialect()
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:72
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:85
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.h:228
StringRef getNamespace() const
Definition: Dialect.h:54
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition: Dialect.cpp:58
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition: Dialect.cpp:49
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:90
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition: Dialect.h:245
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:67
MLIRContext * getContext() const
Definition: Dialect.h:52
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:64
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:104
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:57
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:251
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition: Dialect.cpp:38
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:172
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:166
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:159
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...