MLIR  19.0.0git
Dialect.cpp
Go to the documentation of this file.
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Operation.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/Regex.h"
22 
23 #define DEBUG_TYPE "dialect"
24 
25 using namespace mlir;
26 using namespace detail;
27 
28 //===----------------------------------------------------------------------===//
29 // Dialect
30 //===----------------------------------------------------------------------===//
31 
32 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
33  : name(name), dialectID(id), context(context) {
34  assert(isValidNamespace(name) && "invalid dialect namespace");
35 }
36 
37 Dialect::~Dialect() = default;
38 
39 /// Verify an attribute from this dialect on the argument at 'argIndex' for
40 /// the region at 'regionIndex' on the given operation. Returns failure if
41 /// the verification failed, success otherwise. This hook may optionally be
42 /// invoked from any operation containing a region.
45  return success();
46 }
47 
48 /// Verify an attribute from this dialect on the result at 'resultIndex' for
49 /// the region at 'regionIndex' on the given operation. Returns failure if
50 /// the verification failed, success otherwise. This hook may optionally be
51 /// invoked from any operation containing a region.
53  unsigned, NamedAttribute) {
54  return success();
55 }
56 
57 /// Parse an attribute registered to this dialect.
59  parser.emitError(parser.getNameLoc())
60  << "dialect '" << getNamespace()
61  << "' provides no attribute parsing hook";
62  return Attribute();
63 }
64 
65 /// Parse a type registered to this dialect.
67  // If this dialect allows unknown types, then represent this with OpaqueType.
68  if (allowsUnknownTypes()) {
69  StringAttr ns = StringAttr::get(getContext(), getNamespace());
70  return OpaqueType::get(ns, parser.getFullSymbolSpec());
71  }
72 
73  parser.emitError(parser.getNameLoc())
74  << "dialect '" << getNamespace() << "' provides no type parsing hook";
75  return Type();
76 }
77 
78 std::optional<Dialect::ParseOpHook>
79 Dialect::getParseOperationHook(StringRef opName) const {
80  return std::nullopt;
81 }
82 
83 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
85  assert(op->getDialect() == this &&
86  "Dialect hook invoked on non-dialect owned operation");
87  return nullptr;
88 }
89 
90 /// Utility function that returns if the given string is a valid dialect
91 /// namespace
92 bool Dialect::isValidNamespace(StringRef str) {
93  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
94  return dialectNameRegex.match(str);
95 }
96 
97 /// Register a set of dialect interfaces with this dialect instance.
98 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
99  // Handle the case where the models resolve a promised interface.
101 
102  auto it = registeredInterfaces.try_emplace(interface->getID(),
103  std::move(interface));
104  (void)it;
105  LLVM_DEBUG({
106  if (!it.second) {
107  llvm::dbgs() << "[" DEBUG_TYPE
108  "] repeated interface registration for dialect "
109  << getNamespace();
110  }
111  });
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Dialect Interface
116 //===----------------------------------------------------------------------===//
117 
119 
121  return dialect->getContext();
122 }
123 
124 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
125  MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
126  for (auto *dialect : ctx->getLoadedDialects()) {
127 #ifndef NDEBUG
128  dialect->handleUseOfUndefinedPromisedInterface(
129  dialect->getTypeID(), interfaceKind, interfaceName);
130 #endif
131  if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
132  interfaces.insert(interface);
133  orderedInterfaces.push_back(interface);
134  }
135  }
136 }
137 
138 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
139 
140 /// Get the interface for the dialect of given operation, or null if one
141 /// is not registered.
142 const DialectInterface *
143 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
144  return getInterfaceFor(op->getDialect());
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // DialectExtension
149 //===----------------------------------------------------------------------===//
150 
152 
154  Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
155  StringRef interfaceName) {
156  dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
157  interfaceID, interfaceName);
158 }
159 
161  Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
162  dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
163  interfaceID);
164 }
165 
167  TypeID interfaceRequestorID,
168  TypeID interfaceID) {
169  return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // DialectRegistry
174 //===----------------------------------------------------------------------===//
175 
176 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
177 
180  auto it = registry.find(name.str());
181  if (it == registry.end())
182  return nullptr;
183  return it->second.second;
184 }
185 
186 void DialectRegistry::insert(TypeID typeID, StringRef name,
187  const DialectAllocatorFunction &ctor) {
188  auto inserted = registry.insert(
189  std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
190  if (!inserted.second && inserted.first->second.first != typeID) {
191  llvm::report_fatal_error(
192  "Trying to register different dialects for the same namespace: " +
193  name);
194  }
195 }
196 
198  StringRef name, const DynamicDialectPopulationFunction &ctor) {
199  // This TypeID marks dynamic dialects. We cannot give a TypeID for the
200  // dialect yet, since the TypeID of a dynamic dialect is defined at its
201  // construction.
202  TypeID typeID = TypeID::get<void>();
203 
204  // Create the dialect, and then call ctor, which allocates its components.
205  auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
206  auto *dynDialect = ctx->getOrLoadDynamicDialect(
207  nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
208  assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
209  return dynDialect;
210  };
211 
212  insert(typeID, name, constructor);
213 }
214 
216  MLIRContext *ctx = dialect->getContext();
217  StringRef dialectName = dialect->getNamespace();
218 
219  // Functor used to try to apply the given extension.
220  auto applyExtension = [&](const DialectExtensionBase &extension) {
221  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
222  // An empty set is equivalent to always invoke.
223  if (dialectNames.empty()) {
224  extension.apply(ctx, dialect);
225  return;
226  }
227 
228  // Handle the simple case of a single dialect name. In this case, the
229  // required dialect should be the current dialect.
230  if (dialectNames.size() == 1) {
231  if (dialectNames.front() == dialectName)
232  extension.apply(ctx, dialect);
233  return;
234  }
235 
236  // Otherwise, check to see if this extension requires this dialect.
237  const StringRef *nameIt = llvm::find(dialectNames, dialectName);
238  if (nameIt == dialectNames.end())
239  return;
240 
241  // If it does, ensure that all of the other required dialects have been
242  // loaded.
243  SmallVector<Dialect *> requiredDialects;
244  requiredDialects.reserve(dialectNames.size());
245  for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
246  ++it) {
247  // The current dialect is known to be loaded.
248  if (it == nameIt) {
249  requiredDialects.push_back(dialect);
250  continue;
251  }
252  // Otherwise, check if it is loaded.
253  Dialect *loadedDialect = ctx->getLoadedDialect(*it);
254  if (!loadedDialect)
255  return;
256  requiredDialects.push_back(loadedDialect);
257  }
258  extension.apply(ctx, requiredDialects);
259  };
260 
261  // Note: Additional extensions may be added while applying an extension.
262  for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
263  applyExtension(*extensions[i]);
264 }
265 
267  // Functor used to try to apply the given extension.
268  auto applyExtension = [&](const DialectExtensionBase &extension) {
269  ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
270  if (dialectNames.empty()) {
271  auto loadedDialects = ctx->getLoadedDialects();
272  extension.apply(ctx, loadedDialects);
273  return;
274  }
275 
276  // Check to see if all of the dialects for this extension are loaded.
277  SmallVector<Dialect *> requiredDialects;
278  requiredDialects.reserve(dialectNames.size());
279  for (StringRef dialectName : dialectNames) {
280  Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
281  if (!loadedDialect)
282  return;
283  requiredDialects.push_back(loadedDialect);
284  }
285  extension.apply(ctx, requiredDialects);
286  };
287 
288  // Note: Additional extensions may be added while applying an extension.
289  for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
290  applyExtension(*extensions[i]);
291 }
292 
294  // Treat any extensions conservatively.
295  if (!extensions.empty())
296  return false;
297  // Check that the current dialects fully overlap with the dialects in 'rhs'.
298  return llvm::all_of(
299  registry, [&](const auto &it) { return rhs.registry.count(it.first); });
300 }
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual StringRef getFullSymbolSpec() const =0
Returns the full specification of the symbol being parsed.
This class represents an opaque dialect extension.
This class represents an interface overridden for a single dialect.
MLIRContext * getContext() const
Return the context that holds the parent dialect of this interface.
Definition: Dialect.cpp:120
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool isSubsetOf(const DialectRegistry &rhs) const
Returns true if the current registry is a subset of 'rhs', i.e.
Definition: Dialect.cpp:293
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const
Return an allocation function for constructing the dialect identified by its namespace,...
Definition: Dialect.cpp:179
void insertDynamic(StringRef name, const DynamicDialectPopulationFunction &ctor)
Add a new dynamic dialect constructor in the registry.
Definition: Dialect.cpp:197
void applyExtensions(Dialect *dialect) const
Apply any held extensions that require the given dialect.
Definition: Dialect.cpp:215
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual ~Dialect()
virtual Type parseType(DialectAsmParser &parser) const
Parse a type registered to this dialect.
Definition: Dialect.cpp:66
virtual std::optional< ParseOpHook > getParseOperationHook(StringRef opName) const
Return the hook to parse an operation registered to this dialect, if any.
Definition: Dialect.cpp:79
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName="")
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.h:231
StringRef getNamespace() const
Definition: Dialect.h:57
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
virtual LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, unsigned resultIndex, NamedAttribute)
Verify an attribute from this dialect on the result at 'resultIndex' for the region at 'regionIndex' ...
Definition: Dialect.cpp:52
virtual LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, unsigned argIndex, NamedAttribute)
Verify an attribute from this dialect on the argument at 'argIndex' for the region at 'regionIndex' o...
Definition: Dialect.cpp:43
virtual llvm::unique_function< void(Operation *, OpAsmPrinter &printer)> getOperationPrinter(Operation *op) const
Print an operation registered to this dialect.
Definition: Dialect.cpp:84
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached to a construct owned by this dialec...
Definition: Dialect.h:248
bool allowsUnknownTypes() const
Return true if this dialect allows for unregistered types, i.e., types prefixed with the dialect name...
Definition: Dialect.h:70
MLIRContext * getContext() const
Definition: Dialect.h:55
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const
Parse an attribute registered to this dialect.
Definition: Dialect.cpp:58
void addInterface(std::unique_ptr< DialectInterface > interface)
Register a dialect interface with this dialect instance.
Definition: Dialect.cpp:98
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:60
bool hasPromisedInterface(TypeID interfaceRequestorID, TypeID interfaceID) const
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.h:254
Dialect(StringRef name, MLIRContext *context, TypeID id)
The constructor takes a unique namespace for this dialect as well as the context to bind to.
Definition: Dialect.cpp:32
A dialect that can be defined at runtime.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if a promise has been made for the interface/requestor pair.
Definition: Dialect.cpp:166
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID)
Checks if the given interface, which is attempting to be attached, is a promised interface of this di...
Definition: Dialect.cpp:160
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, StringRef interfaceName)
Checks if the given interface, which is attempting to be used, is a promised interface of this dialec...
Definition: Dialect.cpp:153
Include the generated interface declarations.
std::function< void(MLIRContext *, DynamicDialect *)> DynamicDialectPopulationFunction
std::function< Dialect *(MLIRContext *)> DialectAllocatorFunction
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26