MLIR 23.0.0git
Dialect.cpp
Go to the documentation of this file.
1//===- Dialect.cpp - Dialect implementation -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
11#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Support/TypeID.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/SmallVectorExtras.h"
21#include "llvm/ADT/Twine.h"
22#include "llvm/Support/DebugLog.h"
23#include "llvm/Support/Regex.h"
24#include <memory>
25
26#define DEBUG_TYPE "dialect"
27
28using namespace mlir;
29using namespace detail;
30
31//===----------------------------------------------------------------------===//
32// Dialect
33//===----------------------------------------------------------------------===//
34
35Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
36 : name(name), dialectID(id), context(context) {
37 assert(isValidNamespace(name) && "invalid dialect namespace");
38}
39
40Dialect::~Dialect() = default;
41
42/// Verify an attribute from this dialect on the argument at 'argIndex' for
43/// the region at 'regionIndex' on the given operation. Returns failure if
44/// the verification failed, success otherwise. This hook may optionally be
45/// invoked from any operation containing a region.
46LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
48 return success();
49}
50
51/// Verify an attribute from this dialect on the result at 'resultIndex' for
52/// the region at 'regionIndex' on the given operation. Returns failure if
53/// the verification failed, success otherwise. This hook may optionally be
54/// invoked from any operation containing a region.
56 unsigned, NamedAttribute) {
57 return success();
58}
59
60/// Parse an attribute registered to this dialect.
62 parser.emitError(parser.getNameLoc())
63 << "dialect '" << getNamespace()
64 << "' provides no attribute parsing hook";
65 return Attribute();
66}
67
68/// Parse a type registered to this dialect.
70 // If this dialect allows unknown types, then represent this with OpaqueType.
71 if (allowsUnknownTypes()) {
72 StringAttr ns = StringAttr::get(getContext(), getNamespace());
73 return OpaqueType::get(ns, parser.getFullSymbolSpec());
74 }
75
76 parser.emitError(parser.getNameLoc())
77 << "dialect '" << getNamespace() << "' provides no type parsing hook";
78 return Type();
79}
80
81std::optional<Dialect::ParseOpHook>
82Dialect::getParseOperationHook(StringRef opName) const {
83 return std::nullopt;
84}
85
86llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
88 assert(op->getDialect() == this &&
89 "Dialect hook invoked on non-dialect owned operation");
90 return nullptr;
91}
92
93/// Utility function that returns if the given string is a valid dialect
94/// namespace
95bool Dialect::isValidNamespace(StringRef str) {
96 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
97 return dialectNameRegex.match(str);
98}
99
100/// Register a set of dialect interfaces with this dialect instance.
101void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
102 // Handle the case where the models resolve a promised interface.
104
105 auto it = registeredInterfaces.try_emplace(interface->getID(),
106 std::move(interface));
107 if (!it.second)
108 LDBG() << "repeated interface registration for dialect " << getNamespace();
109}
110
111//===----------------------------------------------------------------------===//
112// Dialect Interface
113//===----------------------------------------------------------------------===//
114
116
118 return dialect->getContext();
119}
120
122 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
123 for (auto *dialect : ctx->getLoadedDialects()) {
124#ifndef NDEBUG
125 dialect->handleUseOfUndefinedPromisedInterface(
126 dialect->getTypeID(), interfaceKind, interfaceName);
127#endif
128 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
129 interfaces.insert(interface);
130 orderedInterfaces.push_back(interface);
131 }
132 }
133}
134
136
137/// Get the interface for the dialect of given operation, or null if one
138/// is not registered.
139const DialectInterface *
143
144//===----------------------------------------------------------------------===//
145// DialectExtension
146//===----------------------------------------------------------------------===//
147
149
151 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
152 StringRef interfaceName) {
153 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
154 interfaceID, interfaceName);
155}
156
158 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
159 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
160 interfaceID);
161}
162
164 TypeID interfaceRequestorID,
165 TypeID interfaceID) {
166 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
167}
168
169//===----------------------------------------------------------------------===//
170// DialectRegistry
171//===----------------------------------------------------------------------===//
172
173namespace {
174template <typename Fn>
175void applyExtensionsFn(
176 Fn &&applyExtension,
177 const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
178 &extensions) {
179 // Note: Additional extensions may be added while applying an extension.
180 // The iterators will be invalidated if extensions are added so we'll keep
181 // a copy of the extensions for ourselves.
182
183 const auto extractExtension =
184 [](const auto &entry) -> DialectExtensionBase * {
185 return entry.second.get();
186 };
187
188 auto startIt = extensions.begin(), endIt = extensions.end();
189 size_t count = 0;
190 while (startIt != endIt) {
191 count += endIt - startIt;
192
193 // Grab the subset of extensions we'll apply in this iteration.
194 const auto subset =
195 llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
196
197 for (const auto *ext : subset)
198 applyExtension(*ext);
199
200 // Book-keep for the next iteration.
201 startIt = extensions.begin() + count;
202 endIt = extensions.end();
203 }
204}
205} // namespace
206
208
211 auto it = registry.find(name);
212 if (it == registry.end())
213 return nullptr;
214 return it->second.second;
215}
216
217void DialectRegistry::insert(TypeID typeID, StringRef name,
218 const DialectAllocatorFunction &ctor) {
219 auto inserted = registry.insert(
220 std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
221 if (!inserted.second && inserted.first->second.first != typeID) {
222 llvm::report_fatal_error(
223 "Trying to register different dialects for the same namespace: " +
224 name);
225 }
226}
227
230 for (const std::string &name : dialectsToPreload) {
231 if (!ctx->getOrLoadDialect(name)) {
232 if (emitError)
233 emitError() << "can't load dialect '" << name
234 << "': missing registration?";
235 return failure();
236 }
237 }
238 return success();
239}
240
242 // If we already have an allocator for this name, nothing to do: the existing
243 // registration will take care of loading the dialect.
244 if (registry.count(name))
245 return;
246 if (llvm::is_contained(dialectsToPreload, name))
247 return;
248 dialectsToPreload.emplace_back(name);
249}
250
252 StringRef name, const DynamicDialectPopulationFunction &ctor) {
253 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
254 // dialect yet, since the TypeID of a dynamic dialect is defined at its
255 // construction.
256 TypeID typeID = TypeID::get<void>();
257
258 // Create the dialect, and then call ctor, which allocates its components.
259 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
260 auto *dynDialect = ctx->getOrLoadDynamicDialect(
261 nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
262 assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
263 return dynDialect;
264 };
265
266 insert(typeID, name, constructor);
267}
268
270 MLIRContext *ctx = dialect->getContext();
271 StringRef dialectName = dialect->getNamespace();
272
273 // Functor used to try to apply the given extension.
274 auto applyExtension = [&](const DialectExtensionBase &extension) {
275 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
276 // An empty set is equivalent to always invoke.
277 if (dialectNames.empty()) {
278 extension.apply(ctx, dialect);
279 return;
280 }
281
282 // Handle the simple case of a single dialect name. In this case, the
283 // required dialect should be the current dialect.
284 if (dialectNames.size() == 1) {
285 if (dialectNames.front() == dialectName)
286 extension.apply(ctx, dialect);
287 return;
288 }
289
290 // Otherwise, check to see if this extension requires this dialect.
291 const StringRef *nameIt = llvm::find(dialectNames, dialectName);
292 if (nameIt == dialectNames.end())
293 return;
294
295 // If it does, ensure that all of the other required dialects have been
296 // loaded.
297 SmallVector<Dialect *> requiredDialects;
298 requiredDialects.reserve(dialectNames.size());
299 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
300 ++it) {
301 // The current dialect is known to be loaded.
302 if (it == nameIt) {
303 requiredDialects.push_back(dialect);
304 continue;
305 }
306 // Otherwise, check if it is loaded.
307 Dialect *loadedDialect = ctx->getLoadedDialect(*it);
308 if (!loadedDialect)
309 return;
310 requiredDialects.push_back(loadedDialect);
311 }
312 extension.apply(ctx, requiredDialects);
313 };
314
315 applyExtensionsFn(applyExtension, extensions);
316}
317
319 // Functor used to try to apply the given extension.
320 auto applyExtension = [&](const DialectExtensionBase &extension) {
321 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
322 if (dialectNames.empty()) {
323 auto loadedDialects = ctx->getLoadedDialects();
324 extension.apply(ctx, loadedDialects);
325 return;
326 }
327
328 // Check to see if all of the dialects for this extension are loaded.
329 SmallVector<Dialect *> requiredDialects;
330 requiredDialects.reserve(dialectNames.size());
331 for (StringRef dialectName : dialectNames) {
332 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
333 if (!loadedDialect)
334 return;
335 requiredDialects.push_back(loadedDialect);
336 }
337 extension.apply(ctx, requiredDialects);
338 };
339
340 applyExtensionsFn(applyExtension, extensions);
341}
342
344 // Check that all extension keys are present in 'rhs'.
345 const auto hasExtension = [&](const auto &key) {
346 return rhs.extensions.contains(key);
347 };
348 if (!llvm::all_of(make_first_range(extensions), hasExtension))
349 return false;
350
351 // Check that the current dialects fully overlap with the dialects in 'rhs'.
352 if (!llvm::all_of(registry, [&](const auto &it) {
353 return rhs.registry.count(it.first);
354 }))
355 return false;
356
357 // Check that all preload-only entries are known in 'rhs'.
358 return llvm::all_of(dialectsToPreload, [&](const std::string &name) {
359 return rhs.registry.count(name) ||
360 llvm::is_contained(rhs.dialectsToPreload, name);
361 });
362}
return success()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
Attributes are known-constant values of operations.
Definition Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
This class represents an opaque dialect extension.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition Dialect.cpp:117
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition Dialect.cpp:343
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition Dialect.cpp:210
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition Dialect.cpp:251
LogicalResult preloadSelectDialects(MLIRContext *ctx, function_ref< InFlightDiagnostic()> emitError={}) const
Load into ctx every dialect previously added via addDialectToPreload(StringRef).
Definition Dialect.cpp:228
void addDialectToPreload(StringRef name)
Request that the dialect with the given name be preloaded into the MLIRContext, without providing an ...
Definition Dialect.cpp:241
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition Dialect.cpp:269
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
MLIRContext * getContext() const
Definition Dialect.h:52
virtual ~Dialect()
friend class MLIRContext
Definition Dialect.h:371
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition Dialect.cpp:69
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition Dialect.cpp:82
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.h:228
StringRef getNamespace() const
Definition Dialect.h:54
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition Dialect.cpp:55
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition Dialect.cpp:46
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition Dialect.cpp:87
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition Dialect.h:245
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition Dialect.h:67
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition Dialect.cpp:61
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition Dialect.cpp:101
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition Dialect.h:57
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.h:251
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition Dialect.cpp:35
A dialect that can be defined at runtime.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
const DialectInterface * getInterfaceFor(Operation *op) const
Get the interface for the dialect of given operation, or null if one is not registered.
Definition Dialect.cpp:140
DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName)
Definition Dialect.cpp:121
AttrTypeReplacer.
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.cpp:163
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition Dialect.cpp:157
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.cpp:150
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
function_ref< Dialect *(MLIRContext *)> DialectAllocatorFunctionRef
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147