MLIR 22.0.0git
Dialect.cpp
Go to the documentation of this file.
1//===- Dialect.cpp - Dialect implementation -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
11#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Support/TypeID.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/SmallVectorExtras.h"
21#include "llvm/ADT/Twine.h"
22#include "llvm/Support/DebugLog.h"
23#include "llvm/Support/Regex.h"
24#include <memory>
25
26#define DEBUG_TYPE "dialect"
27
28using namespace mlir;
29using namespace detail;
30
31//===----------------------------------------------------------------------===//
32// Dialect
33//===----------------------------------------------------------------------===//
34
35Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
36 : name(name), dialectID(id), context(context) {
37 assert(isValidNamespace(name) && "invalid dialect namespace");
38}
39
40Dialect::~Dialect() = default;
41
42/// Verify an attribute from this dialect on the argument at 'argIndex' for
43/// the region at 'regionIndex' on the given operation. Returns failure if
44/// the verification failed, success otherwise. This hook may optionally be
45/// invoked from any operation containing a region.
46LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
48 return success();
49}
50
51/// Verify an attribute from this dialect on the result at 'resultIndex' for
52/// the region at 'regionIndex' on the given operation. Returns failure if
53/// the verification failed, success otherwise. This hook may optionally be
54/// invoked from any operation containing a region.
56 unsigned, NamedAttribute) {
57 return success();
58}
59
60/// Parse an attribute registered to this dialect.
62 parser.emitError(parser.getNameLoc())
63 << "dialect '" << getNamespace()
64 << "' provides no attribute parsing hook";
65 return Attribute();
66}
67
68/// Parse a type registered to this dialect.
70 // If this dialect allows unknown types, then represent this with OpaqueType.
71 if (allowsUnknownTypes()) {
72 StringAttr ns = StringAttr::get(getContext(), getNamespace());
73 return OpaqueType::get(ns, parser.getFullSymbolSpec());
74 }
75
76 parser.emitError(parser.getNameLoc())
77 << "dialect '" << getNamespace() << "' provides no type parsing hook";
78 return Type();
79}
80
81std::optional<Dialect::ParseOpHook>
82Dialect::getParseOperationHook(StringRef opName) const {
83 return std::nullopt;
84}
85
86llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
88 assert(op->getDialect() == this &&
89 "Dialect hook invoked on non-dialect owned operation");
90 return nullptr;
91}
92
93/// Utility function that returns if the given string is a valid dialect
94/// namespace
95bool Dialect::isValidNamespace(StringRef str) {
96 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
97 return dialectNameRegex.match(str);
98}
99
100/// Register a set of dialect interfaces with this dialect instance.
101void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
102 // Handle the case where the models resolve a promised interface.
104
105 auto it = registeredInterfaces.try_emplace(interface->getID(),
106 std::move(interface));
107 if (!it.second)
108 LDBG() << "repeated interface registration for dialect " << getNamespace();
109}
110
111//===----------------------------------------------------------------------===//
112// Dialect Interface
113//===----------------------------------------------------------------------===//
114
116
118 return dialect->getContext();
119}
120
122 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
123 for (auto *dialect : ctx->getLoadedDialects()) {
124#ifndef NDEBUG
125 dialect->handleUseOfUndefinedPromisedInterface(
126 dialect->getTypeID(), interfaceKind, interfaceName);
127#endif
128 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
129 interfaces.insert(interface);
130 orderedInterfaces.push_back(interface);
131 }
132 }
133}
134
136
137/// Get the interface for the dialect of given operation, or null if one
138/// is not registered.
139const DialectInterface *
143
144//===----------------------------------------------------------------------===//
145// DialectExtension
146//===----------------------------------------------------------------------===//
147
149
151 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
152 StringRef interfaceName) {
153 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
154 interfaceID, interfaceName);
155}
156
158 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
159 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
160 interfaceID);
161}
162
164 TypeID interfaceRequestorID,
165 TypeID interfaceID) {
166 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
167}
168
169//===----------------------------------------------------------------------===//
170// DialectRegistry
171//===----------------------------------------------------------------------===//
172
173namespace {
174template <typename Fn>
175void applyExtensionsFn(
176 Fn &&applyExtension,
177 const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
178 &extensions) {
179 // Note: Additional extensions may be added while applying an extension.
180 // The iterators will be invalidated if extensions are added so we'll keep
181 // a copy of the extensions for ourselves.
182
183 const auto extractExtension =
184 [](const auto &entry) -> DialectExtensionBase * {
185 return entry.second.get();
186 };
187
188 auto startIt = extensions.begin(), endIt = extensions.end();
189 size_t count = 0;
190 while (startIt != endIt) {
191 count += endIt - startIt;
192
193 // Grab the subset of extensions we'll apply in this iteration.
194 const auto subset =
195 llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
196
197 for (const auto *ext : subset)
198 applyExtension(*ext);
199
200 // Book-keep for the next iteration.
201 startIt = extensions.begin() + count;
202 endIt = extensions.end();
203 }
204}
205} // namespace
206
208
211 auto it = registry.find(name);
212 if (it == registry.end())
213 return nullptr;
214 return it->second.second;
215}
216
217void DialectRegistry::insert(TypeID typeID, StringRef name,
218 const DialectAllocatorFunction &ctor) {
219 auto inserted = registry.insert(
220 std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
221 if (!inserted.second && inserted.first->second.first != typeID) {
222 llvm::report_fatal_error(
223 "Trying to register different dialects for the same namespace: " +
224 name);
225 }
226}
227
229 StringRef name, const DynamicDialectPopulationFunction &ctor) {
230 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
231 // dialect yet, since the TypeID of a dynamic dialect is defined at its
232 // construction.
233 TypeID typeID = TypeID::get<void>();
234
235 // Create the dialect, and then call ctor, which allocates its components.
236 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
237 auto *dynDialect = ctx->getOrLoadDynamicDialect(
238 nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
239 assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
240 return dynDialect;
241 };
242
243 insert(typeID, name, constructor);
244}
245
247 MLIRContext *ctx = dialect->getContext();
248 StringRef dialectName = dialect->getNamespace();
249
250 // Functor used to try to apply the given extension.
251 auto applyExtension = [&](const DialectExtensionBase &extension) {
252 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
253 // An empty set is equivalent to always invoke.
254 if (dialectNames.empty()) {
255 extension.apply(ctx, dialect);
256 return;
257 }
258
259 // Handle the simple case of a single dialect name. In this case, the
260 // required dialect should be the current dialect.
261 if (dialectNames.size() == 1) {
262 if (dialectNames.front() == dialectName)
263 extension.apply(ctx, dialect);
264 return;
265 }
266
267 // Otherwise, check to see if this extension requires this dialect.
268 const StringRef *nameIt = llvm::find(dialectNames, dialectName);
269 if (nameIt == dialectNames.end())
270 return;
271
272 // If it does, ensure that all of the other required dialects have been
273 // loaded.
274 SmallVector<Dialect *> requiredDialects;
275 requiredDialects.reserve(dialectNames.size());
276 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
277 ++it) {
278 // The current dialect is known to be loaded.
279 if (it == nameIt) {
280 requiredDialects.push_back(dialect);
281 continue;
282 }
283 // Otherwise, check if it is loaded.
284 Dialect *loadedDialect = ctx->getLoadedDialect(*it);
285 if (!loadedDialect)
286 return;
287 requiredDialects.push_back(loadedDialect);
288 }
289 extension.apply(ctx, requiredDialects);
290 };
291
292 applyExtensionsFn(applyExtension, extensions);
293}
294
296 // Functor used to try to apply the given extension.
297 auto applyExtension = [&](const DialectExtensionBase &extension) {
298 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
299 if (dialectNames.empty()) {
300 auto loadedDialects = ctx->getLoadedDialects();
301 extension.apply(ctx, loadedDialects);
302 return;
303 }
304
305 // Check to see if all of the dialects for this extension are loaded.
306 SmallVector<Dialect *> requiredDialects;
307 requiredDialects.reserve(dialectNames.size());
308 for (StringRef dialectName : dialectNames) {
309 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
310 if (!loadedDialect)
311 return;
312 requiredDialects.push_back(loadedDialect);
313 }
314 extension.apply(ctx, requiredDialects);
315 };
316
317 applyExtensionsFn(applyExtension, extensions);
318}
319
321 // Check that all extension keys are present in 'rhs'.
322 const auto hasExtension = [&](const auto &key) {
323 return rhs.extensions.contains(key);
324 };
325 if (!llvm::all_of(make_first_range(extensions), hasExtension))
326 return false;
327
328 // Check that the current dialects fully overlap with the dialects in 'rhs'.
329 return llvm::all_of(
330 registry, [&](const auto &it) { return rhs.registry.count(it.first); });
331}
return success()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
Attributes are known-constant values of operations.
Definition Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
This class represents an opaque dialect extension.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition Dialect.cpp:117
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition Dialect.cpp:320
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition Dialect.cpp:210
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition Dialect.cpp:228
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition Dialect.cpp:246
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
MLIRContext * getContext() const
Definition Dialect.h:52
virtual ~Dialect()
friend class MLIRContext
Definition Dialect.h:371
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition Dialect.cpp:69
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition Dialect.cpp:82
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.h:228
StringRef getNamespace() const
Definition Dialect.h:54
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition Dialect.cpp:55
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition Dialect.cpp:46
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition Dialect.cpp:87
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition Dialect.h:245
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition Dialect.h:67
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition Dialect.cpp:61
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition Dialect.cpp:101
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition Dialect.h:57
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.h:251
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition Dialect.cpp:35
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
const DialectInterface * getInterfaceFor(Operation *op) const
Get the interface for the dialect of given operation, or null if one is not registered.
Definition Dialect.cpp:140
DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName)
Definition Dialect.cpp:121
AttrTypeReplacer.
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition Dialect.cpp:163
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition Dialect.cpp:157
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition Dialect.cpp:150
Include the generated interface declarations.
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
function_ref< Dialect *(MLIRContext *)> DialectAllocatorFunctionRef
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction