MLIR  16.0.0git
IRModule.cpp
Go to the documentation of this file.
1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12 
13 #include <vector>
14 
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 // -----------------------------------------------------------------------------
22 // PyGlobals
23 // -----------------------------------------------------------------------------
24 
25 PyGlobals *PyGlobals::instance = nullptr;
26 
28  assert(!instance && "PyGlobals already constructed");
29  instance = this;
30  // The default search path include {mlir.}dialects, where {mlir.} is the
31  // package prefix configured at compile time.
32  dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
33 }
34 
35 PyGlobals::~PyGlobals() { instance = nullptr; }
36 
37 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
38  if (loadedDialectModulesCache.contains(dialectNamespace))
39  return;
40  // Since re-entrancy is possible, make a copy of the search prefixes.
41  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
42  py::object loaded;
43  for (std::string moduleName : localSearchPrefixes) {
44  moduleName.push_back('.');
45  moduleName.append(dialectNamespace.data(), dialectNamespace.size());
46 
47  try {
48  loaded = py::module::import(moduleName.c_str());
49  } catch (py::error_already_set &e) {
50  if (e.matches(PyExc_ModuleNotFoundError)) {
51  continue;
52  }
53  throw;
54  }
55  break;
56  }
57 
58  // Note: Iterator cannot be shared from prior to loading, since re-entrancy
59  // may have occurred, which may do anything.
60  loadedDialectModulesCache.insert(dialectNamespace);
61 }
62 
63 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
64  py::object pyClass) {
65  py::object &found = dialectClassMap[dialectNamespace];
66  if (found) {
67  throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
68  dialectNamespace +
69  "' is already registered.");
70  }
71  found = std::move(pyClass);
72 }
73 
74 void PyGlobals::registerOperationImpl(const std::string &operationName,
75  py::object pyClass,
76  py::object rawOpViewClass) {
77  py::object &found = operationClassMap[operationName];
78  if (found) {
79  throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
80  operationName +
81  "' is already registered.");
82  }
83  found = std::move(pyClass);
84  rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
85 }
86 
88 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
89  loadDialectModule(dialectNamespace);
90  // Fast match against the class map first (common case).
91  const auto foundIt = dialectClassMap.find(dialectNamespace);
92  if (foundIt != dialectClassMap.end()) {
93  if (foundIt->second.is_none())
94  return llvm::None;
95  assert(foundIt->second && "py::object is defined");
96  return foundIt->second;
97  }
98 
99  // Not found and loading did not yield a registration. Negative cache.
100  dialectClassMap[dialectNamespace] = py::none();
101  return llvm::None;
102 }
103 
105 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
106  {
107  auto foundIt = rawOpViewClassMapCache.find(operationName);
108  if (foundIt != rawOpViewClassMapCache.end()) {
109  if (foundIt->second.is_none())
110  return llvm::None;
111  assert(foundIt->second && "py::object is defined");
112  return foundIt->second;
113  }
114  }
115 
116  // Not found. Load the dialect namespace.
117  auto split = operationName.split('.');
118  llvm::StringRef dialectNamespace = split.first;
119  loadDialectModule(dialectNamespace);
120 
121  // Attempt to find from the canonical map and cache.
122  {
123  auto foundIt = rawOpViewClassMap.find(operationName);
124  if (foundIt != rawOpViewClassMap.end()) {
125  if (foundIt->second.is_none())
126  return llvm::None;
127  assert(foundIt->second && "py::object is defined");
128  // Positive cache.
129  rawOpViewClassMapCache[operationName] = foundIt->second;
130  return foundIt->second;
131  }
132  // Negative cache.
133  rawOpViewClassMap[operationName] = py::none();
134  return llvm::None;
135  }
136 }
137 
139  loadedDialectModulesCache.clear();
140  rawOpViewClassMapCache.clear();
141 }
Include the generated interface declarations.
Globals that are always accessible once the extension has been initialized.
Definition: Globals.h:25
void loadDialectModule(llvm::StringRef dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition: IRModule.cpp:37
void clearImportCache()
Clears positive and negative caches regarding what implementations are available. ...
Definition: IRModule.cpp:138
void registerOperationImpl(const std::string &operationName, pybind11::object pyClass, pybind11::object rawOpViewClass)
Adds a concrete implementation operation class.
Definition: IRModule.cpp:74
pybind11::error_already_set SetPyError(PyObject *excClass, const llvm::Twine &message)
Definition: PybindUtils.cpp:12
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:56
llvm::Optional< pybind11::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition: IRModule.cpp:88
llvm::Optional< pybind11::object > lookupRawOpViewClass(llvm::StringRef operationName)
Looks up a registered raw OpView class by operation name.
Definition: IRModule.cpp:105
void registerDialectImpl(const std::string &dialectNamespace, pybind11::object pyClass)
Adds a concrete implementation dialect class.
Definition: IRModule.cpp:63