MLIR  22.0.0git
Dialect.cpp
Go to the documentation of this file.
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/TypeID.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/SmallVectorExtras.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/Support/DebugLog.h"
23 #include "llvm/Support/Regex.h"
24 #include <memory>
25 
26 #define DEBUG_TYPE "dialect"
27 
28 using namespace mlir;
29 using namespace detail;
30 
31 //===----------------------------------------------------------------------===//
32 // Dialect
33 //===----------------------------------------------------------------------===//
34 
35 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
36  : name(name), dialectID(id), context(context) {
37  assert(isValidNamespace(name) && "invalid dialect namespace");
38 }
39 
40 Dialect::~Dialect() = default;
41 
42 /// Verify an attribute from this dialect on the argument at 'argIndex' for
43 /// the region at 'regionIndex' on the given operation. Returns failure if
44 /// the verification failed, success otherwise. This hook may optionally be
45 /// invoked from any operation containing a region.
46 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
48  return success();
49 }
50 
51 /// Verify an attribute from this dialect on the result at 'resultIndex' for
52 /// the region at 'regionIndex' on the given operation. Returns failure if
53 /// the verification failed, success otherwise. This hook may optionally be
54 /// invoked from any operation containing a region.
56  unsigned, NamedAttribute) {
57  return success();
58 }
59 
60 /// Parse an attribute registered to this dialect.
62  parser.emitError(parser.getNameLoc())
63  << "dialect '" << getNamespace()
64  << "' provides no attribute parsing hook";
65  return Attribute();
66 }
67 
68 /// Parse a type registered to this dialect.
70  // If this dialect allows unknown types, then represent this with OpaqueType.
71  if (allowsUnknownTypes()) {
72  StringAttr ns = StringAttr::get(getContext(), getNamespace());
73  return OpaqueType::get(ns, parser.getFullSymbolSpec());
74  }
75 
76  parser.emitError(parser.getNameLoc())
77  << "dialect '" << getNamespace() << "' provides no type parsing hook";
78  return Type();
79 }
80 
81 std::optional<Dialect::ParseOpHook>
82 Dialect::getParseOperationHook(StringRef opName) const {
83  return std::nullopt;
84 }
85 
86 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
88  assert(op->getDialect() == this &&
89  "Dialect hook invoked on non-dialect owned operation");
90  return nullptr;
91 }
92 
93 /// Utility function that returns if the given string is a valid dialect
94 /// namespace
95 bool Dialect::isValidNamespace(StringRef str) {
96  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
97  return dialectNameRegex.match(str);
98 }
99 
100 /// Register a set of dialect interfaces with this dialect instance.
101 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
102  // Handle the case where the models resolve a promised interface.
104 
105  auto it = registeredInterfaces.try_emplace(interface->getID(),
106  std::move(interface));
107  if (!it.second)
108  LDBG() << "repeated interface registration for dialect " << getNamespace();
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Dialect Interface
113 //===----------------------------------------------------------------------===//
114 
116 
118  return dialect->getContext();
119 }
120 
121 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
122  MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
123  for (auto *dialect : ctx->getLoadedDialects()) {
124 #ifndef NDEBUG
125  dialect->handleUseOfUndefinedPromisedInterface(
126  dialect->getTypeID(), interfaceKind, interfaceName);
127 #endif
128  if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
129  interfaces.insert(interface);
130  orderedInterfaces.push_back(interface);
131  }
132  }
133 }
134 
135 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
136 
137 /// Get the interface for the dialect of given operation, or null if one
138 /// is not registered.
139 const DialectInterface *
140 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
141  return getInterfaceFor(op->getDialect());
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // DialectExtension
146 //===----------------------------------------------------------------------===//
147 
149 
151  Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
152  StringRef interfaceName) {
153  dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
154  interfaceID, interfaceName);
155 }
156 
158  Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
159  dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
160  interfaceID);
161 }
162 
164  TypeID interfaceRequestorID,
165  TypeID interfaceID) {
166  return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // DialectRegistry
171 //===----------------------------------------------------------------------===//
172 
173 namespace {
174 template <typename Fn>
175 void applyExtensionsFn(
176  Fn &&applyExtension,
177  const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
178  &extensions) {
179  // Note: Additional extensions may be added while applying an extension.
180  // The iterators will be invalidated if extensions are added so we'll keep
181  // a copy of the extensions for ourselves.
182 
183  const auto extractExtension =
184  [](const auto &entry) -> DialectExtensionBase * {
185  return entry.second.get();
186  };
187 
188  auto startIt = extensions.begin(), endIt = extensions.end();
189  size_t count = 0;
190  while (startIt != endIt) {
191  count += endIt - startIt;
192 
193  // Grab the subset of extensions we'll apply in this iteration.
194  const auto subset =
195  llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
196 
197  for (const auto *ext : subset)
198  applyExtension(*ext);
199 
200  // Book-keep for the next iteration.
201  startIt = extensions.begin() + count;
202  endIt = extensions.end();
203  }
204 }
205 } // namespace
206 
207 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
208 
211  auto it = registry.find(name);
212  if (it == registry.end())
213  return nullptr;
214  return it->second.second;
215 }
216 
217 void DialectRegistry::insert(TypeID typeID, StringRef name,
218  const DialectAllocatorFunction &ctor) {
219  auto inserted = registry.insert(
220  std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
221  if (!inserted.second && inserted.first->second.first != typeID) {
222  llvm::report_fatal_error(
223  "Trying to register different dialects for the same namespace: " +
224  name);
225  }
226 }
227 
229  StringRef name, const DynamicDialectPopulationFunction &ctor) {
230  // This TypeID marks dynamic dialects. We cannot give a TypeID for the
231  // dialect yet, since the TypeID of a dynamic dialect is defined at its
232  // construction.
233  TypeID typeID = TypeID::get<void>();
234 
235  // Create the dialect, and then call ctor, which allocates its components.
236  auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
237  auto *dynDialect = ctx->getOrLoadDynamicDialect(
238  nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
239  assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
240  return dynDialect;
241  };
242 
243  insert(typeID, name, constructor);
244 }
245 
247  MLIRContext *ctx = dialect->getContext();
248  StringRef dialectName = dialect->getNamespace();
249 
250  // Functor used to try to apply the given extension.
251  auto applyExtension = [&](const DialectExtensionBase &extension) {
252  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
253  // An empty set is equivalent to always invoke.
254  if (dialectNames.empty()) {
255  extension.apply(ctx, dialect);
256  return;
257  }
258 
259  // Handle the simple case of a single dialect name. In this case, the
260  // required dialect should be the current dialect.
261  if (dialectNames.size() == 1) {
262  if (dialectNames.front() == dialectName)
263  extension.apply(ctx, dialect);
264  return;
265  }
266 
267  // Otherwise, check to see if this extension requires this dialect.
268  const StringRef *nameIt = llvm::find(dialectNames, dialectName);
269  if (nameIt == dialectNames.end())
270  return;
271 
272  // If it does, ensure that all of the other required dialects have been
273  // loaded.
274  SmallVector<Dialect *> requiredDialects;
275  requiredDialects.reserve(dialectNames.size());
276  for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
277  ++it) {
278  // The current dialect is known to be loaded.
279  if (it == nameIt) {
280  requiredDialects.push_back(dialect);
281  continue;
282  }
283  // Otherwise, check if it is loaded.
284  Dialect *loadedDialect = ctx->getLoadedDialect(*it);
285  if (!loadedDialect)
286  return;
287  requiredDialects.push_back(loadedDialect);
288  }
289  extension.apply(ctx, requiredDialects);
290  };
291 
292  applyExtensionsFn(applyExtension, extensions);
293 }
294 
296  // Functor used to try to apply the given extension.
297  auto applyExtension = [&](const DialectExtensionBase &extension) {
298  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
299  if (dialectNames.empty()) {
300  auto loadedDialects = ctx->getLoadedDialects();
301  extension.apply(ctx, loadedDialects);
302  return;
303  }
304 
305  // Check to see if all of the dialects for this extension are loaded.
306  SmallVector<Dialect *> requiredDialects;
307  requiredDialects.reserve(dialectNames.size());
308  for (StringRef dialectName : dialectNames) {
309  Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
310  if (!loadedDialect)
311  return;
312  requiredDialects.push_back(loadedDialect);
313  }
314  extension.apply(ctx, requiredDialects);
315  };
316 
317  applyExtensionsFn(applyExtension, extensions);
318 }
319 
321  // Check that all extension keys are present in 'rhs'.
322  const auto hasExtension = [&](const auto &key) {
323  return rhs.extensions.contains(key);
324  };
325  if (!llvm::all_of(make_first_range(extensions), hasExtension))
326  return false;
327 
328  // Check that the current dialects fully overlap with the dialects in 'rhs'.
329  return llvm::all_of(
330  registry, [&](const auto &it) { return rhs.registry.count(it.first); });
331 }
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
This class represents an opaque dialect extension.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition: Dialect.cpp:117
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:320
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:210
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:228
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:246
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual ~Dialect()
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:69
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:82
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.h:228
StringRef getNamespace() const
Definition: Dialect.h:54
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:95
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition: Dialect.cpp:55
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition: Dialect.cpp:46
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:87
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition: Dialect.h:245
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:67
MLIRContext * getContext() const
Definition: Dialect.h:52
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:61
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:101
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:57
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:251
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition: Dialect.cpp:35
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:163
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:157
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:150
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...