MLIR  19.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 #include <optional>
25 #include <utility>
26 
27 namespace py = pybind11;
28 using namespace py::literals;
29 using namespace mlir;
30 using namespace mlir::python;
31 
32 using llvm::SmallVector;
33 using llvm::StringRef;
34 using llvm::Twine;
35 
36 //------------------------------------------------------------------------------
37 // Docstrings (trivial, non-duplicated docstrings are included inline).
38 //------------------------------------------------------------------------------
39 
40 static const char kContextParseTypeDocstring[] =
41  R"(Parses the assembly form of a type.
42 
43 Returns a Type object or raises an MLIRError if the type cannot be parsed.
44 
45 See also: https://mlir.llvm.org/docs/LangRef/#type-system
46 )";
47 
49  R"(Gets a Location representing a caller and callsite)";
50 
51 static const char kContextGetFileLocationDocstring[] =
52  R"(Gets a Location representing a file, line and column)";
53 
54 static const char kContextGetFusedLocationDocstring[] =
55  R"(Gets a Location representing a fused location with optional metadata)";
56 
57 static const char kContextGetNameLocationDocString[] =
58  R"(Gets a Location representing a named location with optional child location)";
59 
60 static const char kModuleParseDocstring[] =
61  R"(Parses a module's assembly format from a string.
62 
63 Returns a new MlirModule or raises an MLIRError if the parsing fails.
64 
65 See also: https://mlir.llvm.org/docs/LangRef/
66 )";
67 
68 static const char kOperationCreateDocstring[] =
69  R"(Creates a new operation.
70 
71 Args:
72  name: Operation name (e.g. "dialect.operation").
73  results: Sequence of Type representing op result types.
74  attributes: Dict of str:Attribute.
75  successors: List of Block for the operation's successors.
76  regions: Number of regions to create.
77  location: A Location object (defaults to resolve from context manager).
78  ip: An InsertionPoint (defaults to resolve from context manager or set to
79  False to disable insertion, even with an insertion point set in the
80  context manager).
81  infer_type: Whether to infer result types.
82 Returns:
83  A new "detached" Operation object. Detached operations can be added
84  to blocks, which causes them to become "attached."
85 )";
86 
87 static const char kOperationPrintDocstring[] =
88  R"(Prints the assembly form of the operation to a file like object.
89 
90 Args:
91  file: The file like object to write to. Defaults to sys.stdout.
92  binary: Whether to write bytes (True) or str (False). Defaults to False.
93  large_elements_limit: Whether to elide elements attributes above this
94  number of elements. Defaults to None (no limit).
95  enable_debug_info: Whether to print debug/location information. Defaults
96  to False.
97  pretty_debug_info: Whether to format debug information for easier reading
98  by a human (warning: the result is unparseable).
99  print_generic_op_form: Whether to print the generic assembly forms of all
100  ops. Defaults to False.
101  use_local_Scope: Whether to print in a way that is more optimized for
102  multi-threaded access but may not be consistent with how the overall
103  module prints.
104  assume_verified: By default, if not printing generic form, the verifier
105  will be run and if it fails, generic form will be printed with a comment
106  about failed verification. While a reasonable default for interactive use,
107  for systematic use, it is often better for the caller to verify explicitly
108  and report failures in a more robust fashion. Set this to True if doing this
109  in order to avoid running a redundant verification. If the IR is actually
110  invalid, behavior is undefined.
111 )";
112 
113 static const char kOperationPrintStateDocstring[] =
114  R"(Prints the assembly form of the operation to a file like object.
115 
116 Args:
117  file: The file like object to write to. Defaults to sys.stdout.
118  binary: Whether to write bytes (True) or str (False). Defaults to False.
119  state: AsmState capturing the operation numbering and flags.
120 )";
121 
122 static const char kOperationGetAsmDocstring[] =
123  R"(Gets the assembly form of the operation with all options available.
124 
125 Args:
126  binary: Whether to return a bytes (True) or str (False) object. Defaults to
127  False.
128  ... others ...: See the print() method for common keyword arguments for
129  configuring the printout.
130 Returns:
131  Either a bytes or str object, depending on the setting of the 'binary'
132  argument.
133 )";
134 
135 static const char kOperationPrintBytecodeDocstring[] =
136  R"(Write the bytecode form of the operation to a file like object.
137 
138 Args:
139  file: The file like object to write to.
140  desired_version: The version of bytecode to emit.
141 Returns:
142  The bytecode writer status.
143 )";
144 
145 static const char kOperationStrDunderDocstring[] =
146  R"(Gets the assembly form of the operation with default options.
147 
148 If more advanced control over the assembly formatting or I/O options is needed,
149 use the dedicated print or get_asm method, which supports keyword arguments to
150 customize behavior.
151 )";
152 
153 static const char kDumpDocstring[] =
154  R"(Dumps a debug representation of the object to stderr.)";
155 
156 static const char kAppendBlockDocstring[] =
157  R"(Appends a new block, with argument types as positional args.
158 
159 Returns:
160  The created block.
161 )";
162 
163 static const char kValueDunderStrDocstring[] =
164  R"(Returns the string form of the value.
165 
166 If the value is a block argument, this is the assembly form of its type and the
167 position in the argument list. If the value is an operation result, this is
168 equivalent to printing the operation that produced it.
169 )";
170 
171 static const char kGetNameAsOperand[] =
172  R"(Returns the string form of value as an operand (i.e., the ValueID).
173 )";
174 
176  R"(Replace all uses of value with the new value, updating anything in
177 the IR that uses 'self' to use the other value instead.
178 )";
179 
180 //------------------------------------------------------------------------------
181 // Utilities.
182 //------------------------------------------------------------------------------
183 
184 /// Helper for creating an @classmethod.
185 template <class Func, typename... Args>
186 py::object classmethod(Func f, Args... args) {
187  py::object cf = py::cpp_function(f, args...);
188  return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
189 }
190 
191 static py::object
192 createCustomDialectWrapper(const std::string &dialectNamespace,
193  py::object dialectDescriptor) {
194  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
195  if (!dialectClass) {
196  // Use the base class.
197  return py::cast(PyDialect(std::move(dialectDescriptor)));
198  }
199 
200  // Create the custom implementation.
201  return (*dialectClass)(std::move(dialectDescriptor));
202 }
203 
204 static MlirStringRef toMlirStringRef(const std::string &s) {
205  return mlirStringRefCreate(s.data(), s.size());
206 }
207 
208 /// Create a block, using the current location context if no locations are
209 /// specified.
210 static MlirBlock createBlock(const py::sequence &pyArgTypes,
211  const std::optional<py::sequence> &pyArgLocs) {
212  SmallVector<MlirType> argTypes;
213  argTypes.reserve(pyArgTypes.size());
214  for (const auto &pyType : pyArgTypes)
215  argTypes.push_back(pyType.cast<PyType &>());
216 
218  if (pyArgLocs) {
219  argLocs.reserve(pyArgLocs->size());
220  for (const auto &pyLoc : *pyArgLocs)
221  argLocs.push_back(pyLoc.cast<PyLocation &>());
222  } else if (!argTypes.empty()) {
223  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
224  }
225 
226  if (argTypes.size() != argLocs.size())
227  throw py::value_error(("Expected " + Twine(argTypes.size()) +
228  " locations, got: " + Twine(argLocs.size()))
229  .str());
230  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
231 }
232 
233 /// Wrapper for the global LLVM debugging flag.
235  static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
236 
237  static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
238 
239  static void bind(py::module &m) {
240  // Debug flags.
241  py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
242  .def_property_static("flag", &PyGlobalDebugFlag::get,
243  &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
244  }
245 };
246 
248  static bool dunderContains(const std::string &attributeKind) {
249  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
250  }
251  static py::function dundeGetItemNamed(const std::string &attributeKind) {
252  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
253  if (!builder)
254  throw py::key_error(attributeKind);
255  return *builder;
256  }
257  static void dundeSetItemNamed(const std::string &attributeKind,
258  py::function func, bool replace) {
259  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
260  replace);
261  }
262 
263  static void bind(py::module &m) {
264  py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
265  .def_static("contains", &PyAttrBuilderMap::dunderContains)
266  .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
267  .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
268  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
269  "Register an attribute builder for building MLIR "
270  "attributes from python values.");
271  }
272 };
273 
274 //------------------------------------------------------------------------------
275 // PyBlock
276 //------------------------------------------------------------------------------
277 
278 py::object PyBlock::getCapsule() {
279  return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
280 }
281 
282 //------------------------------------------------------------------------------
283 // Collections.
284 //------------------------------------------------------------------------------
285 
286 namespace {
287 
288 class PyRegionIterator {
289 public:
290  PyRegionIterator(PyOperationRef operation)
291  : operation(std::move(operation)) {}
292 
293  PyRegionIterator &dunderIter() { return *this; }
294 
295  PyRegion dunderNext() {
296  operation->checkValid();
297  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
298  throw py::stop_iteration();
299  }
300  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
301  return PyRegion(operation, region);
302  }
303 
304  static void bind(py::module &m) {
305  py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
306  .def("__iter__", &PyRegionIterator::dunderIter)
307  .def("__next__", &PyRegionIterator::dunderNext);
308  }
309 
310 private:
311  PyOperationRef operation;
312  int nextIndex = 0;
313 };
314 
315 /// Regions of an op are fixed length and indexed numerically so are represented
316 /// with a sequence-like container.
317 class PyRegionList {
318 public:
319  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
320 
321  PyRegionIterator dunderIter() {
322  operation->checkValid();
323  return PyRegionIterator(operation);
324  }
325 
326  intptr_t dunderLen() {
327  operation->checkValid();
328  return mlirOperationGetNumRegions(operation->get());
329  }
330 
331  PyRegion dunderGetItem(intptr_t index) {
332  // dunderLen checks validity.
333  if (index < 0 || index >= dunderLen()) {
334  throw py::index_error("attempt to access out of bounds region");
335  }
336  MlirRegion region = mlirOperationGetRegion(operation->get(), index);
337  return PyRegion(operation, region);
338  }
339 
340  static void bind(py::module &m) {
341  py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
342  .def("__len__", &PyRegionList::dunderLen)
343  .def("__iter__", &PyRegionList::dunderIter)
344  .def("__getitem__", &PyRegionList::dunderGetItem);
345  }
346 
347 private:
348  PyOperationRef operation;
349 };
350 
351 class PyBlockIterator {
352 public:
353  PyBlockIterator(PyOperationRef operation, MlirBlock next)
354  : operation(std::move(operation)), next(next) {}
355 
356  PyBlockIterator &dunderIter() { return *this; }
357 
358  PyBlock dunderNext() {
359  operation->checkValid();
360  if (mlirBlockIsNull(next)) {
361  throw py::stop_iteration();
362  }
363 
364  PyBlock returnBlock(operation, next);
365  next = mlirBlockGetNextInRegion(next);
366  return returnBlock;
367  }
368 
369  static void bind(py::module &m) {
370  py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
371  .def("__iter__", &PyBlockIterator::dunderIter)
372  .def("__next__", &PyBlockIterator::dunderNext);
373  }
374 
375 private:
376  PyOperationRef operation;
377  MlirBlock next;
378 };
379 
380 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
381 /// we present them as a more full-featured list-like container but optimize
382 /// it for forward iteration. Blocks are always owned by a region.
383 class PyBlockList {
384 public:
385  PyBlockList(PyOperationRef operation, MlirRegion region)
386  : operation(std::move(operation)), region(region) {}
387 
388  PyBlockIterator dunderIter() {
389  operation->checkValid();
390  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
391  }
392 
393  intptr_t dunderLen() {
394  operation->checkValid();
395  intptr_t count = 0;
396  MlirBlock block = mlirRegionGetFirstBlock(region);
397  while (!mlirBlockIsNull(block)) {
398  count += 1;
399  block = mlirBlockGetNextInRegion(block);
400  }
401  return count;
402  }
403 
404  PyBlock dunderGetItem(intptr_t index) {
405  operation->checkValid();
406  if (index < 0) {
407  throw py::index_error("attempt to access out of bounds block");
408  }
409  MlirBlock block = mlirRegionGetFirstBlock(region);
410  while (!mlirBlockIsNull(block)) {
411  if (index == 0) {
412  return PyBlock(operation, block);
413  }
414  block = mlirBlockGetNextInRegion(block);
415  index -= 1;
416  }
417  throw py::index_error("attempt to access out of bounds block");
418  }
419 
420  PyBlock appendBlock(const py::args &pyArgTypes,
421  const std::optional<py::sequence> &pyArgLocs) {
422  operation->checkValid();
423  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
424  mlirRegionAppendOwnedBlock(region, block);
425  return PyBlock(operation, block);
426  }
427 
428  static void bind(py::module &m) {
429  py::class_<PyBlockList>(m, "BlockList", py::module_local())
430  .def("__getitem__", &PyBlockList::dunderGetItem)
431  .def("__iter__", &PyBlockList::dunderIter)
432  .def("__len__", &PyBlockList::dunderLen)
433  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
434  py::arg("arg_locs") = std::nullopt);
435  }
436 
437 private:
438  PyOperationRef operation;
439  MlirRegion region;
440 };
441 
442 class PyOperationIterator {
443 public:
444  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
445  : parentOperation(std::move(parentOperation)), next(next) {}
446 
447  PyOperationIterator &dunderIter() { return *this; }
448 
449  py::object dunderNext() {
450  parentOperation->checkValid();
451  if (mlirOperationIsNull(next)) {
452  throw py::stop_iteration();
453  }
454 
455  PyOperationRef returnOperation =
456  PyOperation::forOperation(parentOperation->getContext(), next);
457  next = mlirOperationGetNextInBlock(next);
458  return returnOperation->createOpView();
459  }
460 
461  static void bind(py::module &m) {
462  py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
463  .def("__iter__", &PyOperationIterator::dunderIter)
464  .def("__next__", &PyOperationIterator::dunderNext);
465  }
466 
467 private:
468  PyOperationRef parentOperation;
469  MlirOperation next;
470 };
471 
472 /// Operations are exposed by the C-API as a forward-only linked list. In
473 /// Python, we present them as a more full-featured list-like container but
474 /// optimize it for forward iteration. Iterable operations are always owned
475 /// by a block.
476 class PyOperationList {
477 public:
478  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
479  : parentOperation(std::move(parentOperation)), block(block) {}
480 
481  PyOperationIterator dunderIter() {
482  parentOperation->checkValid();
483  return PyOperationIterator(parentOperation,
485  }
486 
487  intptr_t dunderLen() {
488  parentOperation->checkValid();
489  intptr_t count = 0;
490  MlirOperation childOp = mlirBlockGetFirstOperation(block);
491  while (!mlirOperationIsNull(childOp)) {
492  count += 1;
493  childOp = mlirOperationGetNextInBlock(childOp);
494  }
495  return count;
496  }
497 
498  py::object dunderGetItem(intptr_t index) {
499  parentOperation->checkValid();
500  if (index < 0) {
501  throw py::index_error("attempt to access out of bounds operation");
502  }
503  MlirOperation childOp = mlirBlockGetFirstOperation(block);
504  while (!mlirOperationIsNull(childOp)) {
505  if (index == 0) {
506  return PyOperation::forOperation(parentOperation->getContext(), childOp)
507  ->createOpView();
508  }
509  childOp = mlirOperationGetNextInBlock(childOp);
510  index -= 1;
511  }
512  throw py::index_error("attempt to access out of bounds operation");
513  }
514 
515  static void bind(py::module &m) {
516  py::class_<PyOperationList>(m, "OperationList", py::module_local())
517  .def("__getitem__", &PyOperationList::dunderGetItem)
518  .def("__iter__", &PyOperationList::dunderIter)
519  .def("__len__", &PyOperationList::dunderLen);
520  }
521 
522 private:
523  PyOperationRef parentOperation;
524  MlirBlock block;
525 };
526 
527 class PyOpOperand {
528 public:
529  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
530 
531  py::object getOwner() {
532  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
533  PyMlirContextRef context =
534  PyMlirContext::forContext(mlirOperationGetContext(owner));
535  return PyOperation::forOperation(context, owner)->createOpView();
536  }
537 
538  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
539 
540  static void bind(py::module &m) {
541  py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
542  .def_property_readonly("owner", &PyOpOperand::getOwner)
543  .def_property_readonly("operand_number",
544  &PyOpOperand::getOperandNumber);
545  }
546 
547 private:
548  MlirOpOperand opOperand;
549 };
550 
551 class PyOpOperandIterator {
552 public:
553  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
554 
555  PyOpOperandIterator &dunderIter() { return *this; }
556 
557  PyOpOperand dunderNext() {
558  if (mlirOpOperandIsNull(opOperand))
559  throw py::stop_iteration();
560 
561  PyOpOperand returnOpOperand(opOperand);
562  opOperand = mlirOpOperandGetNextUse(opOperand);
563  return returnOpOperand;
564  }
565 
566  static void bind(py::module &m) {
567  py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
568  .def("__iter__", &PyOpOperandIterator::dunderIter)
569  .def("__next__", &PyOpOperandIterator::dunderNext);
570  }
571 
572 private:
573  MlirOpOperand opOperand;
574 };
575 
576 } // namespace
577 
578 //------------------------------------------------------------------------------
579 // PyMlirContext
580 //------------------------------------------------------------------------------
581 
582 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
583  py::gil_scoped_acquire acquire;
584  auto &liveContexts = getLiveContexts();
585  liveContexts[context.ptr] = this;
586 }
587 
589  // Note that the only public way to construct an instance is via the
590  // forContext method, which always puts the associated handle into
591  // liveContexts.
592  py::gil_scoped_acquire acquire;
593  getLiveContexts().erase(context.ptr);
594  mlirContextDestroy(context);
595 }
596 
598  return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
599 }
600 
601 py::object PyMlirContext::createFromCapsule(py::object capsule) {
602  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
603  if (mlirContextIsNull(rawContext))
604  throw py::error_already_set();
605  return forContext(rawContext).releaseObject();
606 }
607 
609  MlirContext context = mlirContextCreateWithThreading(false);
610  return new PyMlirContext(context);
611 }
612 
614  py::gil_scoped_acquire acquire;
615  auto &liveContexts = getLiveContexts();
616  auto it = liveContexts.find(context.ptr);
617  if (it == liveContexts.end()) {
618  // Create.
619  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
620  py::object pyRef = py::cast(unownedContextWrapper);
621  assert(pyRef && "cast to py::object failed");
622  liveContexts[context.ptr] = unownedContextWrapper;
623  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
624  }
625  // Use existing.
626  py::object pyRef = py::cast(it->second);
627  return PyMlirContextRef(it->second, std::move(pyRef));
628 }
629 
630 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
631  static LiveContextMap liveContexts;
632  return liveContexts;
633 }
634 
635 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
636 
637 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
638 
639 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
640  std::vector<PyOperation *> liveObjects;
641  for (auto &entry : liveOperations)
642  liveObjects.push_back(entry.second.second);
643  return liveObjects;
644 }
645 
647  for (auto &op : liveOperations)
648  op.second.second->setInvalid();
649  size_t numInvalidated = liveOperations.size();
650  liveOperations.clear();
651  return numInvalidated;
652 }
653 
654 void PyMlirContext::clearOperation(MlirOperation op) {
655  auto it = liveOperations.find(op.ptr);
656  if (it != liveOperations.end()) {
657  it->second.second->setInvalid();
658  liveOperations.erase(it);
659  }
660 }
661 
663  typedef struct {
664  PyOperation &rootOp;
665  bool rootSeen;
666  } callBackData;
667  callBackData data{op.getOperation(), false};
668  // Mark all ops below the op that the passmanager will be rooted
669  // at (but not op itself - note the preorder) as invalid.
670  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
671  void *userData) {
672  callBackData *data = static_cast<callBackData *>(userData);
673  if (LLVM_LIKELY(data->rootSeen))
674  data->rootOp.getOperation().getContext()->clearOperation(op);
675  else
676  data->rootSeen = true;
678  };
679  mlirOperationWalk(op.getOperation(), invalidatingCallback,
680  static_cast<void *>(&data), MlirWalkPreOrder);
681 }
682 void PyMlirContext::clearOperationsInside(MlirOperation op) {
685 }
686 
687 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
688 
689 pybind11::object PyMlirContext::contextEnter() {
690  return PyThreadContextEntry::pushContext(*this);
691 }
692 
693 void PyMlirContext::contextExit(const pybind11::object &excType,
694  const pybind11::object &excVal,
695  const pybind11::object &excTb) {
697 }
698 
699 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
700  // Note that ownership is transferred to the delete callback below by way of
701  // an explicit inc_ref (borrow).
702  PyDiagnosticHandler *pyHandler =
703  new PyDiagnosticHandler(get(), std::move(callback));
704  py::object pyHandlerObject =
705  py::cast(pyHandler, py::return_value_policy::take_ownership);
706  pyHandlerObject.inc_ref();
707 
708  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
709  // guaranteed to be known to pybind.
710  auto handlerCallback =
711  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
712  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
713  py::object pyDiagnosticObject =
714  py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
715 
716  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
717  bool result = false;
718  {
719  // Since this can be called from arbitrary C++ contexts, always get the
720  // gil.
721  py::gil_scoped_acquire gil;
722  try {
723  result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
724  } catch (std::exception &e) {
725  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
726  e.what());
727  pyHandler->hadError = true;
728  }
729  }
730 
731  pyDiagnostic->invalidate();
733  };
734  auto deleteCallback = +[](void *userData) {
735  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
736  assert(pyHandler->registeredID && "handler is not registered");
737  pyHandler->registeredID.reset();
738 
739  // Decrement reference, balancing the inc_ref() above.
740  py::object pyHandlerObject =
741  py::cast(pyHandler, py::return_value_policy::reference);
742  pyHandlerObject.dec_ref();
743  };
744 
745  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
746  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
747  return pyHandlerObject;
748 }
749 
750 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
751  void *userData) {
752  auto *self = static_cast<ErrorCapture *>(userData);
753  // Check if the context requested we emit errors instead of capturing them.
754  if (self->ctx->emitErrorDiagnostics)
755  return mlirLogicalResultFailure();
756 
758  return mlirLogicalResultFailure();
759 
760  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
761  return mlirLogicalResultSuccess();
762 }
763 
766  if (!context) {
767  throw std::runtime_error(
768  "An MLIR function requires a Context but none was provided in the call "
769  "or from the surrounding environment. Either pass to the function with "
770  "a 'context=' argument or establish a default using 'with Context():'");
771  }
772  return *context;
773 }
774 
775 //------------------------------------------------------------------------------
776 // PyThreadContextEntry management
777 //------------------------------------------------------------------------------
778 
779 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
780  static thread_local std::vector<PyThreadContextEntry> stack;
781  return stack;
782 }
783 
785  auto &stack = getStack();
786  if (stack.empty())
787  return nullptr;
788  return &stack.back();
789 }
790 
791 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
792  py::object insertionPoint,
793  py::object location) {
794  auto &stack = getStack();
795  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
796  std::move(location));
797  // If the new stack has more than one entry and the context of the new top
798  // entry matches the previous, copy the insertionPoint and location from the
799  // previous entry if missing from the new top entry.
800  if (stack.size() > 1) {
801  auto &prev = *(stack.rbegin() + 1);
802  auto &current = stack.back();
803  if (current.context.is(prev.context)) {
804  // Default non-context objects from the previous entry.
805  if (!current.insertionPoint)
806  current.insertionPoint = prev.insertionPoint;
807  if (!current.location)
808  current.location = prev.location;
809  }
810  }
811 }
812 
814  if (!context)
815  return nullptr;
816  return py::cast<PyMlirContext *>(context);
817 }
818 
820  if (!insertionPoint)
821  return nullptr;
822  return py::cast<PyInsertionPoint *>(insertionPoint);
823 }
824 
826  if (!location)
827  return nullptr;
828  return py::cast<PyLocation *>(location);
829 }
830 
832  auto *tos = getTopOfStack();
833  return tos ? tos->getContext() : nullptr;
834 }
835 
837  auto *tos = getTopOfStack();
838  return tos ? tos->getInsertionPoint() : nullptr;
839 }
840 
842  auto *tos = getTopOfStack();
843  return tos ? tos->getLocation() : nullptr;
844 }
845 
847  py::object contextObj = py::cast(context);
848  push(FrameKind::Context, /*context=*/contextObj,
849  /*insertionPoint=*/py::object(),
850  /*location=*/py::object());
851  return contextObj;
852 }
853 
855  auto &stack = getStack();
856  if (stack.empty())
857  throw std::runtime_error("Unbalanced Context enter/exit");
858  auto &tos = stack.back();
859  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
860  throw std::runtime_error("Unbalanced Context enter/exit");
861  stack.pop_back();
862 }
863 
864 py::object
866  py::object contextObj =
867  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
868  py::object insertionPointObj = py::cast(insertionPoint);
869  push(FrameKind::InsertionPoint,
870  /*context=*/contextObj,
871  /*insertionPoint=*/insertionPointObj,
872  /*location=*/py::object());
873  return insertionPointObj;
874 }
875 
877  auto &stack = getStack();
878  if (stack.empty())
879  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
880  auto &tos = stack.back();
881  if (tos.frameKind != FrameKind::InsertionPoint &&
882  tos.getInsertionPoint() != &insertionPoint)
883  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
884  stack.pop_back();
885 }
886 
888  py::object contextObj = location.getContext().getObject();
889  py::object locationObj = py::cast(location);
890  push(FrameKind::Location, /*context=*/contextObj,
891  /*insertionPoint=*/py::object(),
892  /*location=*/locationObj);
893  return locationObj;
894 }
895 
897  auto &stack = getStack();
898  if (stack.empty())
899  throw std::runtime_error("Unbalanced Location enter/exit");
900  auto &tos = stack.back();
901  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
902  throw std::runtime_error("Unbalanced Location enter/exit");
903  stack.pop_back();
904 }
905 
906 //------------------------------------------------------------------------------
907 // PyDiagnostic*
908 //------------------------------------------------------------------------------
909 
911  valid = false;
912  if (materializedNotes) {
913  for (auto &noteObject : *materializedNotes) {
914  PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
915  note->invalidate();
916  }
917  }
918 }
919 
921  py::object callback)
922  : context(context), callback(std::move(callback)) {}
923 
925 
927  if (!registeredID)
928  return;
929  MlirDiagnosticHandlerID localID = *registeredID;
930  mlirContextDetachDiagnosticHandler(context, localID);
931  assert(!registeredID && "should have unregistered");
932  // Not strictly necessary but keeps stale pointers from being around to cause
933  // issues.
934  context = {nullptr};
935 }
936 
937 void PyDiagnostic::checkValid() {
938  if (!valid) {
939  throw std::invalid_argument(
940  "Diagnostic is invalid (used outside of callback)");
941  }
942 }
943 
945  checkValid();
946  return mlirDiagnosticGetSeverity(diagnostic);
947 }
948 
950  checkValid();
951  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
952  MlirContext context = mlirLocationGetContext(loc);
953  return PyLocation(PyMlirContext::forContext(context), loc);
954 }
955 
957  checkValid();
958  py::object fileObject = py::module::import("io").attr("StringIO")();
959  PyFileAccumulator accum(fileObject, /*binary=*/false);
960  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
961  return fileObject.attr("getvalue")();
962 }
963 
965  checkValid();
966  if (materializedNotes)
967  return *materializedNotes;
968  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
969  materializedNotes = py::tuple(numNotes);
970  for (intptr_t i = 0; i < numNotes; ++i) {
971  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
972  (*materializedNotes)[i] = PyDiagnostic(noteDiag);
973  }
974  return *materializedNotes;
975 }
976 
978  std::vector<DiagnosticInfo> notes;
979  for (py::handle n : getNotes())
980  notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
981  return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
982 }
983 
984 //------------------------------------------------------------------------------
985 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
986 //------------------------------------------------------------------------------
987 
988 MlirDialect PyDialects::getDialectForKey(const std::string &key,
989  bool attrError) {
990  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
991  {key.data(), key.size()});
992  if (mlirDialectIsNull(dialect)) {
993  std::string msg = (Twine("Dialect '") + key + "' not found").str();
994  if (attrError)
995  throw py::attribute_error(msg);
996  throw py::index_error(msg);
997  }
998  return dialect;
999 }
1000 
1002  return py::reinterpret_steal<py::object>(
1004 }
1005 
1007  MlirDialectRegistry rawRegistry =
1008  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1009  if (mlirDialectRegistryIsNull(rawRegistry))
1010  throw py::error_already_set();
1011  return PyDialectRegistry(rawRegistry);
1012 }
1013 
1014 //------------------------------------------------------------------------------
1015 // PyLocation
1016 //------------------------------------------------------------------------------
1017 
1019  return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
1020 }
1021 
1023  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1024  if (mlirLocationIsNull(rawLoc))
1025  throw py::error_already_set();
1027  rawLoc);
1028 }
1029 
1031  return PyThreadContextEntry::pushLocation(*this);
1032 }
1033 
1034 void PyLocation::contextExit(const pybind11::object &excType,
1035  const pybind11::object &excVal,
1036  const pybind11::object &excTb) {
1038 }
1039 
1041  auto *location = PyThreadContextEntry::getDefaultLocation();
1042  if (!location) {
1043  throw std::runtime_error(
1044  "An MLIR function requires a Location but none was provided in the "
1045  "call or from the surrounding environment. Either pass to the function "
1046  "with a 'loc=' argument or establish a default using 'with loc:'");
1047  }
1048  return *location;
1049 }
1050 
1051 //------------------------------------------------------------------------------
1052 // PyModule
1053 //------------------------------------------------------------------------------
1054 
1055 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1056  : BaseContextObject(std::move(contextRef)), module(module) {}
1057 
1059  py::gil_scoped_acquire acquire;
1060  auto &liveModules = getContext()->liveModules;
1061  assert(liveModules.count(module.ptr) == 1 &&
1062  "destroying module not in live map");
1063  liveModules.erase(module.ptr);
1064  mlirModuleDestroy(module);
1065 }
1066 
1067 PyModuleRef PyModule::forModule(MlirModule module) {
1068  MlirContext context = mlirModuleGetContext(module);
1069  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1070 
1071  py::gil_scoped_acquire acquire;
1072  auto &liveModules = contextRef->liveModules;
1073  auto it = liveModules.find(module.ptr);
1074  if (it == liveModules.end()) {
1075  // Create.
1076  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1077  // Note that the default return value policy on cast is automatic_reference,
1078  // which does not take ownership (delete will not be called).
1079  // Just be explicit.
1080  py::object pyRef =
1081  py::cast(unownedModule, py::return_value_policy::take_ownership);
1082  unownedModule->handle = pyRef;
1083  liveModules[module.ptr] =
1084  std::make_pair(unownedModule->handle, unownedModule);
1085  return PyModuleRef(unownedModule, std::move(pyRef));
1086  }
1087  // Use existing.
1088  PyModule *existing = it->second.second;
1089  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1090  return PyModuleRef(existing, std::move(pyRef));
1091 }
1092 
1093 py::object PyModule::createFromCapsule(py::object capsule) {
1094  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1095  if (mlirModuleIsNull(rawModule))
1096  throw py::error_already_set();
1097  return forModule(rawModule).releaseObject();
1098 }
1099 
1100 py::object PyModule::getCapsule() {
1101  return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
1102 }
1103 
1104 //------------------------------------------------------------------------------
1105 // PyOperation
1106 //------------------------------------------------------------------------------
1107 
1108 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1109  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1110 
1112  // If the operation has already been invalidated there is nothing to do.
1113  if (!valid)
1114  return;
1115  auto &liveOperations = getContext()->liveOperations;
1116  assert(liveOperations.count(operation.ptr) == 1 &&
1117  "destroying operation not in live map");
1118  liveOperations.erase(operation.ptr);
1119  if (!isAttached()) {
1120  mlirOperationDestroy(operation);
1121  }
1122 }
1123 
1124 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1125  MlirOperation operation,
1126  py::object parentKeepAlive) {
1127  auto &liveOperations = contextRef->liveOperations;
1128  // Create.
1129  PyOperation *unownedOperation =
1130  new PyOperation(std::move(contextRef), operation);
1131  // Note that the default return value policy on cast is automatic_reference,
1132  // which does not take ownership (delete will not be called).
1133  // Just be explicit.
1134  py::object pyRef =
1135  py::cast(unownedOperation, py::return_value_policy::take_ownership);
1136  unownedOperation->handle = pyRef;
1137  if (parentKeepAlive) {
1138  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1139  }
1140  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1141  return PyOperationRef(unownedOperation, std::move(pyRef));
1142 }
1143 
1145  MlirOperation operation,
1146  py::object parentKeepAlive) {
1147  auto &liveOperations = contextRef->liveOperations;
1148  auto it = liveOperations.find(operation.ptr);
1149  if (it == liveOperations.end()) {
1150  // Create.
1151  return createInstance(std::move(contextRef), operation,
1152  std::move(parentKeepAlive));
1153  }
1154  // Use existing.
1155  PyOperation *existing = it->second.second;
1156  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1157  return PyOperationRef(existing, std::move(pyRef));
1158 }
1159 
1161  MlirOperation operation,
1162  py::object parentKeepAlive) {
1163  auto &liveOperations = contextRef->liveOperations;
1164  assert(liveOperations.count(operation.ptr) == 0 &&
1165  "cannot create detached operation that already exists");
1166  (void)liveOperations;
1167 
1168  PyOperationRef created = createInstance(std::move(contextRef), operation,
1169  std::move(parentKeepAlive));
1170  created->attached = false;
1171  return created;
1172 }
1173 
1175  const std::string &sourceStr,
1176  const std::string &sourceName) {
1177  PyMlirContext::ErrorCapture errors(contextRef);
1178  MlirOperation op =
1179  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1180  toMlirStringRef(sourceName));
1181  if (mlirOperationIsNull(op))
1182  throw MLIRError("Unable to parse operation assembly", errors.take());
1183  return PyOperation::createDetached(std::move(contextRef), op);
1184 }
1185 
1187  if (!valid) {
1188  throw std::runtime_error("the operation has been invalidated");
1189  }
1190 }
1191 
1192 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1193  bool enableDebugInfo, bool prettyDebugInfo,
1194  bool printGenericOpForm, bool useLocalScope,
1195  bool assumeVerified, py::object fileObject,
1196  bool binary) {
1197  PyOperation &operation = getOperation();
1198  operation.checkValid();
1199  if (fileObject.is_none())
1200  fileObject = py::module::import("sys").attr("stdout");
1201 
1202  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1203  if (largeElementsLimit)
1204  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1205  if (enableDebugInfo)
1206  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1207  /*prettyForm=*/prettyDebugInfo);
1208  if (printGenericOpForm)
1210  if (useLocalScope)
1212  if (assumeVerified)
1214 
1215  PyFileAccumulator accum(fileObject, binary);
1216  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1217  accum.getUserData());
1219 }
1220 
1221 void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1222  bool binary) {
1223  PyOperation &operation = getOperation();
1224  operation.checkValid();
1225  if (fileObject.is_none())
1226  fileObject = py::module::import("sys").attr("stdout");
1227  PyFileAccumulator accum(fileObject, binary);
1228  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1229  accum.getUserData());
1230 }
1231 
1232 void PyOperationBase::writeBytecode(const py::object &fileObject,
1233  std::optional<int64_t> bytecodeVersion) {
1234  PyOperation &operation = getOperation();
1235  operation.checkValid();
1236  PyFileAccumulator accum(fileObject, /*binary=*/true);
1237 
1238  if (!bytecodeVersion.has_value())
1239  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1240  accum.getUserData());
1241 
1242  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1243  mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1245  operation, config, accum.getCallback(), accum.getUserData());
1247  if (mlirLogicalResultIsFailure(res))
1248  throw py::value_error((Twine("Unable to honor desired bytecode version ") +
1249  Twine(*bytecodeVersion))
1250  .str());
1251 }
1252 
1254  std::function<MlirWalkResult(MlirOperation)> callback,
1255  MlirWalkOrder walkOrder) {
1256  PyOperation &operation = getOperation();
1257  operation.checkValid();
1258  struct UserData {
1259  std::function<MlirWalkResult(MlirOperation)> callback;
1260  bool gotException;
1261  std::string exceptionWhat;
1262  py::object exceptionType;
1263  };
1264  UserData userData{callback, false, {}, {}};
1265  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1266  void *userData) {
1267  UserData *calleeUserData = static_cast<UserData *>(userData);
1268  try {
1269  return (calleeUserData->callback)(op);
1270  } catch (py::error_already_set &e) {
1271  calleeUserData->gotException = true;
1272  calleeUserData->exceptionWhat = e.what();
1273  calleeUserData->exceptionType = e.type();
1275  }
1276  };
1277  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1278  if (userData.gotException) {
1279  std::string message("Exception raised in callback: ");
1280  message.append(userData.exceptionWhat);
1281  throw std::runtime_error(message);
1282  }
1283 }
1284 
1285 py::object PyOperationBase::getAsm(bool binary,
1286  std::optional<int64_t> largeElementsLimit,
1287  bool enableDebugInfo, bool prettyDebugInfo,
1288  bool printGenericOpForm, bool useLocalScope,
1289  bool assumeVerified) {
1290  py::object fileObject;
1291  if (binary) {
1292  fileObject = py::module::import("io").attr("BytesIO")();
1293  } else {
1294  fileObject = py::module::import("io").attr("StringIO")();
1295  }
1296  print(/*largeElementsLimit=*/largeElementsLimit,
1297  /*enableDebugInfo=*/enableDebugInfo,
1298  /*prettyDebugInfo=*/prettyDebugInfo,
1299  /*printGenericOpForm=*/printGenericOpForm,
1300  /*useLocalScope=*/useLocalScope,
1301  /*assumeVerified=*/assumeVerified,
1302  /*fileObject=*/fileObject,
1303  /*binary=*/binary);
1304 
1305  return fileObject.attr("getvalue")();
1306 }
1307 
1309  PyOperation &operation = getOperation();
1310  PyOperation &otherOp = other.getOperation();
1311  operation.checkValid();
1312  otherOp.checkValid();
1313  mlirOperationMoveAfter(operation, otherOp);
1314  operation.parentKeepAlive = otherOp.parentKeepAlive;
1315 }
1316 
1318  PyOperation &operation = getOperation();
1319  PyOperation &otherOp = other.getOperation();
1320  operation.checkValid();
1321  otherOp.checkValid();
1322  mlirOperationMoveBefore(operation, otherOp);
1323  operation.parentKeepAlive = otherOp.parentKeepAlive;
1324 }
1325 
1327  PyOperation &op = getOperation();
1329  if (!mlirOperationVerify(op.get()))
1330  throw MLIRError("Verification failed", errors.take());
1331  return true;
1332 }
1333 
1334 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1335  checkValid();
1336  if (!isAttached())
1337  throw py::value_error("Detached operations have no parent");
1338  MlirOperation operation = mlirOperationGetParentOperation(get());
1339  if (mlirOperationIsNull(operation))
1340  return {};
1341  return PyOperation::forOperation(getContext(), operation);
1342 }
1343 
1345  checkValid();
1346  std::optional<PyOperationRef> parentOperation = getParentOperation();
1347  MlirBlock block = mlirOperationGetBlock(get());
1348  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1349  assert(parentOperation && "Operation has no parent");
1350  return PyBlock{std::move(*parentOperation), block};
1351 }
1352 
1354  checkValid();
1355  return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1356 }
1357 
1358 py::object PyOperation::createFromCapsule(py::object capsule) {
1359  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1360  if (mlirOperationIsNull(rawOperation))
1361  throw py::error_already_set();
1362  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1363  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1364  .releaseObject();
1365 }
1366 
1368  const py::object &maybeIp) {
1369  // InsertPoint active?
1370  if (!maybeIp.is(py::cast(false))) {
1371  PyInsertionPoint *ip;
1372  if (maybeIp.is_none()) {
1374  } else {
1375  ip = py::cast<PyInsertionPoint *>(maybeIp);
1376  }
1377  if (ip)
1378  ip->insert(*op.get());
1379  }
1380 }
1381 
1382 py::object PyOperation::create(const std::string &name,
1383  std::optional<std::vector<PyType *>> results,
1384  std::optional<std::vector<PyValue *>> operands,
1385  std::optional<py::dict> attributes,
1386  std::optional<std::vector<PyBlock *>> successors,
1387  int regions, DefaultingPyLocation location,
1388  const py::object &maybeIp, bool inferType) {
1389  llvm::SmallVector<MlirValue, 4> mlirOperands;
1390  llvm::SmallVector<MlirType, 4> mlirResults;
1391  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1393 
1394  // General parameter validation.
1395  if (regions < 0)
1396  throw py::value_error("number of regions must be >= 0");
1397 
1398  // Unpack/validate operands.
1399  if (operands) {
1400  mlirOperands.reserve(operands->size());
1401  for (PyValue *operand : *operands) {
1402  if (!operand)
1403  throw py::value_error("operand value cannot be None");
1404  mlirOperands.push_back(operand->get());
1405  }
1406  }
1407 
1408  // Unpack/validate results.
1409  if (results) {
1410  mlirResults.reserve(results->size());
1411  for (PyType *result : *results) {
1412  // TODO: Verify result type originate from the same context.
1413  if (!result)
1414  throw py::value_error("result type cannot be None");
1415  mlirResults.push_back(*result);
1416  }
1417  }
1418  // Unpack/validate attributes.
1419  if (attributes) {
1420  mlirAttributes.reserve(attributes->size());
1421  for (auto &it : *attributes) {
1422  std::string key;
1423  try {
1424  key = it.first.cast<std::string>();
1425  } catch (py::cast_error &err) {
1426  std::string msg = "Invalid attribute key (not a string) when "
1427  "attempting to create the operation \"" +
1428  name + "\" (" + err.what() + ")";
1429  throw py::cast_error(msg);
1430  }
1431  try {
1432  auto &attribute = it.second.cast<PyAttribute &>();
1433  // TODO: Verify attribute originates from the same context.
1434  mlirAttributes.emplace_back(std::move(key), attribute);
1435  } catch (py::reference_cast_error &) {
1436  // This exception seems thrown when the value is "None".
1437  std::string msg =
1438  "Found an invalid (`None`?) attribute value for the key \"" + key +
1439  "\" when attempting to create the operation \"" + name + "\"";
1440  throw py::cast_error(msg);
1441  } catch (py::cast_error &err) {
1442  std::string msg = "Invalid attribute value for the key \"" + key +
1443  "\" when attempting to create the operation \"" +
1444  name + "\" (" + err.what() + ")";
1445  throw py::cast_error(msg);
1446  }
1447  }
1448  }
1449  // Unpack/validate successors.
1450  if (successors) {
1451  mlirSuccessors.reserve(successors->size());
1452  for (auto *successor : *successors) {
1453  // TODO: Verify successor originate from the same context.
1454  if (!successor)
1455  throw py::value_error("successor block cannot be None");
1456  mlirSuccessors.push_back(successor->get());
1457  }
1458  }
1459 
1460  // Apply unpacked/validated to the operation state. Beyond this
1461  // point, exceptions cannot be thrown or else the state will leak.
1462  MlirOperationState state =
1463  mlirOperationStateGet(toMlirStringRef(name), location);
1464  if (!mlirOperands.empty())
1465  mlirOperationStateAddOperands(&state, mlirOperands.size(),
1466  mlirOperands.data());
1467  state.enableResultTypeInference = inferType;
1468  if (!mlirResults.empty())
1469  mlirOperationStateAddResults(&state, mlirResults.size(),
1470  mlirResults.data());
1471  if (!mlirAttributes.empty()) {
1472  // Note that the attribute names directly reference bytes in
1473  // mlirAttributes, so that vector must not be changed from here
1474  // on.
1475  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1476  mlirNamedAttributes.reserve(mlirAttributes.size());
1477  for (auto &it : mlirAttributes)
1478  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1480  toMlirStringRef(it.first)),
1481  it.second));
1482  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1483  mlirNamedAttributes.data());
1484  }
1485  if (!mlirSuccessors.empty())
1486  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1487  mlirSuccessors.data());
1488  if (regions) {
1490  mlirRegions.resize(regions);
1491  for (int i = 0; i < regions; ++i)
1492  mlirRegions[i] = mlirRegionCreate();
1493  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1494  mlirRegions.data());
1495  }
1496 
1497  // Construct the operation.
1498  MlirOperation operation = mlirOperationCreate(&state);
1499  if (!operation.ptr)
1500  throw py::value_error("Operation creation failed");
1501  PyOperationRef created =
1502  PyOperation::createDetached(location->getContext(), operation);
1503  maybeInsertOperation(created, maybeIp);
1504 
1505  return created->createOpView();
1506 }
1507 
1508 py::object PyOperation::clone(const py::object &maybeIp) {
1509  MlirOperation clonedOperation = mlirOperationClone(operation);
1510  PyOperationRef cloned =
1511  PyOperation::createDetached(getContext(), clonedOperation);
1512  maybeInsertOperation(cloned, maybeIp);
1513 
1514  return cloned->createOpView();
1515 }
1516 
1518  checkValid();
1519  MlirIdentifier ident = mlirOperationGetName(get());
1520  MlirStringRef identStr = mlirIdentifierStr(ident);
1521  auto operationCls = PyGlobals::get().lookupOperationClass(
1522  StringRef(identStr.data, identStr.length));
1523  if (operationCls)
1524  return PyOpView::constructDerived(*operationCls, *getRef().get());
1525  return py::cast(PyOpView(getRef().getObject()));
1526 }
1527 
1529  checkValid();
1530  // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1531  // Python reference to a child operation is live. All children should also
1532  // have their `valid` bit set to false.
1533  auto &liveOperations = getContext()->liveOperations;
1534  if (liveOperations.count(operation.ptr))
1535  liveOperations.erase(operation.ptr);
1536  mlirOperationDestroy(operation);
1537  valid = false;
1538 }
1539 
1540 //------------------------------------------------------------------------------
1541 // PyOpView
1542 //------------------------------------------------------------------------------
1543 
1544 static void populateResultTypes(StringRef name, py::list resultTypeList,
1545  const py::object &resultSegmentSpecObj,
1546  std::vector<int32_t> &resultSegmentLengths,
1547  std::vector<PyType *> &resultTypes) {
1548  resultTypes.reserve(resultTypeList.size());
1549  if (resultSegmentSpecObj.is_none()) {
1550  // Non-variadic result unpacking.
1551  for (const auto &it : llvm::enumerate(resultTypeList)) {
1552  try {
1553  resultTypes.push_back(py::cast<PyType *>(it.value()));
1554  if (!resultTypes.back())
1555  throw py::cast_error();
1556  } catch (py::cast_error &err) {
1557  throw py::value_error((llvm::Twine("Result ") +
1558  llvm::Twine(it.index()) + " of operation \"" +
1559  name + "\" must be a Type (" + err.what() + ")")
1560  .str());
1561  }
1562  }
1563  } else {
1564  // Sized result unpacking.
1565  auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1566  if (resultSegmentSpec.size() != resultTypeList.size()) {
1567  throw py::value_error((llvm::Twine("Operation \"") + name +
1568  "\" requires " +
1569  llvm::Twine(resultSegmentSpec.size()) +
1570  " result segments but was provided " +
1571  llvm::Twine(resultTypeList.size()))
1572  .str());
1573  }
1574  resultSegmentLengths.reserve(resultTypeList.size());
1575  for (const auto &it :
1576  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1577  int segmentSpec = std::get<1>(it.value());
1578  if (segmentSpec == 1 || segmentSpec == 0) {
1579  // Unpack unary element.
1580  try {
1581  auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1582  if (resultType) {
1583  resultTypes.push_back(resultType);
1584  resultSegmentLengths.push_back(1);
1585  } else if (segmentSpec == 0) {
1586  // Allowed to be optional.
1587  resultSegmentLengths.push_back(0);
1588  } else {
1589  throw py::cast_error("was None and result is not optional");
1590  }
1591  } catch (py::cast_error &err) {
1592  throw py::value_error((llvm::Twine("Result ") +
1593  llvm::Twine(it.index()) + " of operation \"" +
1594  name + "\" must be a Type (" + err.what() +
1595  ")")
1596  .str());
1597  }
1598  } else if (segmentSpec == -1) {
1599  // Unpack sequence by appending.
1600  try {
1601  if (std::get<0>(it.value()).is_none()) {
1602  // Treat it as an empty list.
1603  resultSegmentLengths.push_back(0);
1604  } else {
1605  // Unpack the list.
1606  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1607  for (py::object segmentItem : segment) {
1608  resultTypes.push_back(py::cast<PyType *>(segmentItem));
1609  if (!resultTypes.back()) {
1610  throw py::cast_error("contained a None item");
1611  }
1612  }
1613  resultSegmentLengths.push_back(segment.size());
1614  }
1615  } catch (std::exception &err) {
1616  // NOTE: Sloppy to be using a catch-all here, but there are at least
1617  // three different unrelated exceptions that can be thrown in the
1618  // above "casts". Just keep the scope above small and catch them all.
1619  throw py::value_error((llvm::Twine("Result ") +
1620  llvm::Twine(it.index()) + " of operation \"" +
1621  name + "\" must be a Sequence of Types (" +
1622  err.what() + ")")
1623  .str());
1624  }
1625  } else {
1626  throw py::value_error("Unexpected segment spec");
1627  }
1628  }
1629  }
1630 }
1631 
1633  const py::object &cls, std::optional<py::list> resultTypeList,
1634  py::list operandList, std::optional<py::dict> attributes,
1635  std::optional<std::vector<PyBlock *>> successors,
1636  std::optional<int> regions, DefaultingPyLocation location,
1637  const py::object &maybeIp) {
1638  PyMlirContextRef context = location->getContext();
1639  // Class level operation construction metadata.
1640  std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1641  // Operand and result segment specs are either none, which does no
1642  // variadic unpacking, or a list of ints with segment sizes, where each
1643  // element is either a positive number (typically 1 for a scalar) or -1 to
1644  // indicate that it is derived from the length of the same-indexed operand
1645  // or result (implying that it is a list at that position).
1646  py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1647  py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1648 
1649  std::vector<int32_t> operandSegmentLengths;
1650  std::vector<int32_t> resultSegmentLengths;
1651 
1652  // Validate/determine region count.
1653  auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1654  int opMinRegionCount = std::get<0>(opRegionSpec);
1655  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1656  if (!regions) {
1657  regions = opMinRegionCount;
1658  }
1659  if (*regions < opMinRegionCount) {
1660  throw py::value_error(
1661  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1662  llvm::Twine(opMinRegionCount) +
1663  " regions but was built with regions=" + llvm::Twine(*regions))
1664  .str());
1665  }
1666  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1667  throw py::value_error(
1668  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1669  llvm::Twine(opMinRegionCount) +
1670  " regions but was built with regions=" + llvm::Twine(*regions))
1671  .str());
1672  }
1673 
1674  // Unpack results.
1675  std::vector<PyType *> resultTypes;
1676  if (resultTypeList.has_value()) {
1677  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1678  resultSegmentLengths, resultTypes);
1679  }
1680 
1681  // Unpack operands.
1682  std::vector<PyValue *> operands;
1683  operands.reserve(operands.size());
1684  if (operandSegmentSpecObj.is_none()) {
1685  // Non-sized operand unpacking.
1686  for (const auto &it : llvm::enumerate(operandList)) {
1687  try {
1688  operands.push_back(py::cast<PyValue *>(it.value()));
1689  if (!operands.back())
1690  throw py::cast_error();
1691  } catch (py::cast_error &err) {
1692  throw py::value_error((llvm::Twine("Operand ") +
1693  llvm::Twine(it.index()) + " of operation \"" +
1694  name + "\" must be a Value (" + err.what() + ")")
1695  .str());
1696  }
1697  }
1698  } else {
1699  // Sized operand unpacking.
1700  auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1701  if (operandSegmentSpec.size() != operandList.size()) {
1702  throw py::value_error((llvm::Twine("Operation \"") + name +
1703  "\" requires " +
1704  llvm::Twine(operandSegmentSpec.size()) +
1705  "operand segments but was provided " +
1706  llvm::Twine(operandList.size()))
1707  .str());
1708  }
1709  operandSegmentLengths.reserve(operandList.size());
1710  for (const auto &it :
1711  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1712  int segmentSpec = std::get<1>(it.value());
1713  if (segmentSpec == 1 || segmentSpec == 0) {
1714  // Unpack unary element.
1715  try {
1716  auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1717  if (operandValue) {
1718  operands.push_back(operandValue);
1719  operandSegmentLengths.push_back(1);
1720  } else if (segmentSpec == 0) {
1721  // Allowed to be optional.
1722  operandSegmentLengths.push_back(0);
1723  } else {
1724  throw py::cast_error("was None and operand is not optional");
1725  }
1726  } catch (py::cast_error &err) {
1727  throw py::value_error((llvm::Twine("Operand ") +
1728  llvm::Twine(it.index()) + " of operation \"" +
1729  name + "\" must be a Value (" + err.what() +
1730  ")")
1731  .str());
1732  }
1733  } else if (segmentSpec == -1) {
1734  // Unpack sequence by appending.
1735  try {
1736  if (std::get<0>(it.value()).is_none()) {
1737  // Treat it as an empty list.
1738  operandSegmentLengths.push_back(0);
1739  } else {
1740  // Unpack the list.
1741  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1742  for (py::object segmentItem : segment) {
1743  operands.push_back(py::cast<PyValue *>(segmentItem));
1744  if (!operands.back()) {
1745  throw py::cast_error("contained a None item");
1746  }
1747  }
1748  operandSegmentLengths.push_back(segment.size());
1749  }
1750  } catch (std::exception &err) {
1751  // NOTE: Sloppy to be using a catch-all here, but there are at least
1752  // three different unrelated exceptions that can be thrown in the
1753  // above "casts". Just keep the scope above small and catch them all.
1754  throw py::value_error((llvm::Twine("Operand ") +
1755  llvm::Twine(it.index()) + " of operation \"" +
1756  name + "\" must be a Sequence of Values (" +
1757  err.what() + ")")
1758  .str());
1759  }
1760  } else {
1761  throw py::value_error("Unexpected segment spec");
1762  }
1763  }
1764  }
1765 
1766  // Merge operand/result segment lengths into attributes if needed.
1767  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1768  // Dup.
1769  if (attributes) {
1770  attributes = py::dict(*attributes);
1771  } else {
1772  attributes = py::dict();
1773  }
1774  if (attributes->contains("resultSegmentSizes") ||
1775  attributes->contains("operandSegmentSizes")) {
1776  throw py::value_error("Manually setting a 'resultSegmentSizes' or "
1777  "'operandSegmentSizes' attribute is unsupported. "
1778  "Use Operation.create for such low-level access.");
1779  }
1780 
1781  // Add resultSegmentSizes attribute.
1782  if (!resultSegmentLengths.empty()) {
1783  MlirAttribute segmentLengthAttr =
1784  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1785  resultSegmentLengths.data());
1786  (*attributes)["resultSegmentSizes"] =
1787  PyAttribute(context, segmentLengthAttr);
1788  }
1789 
1790  // Add operandSegmentSizes attribute.
1791  if (!operandSegmentLengths.empty()) {
1792  MlirAttribute segmentLengthAttr =
1793  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1794  operandSegmentLengths.data());
1795  (*attributes)["operandSegmentSizes"] =
1796  PyAttribute(context, segmentLengthAttr);
1797  }
1798  }
1799 
1800  // Delegate to create.
1801  return PyOperation::create(name,
1802  /*results=*/std::move(resultTypes),
1803  /*operands=*/std::move(operands),
1804  /*attributes=*/std::move(attributes),
1805  /*successors=*/std::move(successors),
1806  /*regions=*/*regions, location, maybeIp,
1807  !resultTypeList);
1808 }
1809 
1810 pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1811  const PyOperation &operation) {
1812  // TODO: pybind11 2.6 supports a more direct form.
1813  // Upgrade many years from now.
1814  // auto opViewType = py::type::of<PyOpView>();
1815  py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1816  py::object instance = cls.attr("__new__")(cls);
1817  opViewType.attr("__init__")(instance, operation);
1818  return instance;
1819 }
1820 
1821 PyOpView::PyOpView(const py::object &operationObject)
1822  // Casting through the PyOperationBase base-class and then back to the
1823  // Operation lets us accept any PyOperationBase subclass.
1824  : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1825  operationObject(operation.getRef().getObject()) {}
1826 
1827 //------------------------------------------------------------------------------
1828 // PyInsertionPoint.
1829 //------------------------------------------------------------------------------
1830 
1832 
1834  : refOperation(beforeOperationBase.getOperation().getRef()),
1835  block((*refOperation)->getBlock()) {}
1836 
1838  PyOperation &operation = operationBase.getOperation();
1839  if (operation.isAttached())
1840  throw py::value_error(
1841  "Attempt to insert operation that is already attached");
1842  block.getParentOperation()->checkValid();
1843  MlirOperation beforeOp = {nullptr};
1844  if (refOperation) {
1845  // Insert before operation.
1846  (*refOperation)->checkValid();
1847  beforeOp = (*refOperation)->get();
1848  } else {
1849  // Insert at end (before null) is only valid if the block does not
1850  // already end in a known terminator (violating this will cause assertion
1851  // failures later).
1852  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1853  throw py::index_error("Cannot insert operation at the end of a block "
1854  "that already has a terminator. Did you mean to "
1855  "use 'InsertionPoint.at_block_terminator(block)' "
1856  "versus 'InsertionPoint(block)'?");
1857  }
1858  }
1859  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1860  operation.setAttached();
1861 }
1862 
1864  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1865  if (mlirOperationIsNull(firstOp)) {
1866  // Just insert at end.
1867  return PyInsertionPoint(block);
1868  }
1869 
1870  // Insert before first op.
1872  block.getParentOperation()->getContext(), firstOp);
1873  return PyInsertionPoint{block, std::move(firstOpRef)};
1874 }
1875 
1877  MlirOperation terminator = mlirBlockGetTerminator(block.get());
1878  if (mlirOperationIsNull(terminator))
1879  throw py::value_error("Block has no terminator");
1880  PyOperationRef terminatorOpRef = PyOperation::forOperation(
1881  block.getParentOperation()->getContext(), terminator);
1882  return PyInsertionPoint{block, std::move(terminatorOpRef)};
1883 }
1884 
1887 }
1888 
1889 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1890  const pybind11::object &excVal,
1891  const pybind11::object &excTb) {
1893 }
1894 
1895 //------------------------------------------------------------------------------
1896 // PyAttribute.
1897 //------------------------------------------------------------------------------
1898 
1899 bool PyAttribute::operator==(const PyAttribute &other) const {
1900  return mlirAttributeEqual(attr, other.attr);
1901 }
1902 
1904  return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1905 }
1906 
1908  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1909  if (mlirAttributeIsNull(rawAttr))
1910  throw py::error_already_set();
1911  return PyAttribute(
1913 }
1914 
1915 //------------------------------------------------------------------------------
1916 // PyNamedAttribute.
1917 //------------------------------------------------------------------------------
1918 
1919 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1920  : ownedName(new std::string(std::move(ownedName))) {
1923  toMlirStringRef(*this->ownedName)),
1924  attr);
1925 }
1926 
1927 //------------------------------------------------------------------------------
1928 // PyType.
1929 //------------------------------------------------------------------------------
1930 
1931 bool PyType::operator==(const PyType &other) const {
1932  return mlirTypeEqual(type, other.type);
1933 }
1934 
1935 py::object PyType::getCapsule() {
1936  return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1937 }
1938 
1939 PyType PyType::createFromCapsule(py::object capsule) {
1940  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1941  if (mlirTypeIsNull(rawType))
1942  throw py::error_already_set();
1944  rawType);
1945 }
1946 
1947 //------------------------------------------------------------------------------
1948 // PyTypeID.
1949 //------------------------------------------------------------------------------
1950 
1951 py::object PyTypeID::getCapsule() {
1952  return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
1953 }
1954 
1956  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1957  if (mlirTypeIDIsNull(mlirTypeID))
1958  throw py::error_already_set();
1959  return PyTypeID(mlirTypeID);
1960 }
1961 bool PyTypeID::operator==(const PyTypeID &other) const {
1962  return mlirTypeIDEqual(typeID, other.typeID);
1963 }
1964 
1965 //------------------------------------------------------------------------------
1966 // PyValue and subclasses.
1967 //------------------------------------------------------------------------------
1968 
1969 pybind11::object PyValue::getCapsule() {
1970  return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1971 }
1972 
1973 pybind11::object PyValue::maybeDownCast() {
1974  MlirType type = mlirValueGetType(get());
1975  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
1976  assert(!mlirTypeIDIsNull(mlirTypeID) &&
1977  "mlirTypeID was expected to be non-null.");
1978  std::optional<pybind11::function> valueCaster =
1980  // py::return_value_policy::move means use std::move to move the return value
1981  // contents into a new instance that will be owned by Python.
1982  py::object thisObj = py::cast(this, py::return_value_policy::move);
1983  if (!valueCaster)
1984  return thisObj;
1985  return valueCaster.value()(thisObj);
1986 }
1987 
1988 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1989  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1990  if (mlirValueIsNull(value))
1991  throw py::error_already_set();
1992  MlirOperation owner;
1993  if (mlirValueIsAOpResult(value))
1994  owner = mlirOpResultGetOwner(value);
1995  if (mlirValueIsABlockArgument(value))
1997  if (mlirOperationIsNull(owner))
1998  throw py::error_already_set();
1999  MlirContext ctx = mlirOperationGetContext(owner);
2000  PyOperationRef ownerRef =
2002  return PyValue(ownerRef, value);
2003 }
2004 
2005 //------------------------------------------------------------------------------
2006 // PySymbolTable.
2007 //------------------------------------------------------------------------------
2008 
2010  : operation(operation.getOperation().getRef()) {
2011  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2012  if (mlirSymbolTableIsNull(symbolTable)) {
2013  throw py::cast_error("Operation is not a Symbol Table.");
2014  }
2015 }
2016 
2017 py::object PySymbolTable::dunderGetItem(const std::string &name) {
2018  operation->checkValid();
2019  MlirOperation symbol = mlirSymbolTableLookup(
2020  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2021  if (mlirOperationIsNull(symbol))
2022  throw py::key_error("Symbol '" + name + "' not in the symbol table.");
2023 
2024  return PyOperation::forOperation(operation->getContext(), symbol,
2025  operation.getObject())
2026  ->createOpView();
2027 }
2028 
2030  operation->checkValid();
2031  symbol.getOperation().checkValid();
2032  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2033  // The operation is also erased, so we must invalidate it. There may be Python
2034  // references to this operation so we don't want to delete it from the list of
2035  // live operations here.
2036  symbol.getOperation().valid = false;
2037 }
2038 
2039 void PySymbolTable::dunderDel(const std::string &name) {
2040  py::object operation = dunderGetItem(name);
2041  erase(py::cast<PyOperationBase &>(operation));
2042 }
2043 
2044 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2045  operation->checkValid();
2046  symbol.getOperation().checkValid();
2047  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2049  if (mlirAttributeIsNull(symbolAttr))
2050  throw py::value_error("Expected operation to have a symbol name.");
2051  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2052 }
2053 
2055  // Op must already be a symbol.
2056  PyOperation &operation = symbol.getOperation();
2057  operation.checkValid();
2059  MlirAttribute existingNameAttr =
2060  mlirOperationGetAttributeByName(operation.get(), attrName);
2061  if (mlirAttributeIsNull(existingNameAttr))
2062  throw py::value_error("Expected operation to have a symbol name.");
2063  return existingNameAttr;
2064 }
2065 
2067  const std::string &name) {
2068  // Op must already be a symbol.
2069  PyOperation &operation = symbol.getOperation();
2070  operation.checkValid();
2072  MlirAttribute existingNameAttr =
2073  mlirOperationGetAttributeByName(operation.get(), attrName);
2074  if (mlirAttributeIsNull(existingNameAttr))
2075  throw py::value_error("Expected operation to have a symbol name.");
2076  MlirAttribute newNameAttr =
2077  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2078  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2079 }
2080 
2082  PyOperation &operation = symbol.getOperation();
2083  operation.checkValid();
2085  MlirAttribute existingVisAttr =
2086  mlirOperationGetAttributeByName(operation.get(), attrName);
2087  if (mlirAttributeIsNull(existingVisAttr))
2088  throw py::value_error("Expected operation to have a symbol visibility.");
2089  return existingVisAttr;
2090 }
2091 
2093  const std::string &visibility) {
2094  if (visibility != "public" && visibility != "private" &&
2095  visibility != "nested")
2096  throw py::value_error(
2097  "Expected visibility to be 'public', 'private' or 'nested'");
2098  PyOperation &operation = symbol.getOperation();
2099  operation.checkValid();
2101  MlirAttribute existingVisAttr =
2102  mlirOperationGetAttributeByName(operation.get(), attrName);
2103  if (mlirAttributeIsNull(existingVisAttr))
2104  throw py::value_error("Expected operation to have a symbol visibility.");
2105  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2106  toMlirStringRef(visibility));
2107  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2108 }
2109 
2110 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2111  const std::string &newSymbol,
2112  PyOperationBase &from) {
2113  PyOperation &fromOperation = from.getOperation();
2114  fromOperation.checkValid();
2116  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2117  from.getOperation())))
2118 
2119  throw py::value_error("Symbol rename failed");
2120 }
2121 
2123  bool allSymUsesVisible,
2124  py::object callback) {
2125  PyOperation &fromOperation = from.getOperation();
2126  fromOperation.checkValid();
2127  struct UserData {
2128  PyMlirContextRef context;
2129  py::object callback;
2130  bool gotException;
2131  std::string exceptionWhat;
2132  py::object exceptionType;
2133  };
2134  UserData userData{
2135  fromOperation.getContext(), std::move(callback), false, {}, {}};
2137  fromOperation.get(), allSymUsesVisible,
2138  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2139  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2140  auto pyFoundOp =
2141  PyOperation::forOperation(calleeUserData->context, foundOp);
2142  if (calleeUserData->gotException)
2143  return;
2144  try {
2145  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2146  } catch (py::error_already_set &e) {
2147  calleeUserData->gotException = true;
2148  calleeUserData->exceptionWhat = e.what();
2149  calleeUserData->exceptionType = e.type();
2150  }
2151  },
2152  static_cast<void *>(&userData));
2153  if (userData.gotException) {
2154  std::string message("Exception raised in callback: ");
2155  message.append(userData.exceptionWhat);
2156  throw std::runtime_error(message);
2157  }
2158 }
2159 
2160 namespace {
2161 /// CRTP base class for Python MLIR values that subclass Value and should be
2162 /// castable from it. The value hierarchy is one level deep and is not supposed
2163 /// to accommodate other levels unless core MLIR changes.
2164 template <typename DerivedTy>
2165 class PyConcreteValue : public PyValue {
2166 public:
2167  // Derived classes must define statics for:
2168  // IsAFunctionTy isaFunction
2169  // const char *pyClassName
2170  // and redefine bindDerived.
2171  using ClassTy = py::class_<DerivedTy, PyValue>;
2172  using IsAFunctionTy = bool (*)(MlirValue);
2173 
2174  PyConcreteValue() = default;
2175  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2176  : PyValue(operationRef, value) {}
2177  PyConcreteValue(PyValue &orig)
2178  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2179 
2180  /// Attempts to cast the original value to the derived type and throws on
2181  /// type mismatches.
2182  static MlirValue castFrom(PyValue &orig) {
2183  if (!DerivedTy::isaFunction(orig.get())) {
2184  auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2185  throw py::value_error((Twine("Cannot cast value to ") +
2186  DerivedTy::pyClassName + " (from " + origRepr +
2187  ")")
2188  .str());
2189  }
2190  return orig.get();
2191  }
2192 
2193  /// Binds the Python module objects to functions of this class.
2194  static void bind(py::module &m) {
2195  auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
2196  cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
2197  cls.def_static(
2198  "isinstance",
2199  [](PyValue &otherValue) -> bool {
2200  return DerivedTy::isaFunction(otherValue);
2201  },
2202  py::arg("other_value"));
2204  [](DerivedTy &self) { return self.maybeDownCast(); });
2205  DerivedTy::bindDerived(cls);
2206  }
2207 
2208  /// Implemented by derived classes to add methods to the Python subclass.
2209  static void bindDerived(ClassTy &m) {}
2210 };
2211 
2212 /// Python wrapper for MlirBlockArgument.
2213 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2214 public:
2215  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2216  static constexpr const char *pyClassName = "BlockArgument";
2217  using PyConcreteValue::PyConcreteValue;
2218 
2219  static void bindDerived(ClassTy &c) {
2220  c.def_property_readonly("owner", [](PyBlockArgument &self) {
2221  return PyBlock(self.getParentOperation(),
2222  mlirBlockArgumentGetOwner(self.get()));
2223  });
2224  c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
2225  return mlirBlockArgumentGetArgNumber(self.get());
2226  });
2227  c.def(
2228  "set_type",
2229  [](PyBlockArgument &self, PyType type) {
2230  return mlirBlockArgumentSetType(self.get(), type);
2231  },
2232  py::arg("type"));
2233  }
2234 };
2235 
2236 /// Python wrapper for MlirOpResult.
2237 class PyOpResult : public PyConcreteValue<PyOpResult> {
2238 public:
2239  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2240  static constexpr const char *pyClassName = "OpResult";
2241  using PyConcreteValue::PyConcreteValue;
2242 
2243  static void bindDerived(ClassTy &c) {
2244  c.def_property_readonly("owner", [](PyOpResult &self) {
2245  assert(
2246  mlirOperationEqual(self.getParentOperation()->get(),
2247  mlirOpResultGetOwner(self.get())) &&
2248  "expected the owner of the value in Python to match that in the IR");
2249  return self.getParentOperation().getObject();
2250  });
2251  c.def_property_readonly("result_number", [](PyOpResult &self) {
2252  return mlirOpResultGetResultNumber(self.get());
2253  });
2254  }
2255 };
2256 
2257 /// Returns the list of types of the values held by container.
2258 template <typename Container>
2259 static std::vector<MlirType> getValueTypes(Container &container,
2260  PyMlirContextRef &context) {
2261  std::vector<MlirType> result;
2262  result.reserve(container.size());
2263  for (int i = 0, e = container.size(); i < e; ++i) {
2264  result.push_back(mlirValueGetType(container.getElement(i).get()));
2265  }
2266  return result;
2267 }
2268 
2269 /// A list of block arguments. Internally, these are stored as consecutive
2270 /// elements, random access is cheap. The argument list is associated with the
2271 /// operation that contains the block (detached blocks are not allowed in
2272 /// Python bindings) and extends its lifetime.
2273 class PyBlockArgumentList
2274  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2275 public:
2276  static constexpr const char *pyClassName = "BlockArgumentList";
2278 
2279  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2280  intptr_t startIndex = 0, intptr_t length = -1,
2281  intptr_t step = 1)
2282  : Sliceable(startIndex,
2283  length == -1 ? mlirBlockGetNumArguments(block) : length,
2284  step),
2285  operation(std::move(operation)), block(block) {}
2286 
2287  static void bindDerived(ClassTy &c) {
2288  c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2289  return getValueTypes(self, self.operation->getContext());
2290  });
2291  }
2292 
2293 private:
2294  /// Give the parent CRTP class access to hook implementations below.
2295  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2296 
2297  /// Returns the number of arguments in the list.
2298  intptr_t getRawNumElements() {
2299  operation->checkValid();
2300  return mlirBlockGetNumArguments(block);
2301  }
2302 
2303  /// Returns `pos`-the element in the list.
2304  PyBlockArgument getRawElement(intptr_t pos) {
2305  MlirValue argument = mlirBlockGetArgument(block, pos);
2306  return PyBlockArgument(operation, argument);
2307  }
2308 
2309  /// Returns a sublist of this list.
2310  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2311  intptr_t step) {
2312  return PyBlockArgumentList(operation, block, startIndex, length, step);
2313  }
2314 
2315  PyOperationRef operation;
2316  MlirBlock block;
2317 };
2318 
2319 /// A list of operation operands. Internally, these are stored as consecutive
2320 /// elements, random access is cheap. The (returned) operand list is associated
2321 /// with the operation whose operands these are, and thus extends the lifetime
2322 /// of this operation.
2323 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2324 public:
2325  static constexpr const char *pyClassName = "OpOperandList";
2326  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2327 
2328  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2329  intptr_t length = -1, intptr_t step = 1)
2330  : Sliceable(startIndex,
2331  length == -1 ? mlirOperationGetNumOperands(operation->get())
2332  : length,
2333  step),
2334  operation(operation) {}
2335 
2336  void dunderSetItem(intptr_t index, PyValue value) {
2337  index = wrapIndex(index);
2338  mlirOperationSetOperand(operation->get(), index, value.get());
2339  }
2340 
2341  static void bindDerived(ClassTy &c) {
2342  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2343  }
2344 
2345 private:
2346  /// Give the parent CRTP class access to hook implementations below.
2347  friend class Sliceable<PyOpOperandList, PyValue>;
2348 
2349  intptr_t getRawNumElements() {
2350  operation->checkValid();
2351  return mlirOperationGetNumOperands(operation->get());
2352  }
2353 
2354  PyValue getRawElement(intptr_t pos) {
2355  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2356  MlirOperation owner;
2357  if (mlirValueIsAOpResult(operand))
2358  owner = mlirOpResultGetOwner(operand);
2359  else if (mlirValueIsABlockArgument(operand))
2361  else
2362  assert(false && "Value must be an block arg or op result.");
2363  PyOperationRef pyOwner =
2364  PyOperation::forOperation(operation->getContext(), owner);
2365  return PyValue(pyOwner, operand);
2366  }
2367 
2368  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2369  return PyOpOperandList(operation, startIndex, length, step);
2370  }
2371 
2372  PyOperationRef operation;
2373 };
2374 
2375 /// A list of operation results. Internally, these are stored as consecutive
2376 /// elements, random access is cheap. The (returned) result list is associated
2377 /// with the operation whose results these are, and thus extends the lifetime of
2378 /// this operation.
2379 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2380 public:
2381  static constexpr const char *pyClassName = "OpResultList";
2382  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2383 
2384  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2385  intptr_t length = -1, intptr_t step = 1)
2386  : Sliceable(startIndex,
2387  length == -1 ? mlirOperationGetNumResults(operation->get())
2388  : length,
2389  step),
2390  operation(std::move(operation)) {}
2391 
2392  static void bindDerived(ClassTy &c) {
2393  c.def_property_readonly("types", [](PyOpResultList &self) {
2394  return getValueTypes(self, self.operation->getContext());
2395  });
2396  c.def_property_readonly("owner", [](PyOpResultList &self) {
2397  return self.operation->createOpView();
2398  });
2399  }
2400 
2401 private:
2402  /// Give the parent CRTP class access to hook implementations below.
2403  friend class Sliceable<PyOpResultList, PyOpResult>;
2404 
2405  intptr_t getRawNumElements() {
2406  operation->checkValid();
2407  return mlirOperationGetNumResults(operation->get());
2408  }
2409 
2410  PyOpResult getRawElement(intptr_t index) {
2411  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2412  return PyOpResult(value);
2413  }
2414 
2415  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2416  return PyOpResultList(operation, startIndex, length, step);
2417  }
2418 
2419  PyOperationRef operation;
2420 };
2421 
2422 /// A list of operation successors. Internally, these are stored as consecutive
2423 /// elements, random access is cheap. The (returned) successor list is
2424 /// associated with the operation whose successors these are, and thus extends
2425 /// the lifetime of this operation.
2426 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2427 public:
2428  static constexpr const char *pyClassName = "OpSuccessors";
2429 
2430  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2431  intptr_t length = -1, intptr_t step = 1)
2432  : Sliceable(startIndex,
2433  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2434  : length,
2435  step),
2436  operation(operation) {}
2437 
2438  void dunderSetItem(intptr_t index, PyBlock block) {
2439  index = wrapIndex(index);
2440  mlirOperationSetSuccessor(operation->get(), index, block.get());
2441  }
2442 
2443  static void bindDerived(ClassTy &c) {
2444  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2445  }
2446 
2447 private:
2448  /// Give the parent CRTP class access to hook implementations below.
2449  friend class Sliceable<PyOpSuccessors, PyBlock>;
2450 
2451  intptr_t getRawNumElements() {
2452  operation->checkValid();
2453  return mlirOperationGetNumSuccessors(operation->get());
2454  }
2455 
2456  PyBlock getRawElement(intptr_t pos) {
2457  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2458  return PyBlock(operation, block);
2459  }
2460 
2461  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2462  return PyOpSuccessors(operation, startIndex, length, step);
2463  }
2464 
2465  PyOperationRef operation;
2466 };
2467 
2468 /// A list of operation attributes. Can be indexed by name, producing
2469 /// attributes, or by index, producing named attributes.
2470 class PyOpAttributeMap {
2471 public:
2472  PyOpAttributeMap(PyOperationRef operation)
2473  : operation(std::move(operation)) {}
2474 
2475  MlirAttribute dunderGetItemNamed(const std::string &name) {
2476  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2477  toMlirStringRef(name));
2478  if (mlirAttributeIsNull(attr)) {
2479  throw py::key_error("attempt to access a non-existent attribute");
2480  }
2481  return attr;
2482  }
2483 
2484  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2485  if (index < 0 || index >= dunderLen()) {
2486  throw py::index_error("attempt to access out of bounds attribute");
2487  }
2488  MlirNamedAttribute namedAttr =
2489  mlirOperationGetAttribute(operation->get(), index);
2490  return PyNamedAttribute(
2491  namedAttr.attribute,
2492  std::string(mlirIdentifierStr(namedAttr.name).data,
2493  mlirIdentifierStr(namedAttr.name).length));
2494  }
2495 
2496  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2497  mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2498  attr);
2499  }
2500 
2501  void dunderDelItem(const std::string &name) {
2502  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2503  toMlirStringRef(name));
2504  if (!removed)
2505  throw py::key_error("attempt to delete a non-existent attribute");
2506  }
2507 
2508  intptr_t dunderLen() {
2509  return mlirOperationGetNumAttributes(operation->get());
2510  }
2511 
2512  bool dunderContains(const std::string &name) {
2514  operation->get(), toMlirStringRef(name)));
2515  }
2516 
2517  static void bind(py::module &m) {
2518  py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2519  .def("__contains__", &PyOpAttributeMap::dunderContains)
2520  .def("__len__", &PyOpAttributeMap::dunderLen)
2521  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2522  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2523  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2524  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2525  }
2526 
2527 private:
2528  PyOperationRef operation;
2529 };
2530 
2531 } // namespace
2532 
2533 //------------------------------------------------------------------------------
2534 // Populates the core exports of the 'ir' submodule.
2535 //------------------------------------------------------------------------------
2536 
2537 void mlir::python::populateIRCore(py::module &m) {
2538  //----------------------------------------------------------------------------
2539  // Enums.
2540  //----------------------------------------------------------------------------
2541  py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2542  .value("ERROR", MlirDiagnosticError)
2543  .value("WARNING", MlirDiagnosticWarning)
2544  .value("NOTE", MlirDiagnosticNote)
2545  .value("REMARK", MlirDiagnosticRemark);
2546 
2547  py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2548  .value("PRE_ORDER", MlirWalkPreOrder)
2549  .value("POST_ORDER", MlirWalkPostOrder);
2550 
2551  py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2552  .value("ADVANCE", MlirWalkResultAdvance)
2553  .value("INTERRUPT", MlirWalkResultInterrupt)
2554  .value("SKIP", MlirWalkResultSkip);
2555 
2556  //----------------------------------------------------------------------------
2557  // Mapping of Diagnostics.
2558  //----------------------------------------------------------------------------
2559  py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2560  .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2561  .def_property_readonly("location", &PyDiagnostic::getLocation)
2562  .def_property_readonly("message", &PyDiagnostic::getMessage)
2563  .def_property_readonly("notes", &PyDiagnostic::getNotes)
2564  .def("__str__", [](PyDiagnostic &self) -> py::str {
2565  if (!self.isValid())
2566  return "<Invalid Diagnostic>";
2567  return self.getMessage();
2568  });
2569 
2570  py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
2571  py::module_local())
2572  .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
2573  .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
2574  .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
2575  .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
2576  .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
2577  .def("__str__",
2578  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2579 
2580  py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2581  .def("detach", &PyDiagnosticHandler::detach)
2582  .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2583  .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2584  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2585  .def("__exit__", &PyDiagnosticHandler::contextExit);
2586 
2587  //----------------------------------------------------------------------------
2588  // Mapping of MlirContext.
2589  // Note that this is exported as _BaseContext. The containing, Python level
2590  // __init__.py will subclass it with site-specific functionality and set a
2591  // "Context" attribute on this module.
2592  //----------------------------------------------------------------------------
2593  py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2594  .def(py::init<>(&PyMlirContext::createNewContextForInit))
2595  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2596  .def("_get_context_again",
2597  [](PyMlirContext &self) {
2598  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2599  return ref.releaseObject();
2600  })
2601  .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2602  .def("_get_live_operation_objects",
2603  &PyMlirContext::getLiveOperationObjects)
2604  .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2605  .def("_clear_live_operations_inside",
2606  py::overload_cast<MlirOperation>(
2607  &PyMlirContext::clearOperationsInside))
2608  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2609  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2610  &PyMlirContext::getCapsule)
2611  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2612  .def("__enter__", &PyMlirContext::contextEnter)
2613  .def("__exit__", &PyMlirContext::contextExit)
2614  .def_property_readonly_static(
2615  "current",
2616  [](py::object & /*class*/) {
2617  auto *context = PyThreadContextEntry::getDefaultContext();
2618  if (!context)
2619  return py::none().cast<py::object>();
2620  return py::cast(context);
2621  },
2622  "Gets the Context bound to the current thread or raises ValueError")
2623  .def_property_readonly(
2624  "dialects",
2625  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2626  "Gets a container for accessing dialects by name")
2627  .def_property_readonly(
2628  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2629  "Alias for 'dialect'")
2630  .def(
2631  "get_dialect_descriptor",
2632  [=](PyMlirContext &self, std::string &name) {
2633  MlirDialect dialect = mlirContextGetOrLoadDialect(
2634  self.get(), {name.data(), name.size()});
2635  if (mlirDialectIsNull(dialect)) {
2636  throw py::value_error(
2637  (Twine("Dialect '") + name + "' not found").str());
2638  }
2639  return PyDialectDescriptor(self.getRef(), dialect);
2640  },
2641  py::arg("dialect_name"),
2642  "Gets or loads a dialect by name, returning its descriptor object")
2643  .def_property(
2644  "allow_unregistered_dialects",
2645  [](PyMlirContext &self) -> bool {
2647  },
2648  [](PyMlirContext &self, bool value) {
2650  })
2651  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2652  py::arg("callback"),
2653  "Attaches a diagnostic handler that will receive callbacks")
2654  .def(
2655  "enable_multithreading",
2656  [](PyMlirContext &self, bool enable) {
2657  mlirContextEnableMultithreading(self.get(), enable);
2658  },
2659  py::arg("enable"))
2660  .def(
2661  "is_registered_operation",
2662  [](PyMlirContext &self, std::string &name) {
2664  self.get(), MlirStringRef{name.data(), name.size()});
2665  },
2666  py::arg("operation_name"))
2667  .def(
2668  "append_dialect_registry",
2669  [](PyMlirContext &self, PyDialectRegistry &registry) {
2670  mlirContextAppendDialectRegistry(self.get(), registry);
2671  },
2672  py::arg("registry"))
2673  .def_property("emit_error_diagnostics", nullptr,
2674  &PyMlirContext::setEmitErrorDiagnostics,
2675  "Emit error diagnostics to diagnostic handlers. By default "
2676  "error diagnostics are captured and reported through "
2677  "MLIRError exceptions.")
2678  .def("load_all_available_dialects", [](PyMlirContext &self) {
2680  });
2681 
2682  //----------------------------------------------------------------------------
2683  // Mapping of PyDialectDescriptor
2684  //----------------------------------------------------------------------------
2685  py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2686  .def_property_readonly("namespace",
2687  [](PyDialectDescriptor &self) {
2688  MlirStringRef ns =
2689  mlirDialectGetNamespace(self.get());
2690  return py::str(ns.data, ns.length);
2691  })
2692  .def("__repr__", [](PyDialectDescriptor &self) {
2694  std::string repr("<DialectDescriptor ");
2695  repr.append(ns.data, ns.length);
2696  repr.append(">");
2697  return repr;
2698  });
2699 
2700  //----------------------------------------------------------------------------
2701  // Mapping of PyDialects
2702  //----------------------------------------------------------------------------
2703  py::class_<PyDialects>(m, "Dialects", py::module_local())
2704  .def("__getitem__",
2705  [=](PyDialects &self, std::string keyName) {
2706  MlirDialect dialect =
2707  self.getDialectForKey(keyName, /*attrError=*/false);
2708  py::object descriptor =
2709  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2710  return createCustomDialectWrapper(keyName, std::move(descriptor));
2711  })
2712  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2713  MlirDialect dialect =
2714  self.getDialectForKey(attrName, /*attrError=*/true);
2715  py::object descriptor =
2716  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2717  return createCustomDialectWrapper(attrName, std::move(descriptor));
2718  });
2719 
2720  //----------------------------------------------------------------------------
2721  // Mapping of PyDialect
2722  //----------------------------------------------------------------------------
2723  py::class_<PyDialect>(m, "Dialect", py::module_local())
2724  .def(py::init<py::object>(), py::arg("descriptor"))
2725  .def_property_readonly(
2726  "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2727  .def("__repr__", [](py::object self) {
2728  auto clazz = self.attr("__class__");
2729  return py::str("<Dialect ") +
2730  self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2731  clazz.attr("__module__") + py::str(".") +
2732  clazz.attr("__name__") + py::str(")>");
2733  });
2734 
2735  //----------------------------------------------------------------------------
2736  // Mapping of PyDialectRegistry
2737  //----------------------------------------------------------------------------
2738  py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2739  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2740  &PyDialectRegistry::getCapsule)
2741  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2742  .def(py::init<>());
2743 
2744  //----------------------------------------------------------------------------
2745  // Mapping of Location
2746  //----------------------------------------------------------------------------
2747  py::class_<PyLocation>(m, "Location", py::module_local())
2748  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2749  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2750  .def("__enter__", &PyLocation::contextEnter)
2751  .def("__exit__", &PyLocation::contextExit)
2752  .def("__eq__",
2753  [](PyLocation &self, PyLocation &other) -> bool {
2754  return mlirLocationEqual(self, other);
2755  })
2756  .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2757  .def_property_readonly_static(
2758  "current",
2759  [](py::object & /*class*/) {
2760  auto *loc = PyThreadContextEntry::getDefaultLocation();
2761  if (!loc)
2762  throw py::value_error("No current Location");
2763  return loc;
2764  },
2765  "Gets the Location bound to the current thread or raises ValueError")
2766  .def_static(
2767  "unknown",
2768  [](DefaultingPyMlirContext context) {
2769  return PyLocation(context->getRef(),
2770  mlirLocationUnknownGet(context->get()));
2771  },
2772  py::arg("context") = py::none(),
2773  "Gets a Location representing an unknown location")
2774  .def_static(
2775  "callsite",
2776  [](PyLocation callee, const std::vector<PyLocation> &frames,
2777  DefaultingPyMlirContext context) {
2778  if (frames.empty())
2779  throw py::value_error("No caller frames provided");
2780  MlirLocation caller = frames.back().get();
2781  for (const PyLocation &frame :
2782  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2783  caller = mlirLocationCallSiteGet(frame.get(), caller);
2784  return PyLocation(context->getRef(),
2785  mlirLocationCallSiteGet(callee.get(), caller));
2786  },
2787  py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2789  .def_static(
2790  "file",
2791  [](std::string filename, int line, int col,
2792  DefaultingPyMlirContext context) {
2793  return PyLocation(
2794  context->getRef(),
2796  context->get(), toMlirStringRef(filename), line, col));
2797  },
2798  py::arg("filename"), py::arg("line"), py::arg("col"),
2799  py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2800  .def_static(
2801  "fused",
2802  [](const std::vector<PyLocation> &pyLocations,
2803  std::optional<PyAttribute> metadata,
2804  DefaultingPyMlirContext context) {
2806  locations.reserve(pyLocations.size());
2807  for (auto &pyLocation : pyLocations)
2808  locations.push_back(pyLocation.get());
2809  MlirLocation location = mlirLocationFusedGet(
2810  context->get(), locations.size(), locations.data(),
2811  metadata ? metadata->get() : MlirAttribute{0});
2812  return PyLocation(context->getRef(), location);
2813  },
2814  py::arg("locations"), py::arg("metadata") = py::none(),
2815  py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2816  .def_static(
2817  "name",
2818  [](std::string name, std::optional<PyLocation> childLoc,
2819  DefaultingPyMlirContext context) {
2820  return PyLocation(
2821  context->getRef(),
2823  context->get(), toMlirStringRef(name),
2824  childLoc ? childLoc->get()
2825  : mlirLocationUnknownGet(context->get())));
2826  },
2827  py::arg("name"), py::arg("childLoc") = py::none(),
2828  py::arg("context") = py::none(), kContextGetNameLocationDocString)
2829  .def_static(
2830  "from_attr",
2831  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2832  return PyLocation(context->getRef(),
2833  mlirLocationFromAttribute(attribute));
2834  },
2835  py::arg("attribute"), py::arg("context") = py::none(),
2836  "Gets a Location from a LocationAttr")
2837  .def_property_readonly(
2838  "context",
2839  [](PyLocation &self) { return self.getContext().getObject(); },
2840  "Context that owns the Location")
2841  .def_property_readonly(
2842  "attr",
2843  [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2844  "Get the underlying LocationAttr")
2845  .def(
2846  "emit_error",
2847  [](PyLocation &self, std::string message) {
2848  mlirEmitError(self, message.c_str());
2849  },
2850  py::arg("message"), "Emits an error at this location")
2851  .def("__repr__", [](PyLocation &self) {
2852  PyPrintAccumulator printAccum;
2853  mlirLocationPrint(self, printAccum.getCallback(),
2854  printAccum.getUserData());
2855  return printAccum.join();
2856  });
2857 
2858  //----------------------------------------------------------------------------
2859  // Mapping of Module
2860  //----------------------------------------------------------------------------
2861  py::class_<PyModule>(m, "Module", py::module_local())
2862  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2863  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2864  .def_static(
2865  "parse",
2866  [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2867  PyMlirContext::ErrorCapture errors(context->getRef());
2868  MlirModule module = mlirModuleCreateParse(
2869  context->get(), toMlirStringRef(moduleAsm));
2870  if (mlirModuleIsNull(module))
2871  throw MLIRError("Unable to parse module assembly", errors.take());
2872  return PyModule::forModule(module).releaseObject();
2873  },
2874  py::arg("asm"), py::arg("context") = py::none(),
2876  .def_static(
2877  "create",
2878  [](DefaultingPyLocation loc) {
2879  MlirModule module = mlirModuleCreateEmpty(loc);
2880  return PyModule::forModule(module).releaseObject();
2881  },
2882  py::arg("loc") = py::none(), "Creates an empty module")
2883  .def_property_readonly(
2884  "context",
2885  [](PyModule &self) { return self.getContext().getObject(); },
2886  "Context that created the Module")
2887  .def_property_readonly(
2888  "operation",
2889  [](PyModule &self) {
2890  return PyOperation::forOperation(self.getContext(),
2891  mlirModuleGetOperation(self.get()),
2892  self.getRef().releaseObject())
2893  .releaseObject();
2894  },
2895  "Accesses the module as an operation")
2896  .def_property_readonly(
2897  "body",
2898  [](PyModule &self) {
2899  PyOperationRef moduleOp = PyOperation::forOperation(
2900  self.getContext(), mlirModuleGetOperation(self.get()),
2901  self.getRef().releaseObject());
2902  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2903  return returnBlock;
2904  },
2905  "Return the block for this module")
2906  .def(
2907  "dump",
2908  [](PyModule &self) {
2910  },
2912  .def(
2913  "__str__",
2914  [](py::object self) {
2915  // Defer to the operation's __str__.
2916  return self.attr("operation").attr("__str__")();
2917  },
2919 
2920  //----------------------------------------------------------------------------
2921  // Mapping of Operation.
2922  //----------------------------------------------------------------------------
2923  py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2924  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2925  [](PyOperationBase &self) {
2926  return self.getOperation().getCapsule();
2927  })
2928  .def("__eq__",
2929  [](PyOperationBase &self, PyOperationBase &other) {
2930  return &self.getOperation() == &other.getOperation();
2931  })
2932  .def("__eq__",
2933  [](PyOperationBase &self, py::object other) { return false; })
2934  .def("__hash__",
2935  [](PyOperationBase &self) {
2936  return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2937  })
2938  .def_property_readonly("attributes",
2939  [](PyOperationBase &self) {
2940  return PyOpAttributeMap(
2941  self.getOperation().getRef());
2942  })
2943  .def_property_readonly(
2944  "context",
2945  [](PyOperationBase &self) {
2946  PyOperation &concreteOperation = self.getOperation();
2947  concreteOperation.checkValid();
2948  return concreteOperation.getContext().getObject();
2949  },
2950  "Context that owns the Operation")
2951  .def_property_readonly("name",
2952  [](PyOperationBase &self) {
2953  auto &concreteOperation = self.getOperation();
2954  concreteOperation.checkValid();
2955  MlirOperation operation =
2956  concreteOperation.get();
2958  mlirOperationGetName(operation));
2959  return py::str(name.data, name.length);
2960  })
2961  .def_property_readonly("operands",
2962  [](PyOperationBase &self) {
2963  return PyOpOperandList(
2964  self.getOperation().getRef());
2965  })
2966  .def_property_readonly("regions",
2967  [](PyOperationBase &self) {
2968  return PyRegionList(
2969  self.getOperation().getRef());
2970  })
2971  .def_property_readonly(
2972  "results",
2973  [](PyOperationBase &self) {
2974  return PyOpResultList(self.getOperation().getRef());
2975  },
2976  "Returns the list of Operation results.")
2977  .def_property_readonly(
2978  "result",
2979  [](PyOperationBase &self) {
2980  auto &operation = self.getOperation();
2981  auto numResults = mlirOperationGetNumResults(operation);
2982  if (numResults != 1) {
2983  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2984  throw py::value_error(
2985  (Twine("Cannot call .result on operation ") +
2986  StringRef(name.data, name.length) + " which has " +
2987  Twine(numResults) +
2988  " results (it is only valid for operations with a "
2989  "single result)")
2990  .str());
2991  }
2992  return PyOpResult(operation.getRef(),
2993  mlirOperationGetResult(operation, 0))
2994  .maybeDownCast();
2995  },
2996  "Shortcut to get an op result if it has only one (throws an error "
2997  "otherwise).")
2998  .def_property_readonly(
2999  "location",
3000  [](PyOperationBase &self) {
3001  PyOperation &operation = self.getOperation();
3002  return PyLocation(operation.getContext(),
3003  mlirOperationGetLocation(operation.get()));
3004  },
3005  "Returns the source location the operation was defined or derived "
3006  "from.")
3007  .def_property_readonly("parent",
3008  [](PyOperationBase &self) -> py::object {
3009  auto parent =
3010  self.getOperation().getParentOperation();
3011  if (parent)
3012  return parent->getObject();
3013  return py::none();
3014  })
3015  .def(
3016  "__str__",
3017  [](PyOperationBase &self) {
3018  return self.getAsm(/*binary=*/false,
3019  /*largeElementsLimit=*/std::nullopt,
3020  /*enableDebugInfo=*/false,
3021  /*prettyDebugInfo=*/false,
3022  /*printGenericOpForm=*/false,
3023  /*useLocalScope=*/false,
3024  /*assumeVerified=*/false);
3025  },
3026  "Returns the assembly form of the operation.")
3027  .def("print",
3028  py::overload_cast<PyAsmState &, pybind11::object, bool>(
3030  py::arg("state"), py::arg("file") = py::none(),
3031  py::arg("binary") = false, kOperationPrintStateDocstring)
3032  .def("print",
3033  py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3034  bool, py::object, bool>(&PyOperationBase::print),
3035  // Careful: Lots of arguments must match up with print method.
3036  py::arg("large_elements_limit") = py::none(),
3037  py::arg("enable_debug_info") = false,
3038  py::arg("pretty_debug_info") = false,
3039  py::arg("print_generic_op_form") = false,
3040  py::arg("use_local_scope") = false,
3041  py::arg("assume_verified") = false, py::arg("file") = py::none(),
3042  py::arg("binary") = false, kOperationPrintDocstring)
3043  .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
3044  py::arg("desired_version") = py::none(),
3046  .def("get_asm", &PyOperationBase::getAsm,
3047  // Careful: Lots of arguments must match up with get_asm method.
3048  py::arg("binary") = false,
3049  py::arg("large_elements_limit") = py::none(),
3050  py::arg("enable_debug_info") = false,
3051  py::arg("pretty_debug_info") = false,
3052  py::arg("print_generic_op_form") = false,
3053  py::arg("use_local_scope") = false,
3054  py::arg("assume_verified") = false, kOperationGetAsmDocstring)
3055  .def("verify", &PyOperationBase::verify,
3056  "Verify the operation. Raises MLIRError if verification fails, and "
3057  "returns true otherwise.")
3058  .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
3059  "Puts self immediately after the other operation in its parent "
3060  "block.")
3061  .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
3062  "Puts self immediately before the other operation in its parent "
3063  "block.")
3064  .def(
3065  "clone",
3066  [](PyOperationBase &self, py::object ip) {
3067  return self.getOperation().clone(ip);
3068  },
3069  py::arg("ip") = py::none())
3070  .def(
3071  "detach_from_parent",
3072  [](PyOperationBase &self) {
3073  PyOperation &operation = self.getOperation();
3074  operation.checkValid();
3075  if (!operation.isAttached())
3076  throw py::value_error("Detached operation has no parent.");
3077 
3078  operation.detachFromParent();
3079  return operation.createOpView();
3080  },
3081  "Detaches the operation from its parent block.")
3082  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3083  .def("walk", &PyOperationBase::walk, py::arg("callback"),
3084  py::arg("walk_order") = MlirWalkPostOrder);
3085 
3086  py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
3087  .def_static("create", &PyOperation::create, py::arg("name"),
3088  py::arg("results") = py::none(),
3089  py::arg("operands") = py::none(),
3090  py::arg("attributes") = py::none(),
3091  py::arg("successors") = py::none(), py::arg("regions") = 0,
3092  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3093  py::arg("infer_type") = false, kOperationCreateDocstring)
3094  .def_static(
3095  "parse",
3096  [](const std::string &sourceStr, const std::string &sourceName,
3097  DefaultingPyMlirContext context) {
3098  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3099  ->createOpView();
3100  },
3101  py::arg("source"), py::kw_only(), py::arg("source_name") = "",
3102  py::arg("context") = py::none(),
3103  "Parses an operation. Supports both text assembly format and binary "
3104  "bytecode format.")
3105  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3106  &PyOperation::getCapsule)
3107  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3108  .def_property_readonly("operation", [](py::object self) { return self; })
3109  .def_property_readonly("opview", &PyOperation::createOpView)
3110  .def_property_readonly(
3111  "successors",
3112  [](PyOperationBase &self) {
3113  return PyOpSuccessors(self.getOperation().getRef());
3114  },
3115  "Returns the list of Operation successors.");
3116 
3117  auto opViewClass =
3118  py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
3119  .def(py::init<py::object>(), py::arg("operation"))
3120  .def_property_readonly("operation", &PyOpView::getOperationObject)
3121  .def_property_readonly("opview", [](py::object self) { return self; })
3122  .def(
3123  "__str__",
3124  [](PyOpView &self) { return py::str(self.getOperationObject()); })
3125  .def_property_readonly(
3126  "successors",
3127  [](PyOperationBase &self) {
3128  return PyOpSuccessors(self.getOperation().getRef());
3129  },
3130  "Returns the list of Operation successors.");
3131  opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3132  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3133  opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3134  opViewClass.attr("build_generic") = classmethod(
3135  &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3136  py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3137  py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3138  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3139  "Builds a specific, generated OpView based on class level attributes.");
3140  opViewClass.attr("parse") = classmethod(
3141  [](const py::object &cls, const std::string &sourceStr,
3142  const std::string &sourceName, DefaultingPyMlirContext context) {
3143  PyOperationRef parsed =
3144  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3145 
3146  // Check if the expected operation was parsed, and cast to to the
3147  // appropriate `OpView` subclass if successful.
3148  // NOTE: This accesses attributes that have been automatically added to
3149  // `OpView` subclasses, and is not intended to be used on `OpView`
3150  // directly.
3151  std::string clsOpName =
3152  py::cast<std::string>(cls.attr("OPERATION_NAME"));
3153  MlirStringRef identifier =
3155  std::string_view parsedOpName(identifier.data, identifier.length);
3156  if (clsOpName != parsedOpName)
3157  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3158  parsedOpName + "'");
3159  return PyOpView::constructDerived(cls, *parsed.get());
3160  },
3161  py::arg("cls"), py::arg("source"), py::kw_only(),
3162  py::arg("source_name") = "", py::arg("context") = py::none(),
3163  "Parses a specific, generated OpView based on class level attributes");
3164 
3165  //----------------------------------------------------------------------------
3166  // Mapping of PyRegion.
3167  //----------------------------------------------------------------------------
3168  py::class_<PyRegion>(m, "Region", py::module_local())
3169  .def_property_readonly(
3170  "blocks",
3171  [](PyRegion &self) {
3172  return PyBlockList(self.getParentOperation(), self.get());
3173  },
3174  "Returns a forward-optimized sequence of blocks.")
3175  .def_property_readonly(
3176  "owner",
3177  [](PyRegion &self) {
3178  return self.getParentOperation()->createOpView();
3179  },
3180  "Returns the operation owning this region.")
3181  .def(
3182  "__iter__",
3183  [](PyRegion &self) {
3184  self.checkValid();
3185  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3186  return PyBlockIterator(self.getParentOperation(), firstBlock);
3187  },
3188  "Iterates over blocks in the region.")
3189  .def("__eq__",
3190  [](PyRegion &self, PyRegion &other) {
3191  return self.get().ptr == other.get().ptr;
3192  })
3193  .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3194 
3195  //----------------------------------------------------------------------------
3196  // Mapping of PyBlock.
3197  //----------------------------------------------------------------------------
3198  py::class_<PyBlock>(m, "Block", py::module_local())
3199  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3200  .def_property_readonly(
3201  "owner",
3202  [](PyBlock &self) {
3203  return self.getParentOperation()->createOpView();
3204  },
3205  "Returns the owning operation of this block.")
3206  .def_property_readonly(
3207  "region",
3208  [](PyBlock &self) {
3209  MlirRegion region = mlirBlockGetParentRegion(self.get());
3210  return PyRegion(self.getParentOperation(), region);
3211  },
3212  "Returns the owning region of this block.")
3213  .def_property_readonly(
3214  "arguments",
3215  [](PyBlock &self) {
3216  return PyBlockArgumentList(self.getParentOperation(), self.get());
3217  },
3218  "Returns a list of block arguments.")
3219  .def_property_readonly(
3220  "operations",
3221  [](PyBlock &self) {
3222  return PyOperationList(self.getParentOperation(), self.get());
3223  },
3224  "Returns a forward-optimized sequence of operations.")
3225  .def_static(
3226  "create_at_start",
3227  [](PyRegion &parent, const py::list &pyArgTypes,
3228  const std::optional<py::sequence> &pyArgLocs) {
3229  parent.checkValid();
3230  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3231  mlirRegionInsertOwnedBlock(parent, 0, block);
3232  return PyBlock(parent.getParentOperation(), block);
3233  },
3234  py::arg("parent"), py::arg("arg_types") = py::list(),
3235  py::arg("arg_locs") = std::nullopt,
3236  "Creates and returns a new Block at the beginning of the given "
3237  "region (with given argument types and locations).")
3238  .def(
3239  "append_to",
3240  [](PyBlock &self, PyRegion &region) {
3241  MlirBlock b = self.get();
3243  mlirBlockDetach(b);
3244  mlirRegionAppendOwnedBlock(region.get(), b);
3245  },
3246  "Append this block to a region, transferring ownership if necessary")
3247  .def(
3248  "create_before",
3249  [](PyBlock &self, const py::args &pyArgTypes,
3250  const std::optional<py::sequence> &pyArgLocs) {
3251  self.checkValid();
3252  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3253  MlirRegion region = mlirBlockGetParentRegion(self.get());
3254  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3255  return PyBlock(self.getParentOperation(), block);
3256  },
3257  py::arg("arg_locs") = std::nullopt,
3258  "Creates and returns a new Block before this block "
3259  "(with given argument types and locations).")
3260  .def(
3261  "create_after",
3262  [](PyBlock &self, const py::args &pyArgTypes,
3263  const std::optional<py::sequence> &pyArgLocs) {
3264  self.checkValid();
3265  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3266  MlirRegion region = mlirBlockGetParentRegion(self.get());
3267  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3268  return PyBlock(self.getParentOperation(), block);
3269  },
3270  py::arg("arg_locs") = std::nullopt,
3271  "Creates and returns a new Block after this block "
3272  "(with given argument types and locations).")
3273  .def(
3274  "__iter__",
3275  [](PyBlock &self) {
3276  self.checkValid();
3277  MlirOperation firstOperation =
3279  return PyOperationIterator(self.getParentOperation(),
3280  firstOperation);
3281  },
3282  "Iterates over operations in the block.")
3283  .def("__eq__",
3284  [](PyBlock &self, PyBlock &other) {
3285  return self.get().ptr == other.get().ptr;
3286  })
3287  .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3288  .def("__hash__",
3289  [](PyBlock &self) {
3290  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3291  })
3292  .def(
3293  "__str__",
3294  [](PyBlock &self) {
3295  self.checkValid();
3296  PyPrintAccumulator printAccum;
3297  mlirBlockPrint(self.get(), printAccum.getCallback(),
3298  printAccum.getUserData());
3299  return printAccum.join();
3300  },
3301  "Returns the assembly form of the block.")
3302  .def(
3303  "append",
3304  [](PyBlock &self, PyOperationBase &operation) {
3305  if (operation.getOperation().isAttached())
3306  operation.getOperation().detachFromParent();
3307 
3308  MlirOperation mlirOperation = operation.getOperation().get();
3309  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3310  operation.getOperation().setAttached(
3311  self.getParentOperation().getObject());
3312  },
3313  py::arg("operation"),
3314  "Appends an operation to this block. If the operation is currently "
3315  "in another block, it will be moved.");
3316 
3317  //----------------------------------------------------------------------------
3318  // Mapping of PyInsertionPoint.
3319  //----------------------------------------------------------------------------
3320 
3321  py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
3322  .def(py::init<PyBlock &>(), py::arg("block"),
3323  "Inserts after the last operation but still inside the block.")
3324  .def("__enter__", &PyInsertionPoint::contextEnter)
3325  .def("__exit__", &PyInsertionPoint::contextExit)
3326  .def_property_readonly_static(
3327  "current",
3328  [](py::object & /*class*/) {
3329  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3330  if (!ip)
3331  throw py::value_error("No current InsertionPoint");
3332  return ip;
3333  },
3334  "Gets the InsertionPoint bound to the current thread or raises "
3335  "ValueError if none has been set")
3336  .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3337  "Inserts before a referenced operation.")
3338  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3339  py::arg("block"), "Inserts at the beginning of the block.")
3340  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3341  py::arg("block"), "Inserts before the block terminator.")
3342  .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3343  "Inserts an operation.")
3344  .def_property_readonly(
3345  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3346  "Returns the block that this InsertionPoint points to.")
3347  .def_property_readonly(
3348  "ref_operation",
3349  [](PyInsertionPoint &self) -> py::object {
3350  auto refOperation = self.getRefOperation();
3351  if (refOperation)
3352  return refOperation->getObject();
3353  return py::none();
3354  },
3355  "The reference operation before which new operations are "
3356  "inserted, or None if the insertion point is at the end of "
3357  "the block");
3358 
3359  //----------------------------------------------------------------------------
3360  // Mapping of PyAttribute.
3361  //----------------------------------------------------------------------------
3362  py::class_<PyAttribute>(m, "Attribute", py::module_local())
3363  // Delegate to the PyAttribute copy constructor, which will also lifetime
3364  // extend the backing context which owns the MlirAttribute.
3365  .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
3366  "Casts the passed attribute to the generic Attribute")
3367  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3368  &PyAttribute::getCapsule)
3369  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3370  .def_static(
3371  "parse",
3372  [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3373  PyMlirContext::ErrorCapture errors(context->getRef());
3374  MlirAttribute attr = mlirAttributeParseGet(
3375  context->get(), toMlirStringRef(attrSpec));
3376  if (mlirAttributeIsNull(attr))
3377  throw MLIRError("Unable to parse attribute", errors.take());
3378  return attr;
3379  },
3380  py::arg("asm"), py::arg("context") = py::none(),
3381  "Parses an attribute from an assembly form. Raises an MLIRError on "
3382  "failure.")
3383  .def_property_readonly(
3384  "context",
3385  [](PyAttribute &self) { return self.getContext().getObject(); },
3386  "Context that owns the Attribute")
3387  .def_property_readonly(
3388  "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
3389  .def(
3390  "get_named",
3391  [](PyAttribute &self, std::string name) {
3392  return PyNamedAttribute(self, std::move(name));
3393  },
3394  py::keep_alive<0, 1>(), "Binds a name to the attribute")
3395  .def("__eq__",
3396  [](PyAttribute &self, PyAttribute &other) { return self == other; })
3397  .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3398  .def("__hash__",
3399  [](PyAttribute &self) {
3400  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3401  })
3402  .def(
3403  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3405  .def(
3406  "__str__",
3407  [](PyAttribute &self) {
3408  PyPrintAccumulator printAccum;
3409  mlirAttributePrint(self, printAccum.getCallback(),
3410  printAccum.getUserData());
3411  return printAccum.join();
3412  },
3413  "Returns the assembly form of the Attribute.")
3414  .def("__repr__",
3415  [](PyAttribute &self) {
3416  // Generally, assembly formats are not printed for __repr__ because
3417  // this can cause exceptionally long debug output and exceptions.
3418  // However, attribute values are generally considered useful and
3419  // are printed. This may need to be re-evaluated if debug dumps end
3420  // up being excessive.
3421  PyPrintAccumulator printAccum;
3422  printAccum.parts.append("Attribute(");
3423  mlirAttributePrint(self, printAccum.getCallback(),
3424  printAccum.getUserData());
3425  printAccum.parts.append(")");
3426  return printAccum.join();
3427  })
3428  .def_property_readonly(
3429  "typeid",
3430  [](PyAttribute &self) -> MlirTypeID {
3431  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3432  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3433  "mlirTypeID was expected to be non-null.");
3434  return mlirTypeID;
3435  })
3437  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3438  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3439  "mlirTypeID was expected to be non-null.");
3440  std::optional<pybind11::function> typeCaster =
3441  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3442  mlirAttributeGetDialect(self));
3443  if (!typeCaster)
3444  return py::cast(self);
3445  return typeCaster.value()(self);
3446  });
3447 
3448  //----------------------------------------------------------------------------
3449  // Mapping of PyNamedAttribute
3450  //----------------------------------------------------------------------------
3451  py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3452  .def("__repr__",
3453  [](PyNamedAttribute &self) {
3454  PyPrintAccumulator printAccum;
3455  printAccum.parts.append("NamedAttribute(");
3456  printAccum.parts.append(
3457  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3458  mlirIdentifierStr(self.namedAttr.name).length));
3459  printAccum.parts.append("=");
3460  mlirAttributePrint(self.namedAttr.attribute,
3461  printAccum.getCallback(),
3462  printAccum.getUserData());
3463  printAccum.parts.append(")");
3464  return printAccum.join();
3465  })
3466  .def_property_readonly(
3467  "name",
3468  [](PyNamedAttribute &self) {
3469  return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3470  mlirIdentifierStr(self.namedAttr.name).length);
3471  },
3472  "The name of the NamedAttribute binding")
3473  .def_property_readonly(
3474  "attr",
3475  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3476  py::keep_alive<0, 1>(),
3477  "The underlying generic attribute of the NamedAttribute binding");
3478 
3479  //----------------------------------------------------------------------------
3480  // Mapping of PyType.
3481  //----------------------------------------------------------------------------
3482  py::class_<PyType>(m, "Type", py::module_local())
3483  // Delegate to the PyType copy constructor, which will also lifetime
3484  // extend the backing context which owns the MlirType.
3485  .def(py::init<PyType &>(), py::arg("cast_from_type"),
3486  "Casts the passed type to the generic Type")
3487  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3488  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3489  .def_static(
3490  "parse",
3491  [](std::string typeSpec, DefaultingPyMlirContext context) {
3492  PyMlirContext::ErrorCapture errors(context->getRef());
3493  MlirType type =
3494  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3495  if (mlirTypeIsNull(type))
3496  throw MLIRError("Unable to parse type", errors.take());
3497  return type;
3498  },
3499  py::arg("asm"), py::arg("context") = py::none(),
3501  .def_property_readonly(
3502  "context", [](PyType &self) { return self.getContext().getObject(); },
3503  "Context that owns the Type")
3504  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3505  .def("__eq__", [](PyType &self, py::object &other) { return false; })
3506  .def("__hash__",
3507  [](PyType &self) {
3508  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3509  })
3510  .def(
3511  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3512  .def(
3513  "__str__",
3514  [](PyType &self) {
3515  PyPrintAccumulator printAccum;
3516  mlirTypePrint(self, printAccum.getCallback(),
3517  printAccum.getUserData());
3518  return printAccum.join();
3519  },
3520  "Returns the assembly form of the type.")
3521  .def("__repr__",
3522  [](PyType &self) {
3523  // Generally, assembly formats are not printed for __repr__ because
3524  // this can cause exceptionally long debug output and exceptions.
3525  // However, types are an exception as they typically have compact
3526  // assembly forms and printing them is useful.
3527  PyPrintAccumulator printAccum;
3528  printAccum.parts.append("Type(");
3529  mlirTypePrint(self, printAccum.getCallback(),
3530  printAccum.getUserData());
3531  printAccum.parts.append(")");
3532  return printAccum.join();
3533  })
3535  [](PyType &self) {
3536  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3537  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3538  "mlirTypeID was expected to be non-null.");
3539  std::optional<pybind11::function> typeCaster =
3540  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3541  mlirTypeGetDialect(self));
3542  if (!typeCaster)
3543  return py::cast(self);
3544  return typeCaster.value()(self);
3545  })
3546  .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
3547  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3548  if (!mlirTypeIDIsNull(mlirTypeID))
3549  return mlirTypeID;
3550  auto origRepr =
3551  pybind11::repr(pybind11::cast(self)).cast<std::string>();
3552  throw py::value_error(
3553  (origRepr + llvm::Twine(" has no typeid.")).str());
3554  });
3555 
3556  //----------------------------------------------------------------------------
3557  // Mapping of PyTypeID.
3558  //----------------------------------------------------------------------------
3559  py::class_<PyTypeID>(m, "TypeID", py::module_local())
3560  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3561  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3562  // Note, this tests whether the underlying TypeIDs are the same,
3563  // not whether the wrapper MlirTypeIDs are the same, nor whether
3564  // the Python objects are the same (i.e., PyTypeID is a value type).
3565  .def("__eq__",
3566  [](PyTypeID &self, PyTypeID &other) { return self == other; })
3567  .def("__eq__",
3568  [](PyTypeID &self, const py::object &other) { return false; })
3569  // Note, this gives the hash value of the underlying TypeID, not the
3570  // hash value of the Python object, nor the hash value of the
3571  // MlirTypeID wrapper.
3572  .def("__hash__", [](PyTypeID &self) {
3573  return static_cast<size_t>(mlirTypeIDHashValue(self));
3574  });
3575 
3576  //----------------------------------------------------------------------------
3577  // Mapping of Value.
3578  //----------------------------------------------------------------------------
3579  py::class_<PyValue>(m, "Value", py::module_local())
3580  .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
3581  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3582  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3583  .def_property_readonly(
3584  "context",
3585  [](PyValue &self) { return self.getParentOperation()->getContext(); },
3586  "Context in which the value lives.")
3587  .def(
3588  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3590  .def_property_readonly(
3591  "owner",
3592  [](PyValue &self) -> py::object {
3593  MlirValue v = self.get();
3594  if (mlirValueIsAOpResult(v)) {
3595  assert(
3596  mlirOperationEqual(self.getParentOperation()->get(),
3597  mlirOpResultGetOwner(self.get())) &&
3598  "expected the owner of the value in Python to match that in "
3599  "the IR");
3600  return self.getParentOperation().getObject();
3601  }
3602 
3603  if (mlirValueIsABlockArgument(v)) {
3604  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3605  return py::cast(PyBlock(self.getParentOperation(), block));
3606  }
3607 
3608  assert(false && "Value must be a block argument or an op result");
3609  return py::none();
3610  })
3611  .def_property_readonly("uses",
3612  [](PyValue &self) {
3613  return PyOpOperandIterator(
3614  mlirValueGetFirstUse(self.get()));
3615  })
3616  .def("__eq__",
3617  [](PyValue &self, PyValue &other) {
3618  return self.get().ptr == other.get().ptr;
3619  })
3620  .def("__eq__", [](PyValue &self, py::object other) { return false; })
3621  .def("__hash__",
3622  [](PyValue &self) {
3623  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3624  })
3625  .def(
3626  "__str__",
3627  [](PyValue &self) {
3628  PyPrintAccumulator printAccum;
3629  printAccum.parts.append("Value(");
3630  mlirValuePrint(self.get(), printAccum.getCallback(),
3631  printAccum.getUserData());
3632  printAccum.parts.append(")");
3633  return printAccum.join();
3634  },
3636  .def(
3637  "get_name",
3638  [](PyValue &self, bool useLocalScope) {
3639  PyPrintAccumulator printAccum;
3640  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3641  if (useLocalScope)
3643  MlirAsmState valueState =
3644  mlirAsmStateCreateForValue(self.get(), flags);
3645  mlirValuePrintAsOperand(self.get(), valueState,
3646  printAccum.getCallback(),
3647  printAccum.getUserData());
3649  mlirAsmStateDestroy(valueState);
3650  return printAccum.join();
3651  },
3652  py::arg("use_local_scope") = false)
3653  .def(
3654  "get_name",
3655  [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
3656  PyPrintAccumulator printAccum;
3657  MlirAsmState valueState = state.get().get();
3658  mlirValuePrintAsOperand(self.get(), valueState,
3659  printAccum.getCallback(),
3660  printAccum.getUserData());
3661  return printAccum.join();
3662  },
3663  py::arg("state"), kGetNameAsOperand)
3664  .def_property_readonly(
3665  "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
3666  .def(
3667  "set_type",
3668  [](PyValue &self, const PyType &type) {
3669  return mlirValueSetType(self.get(), type);
3670  },
3671  py::arg("type"))
3672  .def(
3673  "replace_all_uses_with",
3674  [](PyValue &self, PyValue &with) {
3675  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3676  },
3679  [](PyValue &self) { return self.maybeDownCast(); });
3680  PyBlockArgument::bind(m);
3681  PyOpResult::bind(m);
3682  PyOpOperand::bind(m);
3683 
3684  py::class_<PyAsmState>(m, "AsmState", py::module_local())
3685  .def(py::init<PyValue &, bool>(), py::arg("value"),
3686  py::arg("use_local_scope") = false)
3687  .def(py::init<PyOperationBase &, bool>(), py::arg("op"),
3688  py::arg("use_local_scope") = false);
3689 
3690  //----------------------------------------------------------------------------
3691  // Mapping of SymbolTable.
3692  //----------------------------------------------------------------------------
3693  py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3694  .def(py::init<PyOperationBase &>())
3695  .def("__getitem__", &PySymbolTable::dunderGetItem)
3696  .def("insert", &PySymbolTable::insert, py::arg("operation"))
3697  .def("erase", &PySymbolTable::erase, py::arg("operation"))
3698  .def("__delitem__", &PySymbolTable::dunderDel)
3699  .def("__contains__",
3700  [](PySymbolTable &table, const std::string &name) {
3702  table, mlirStringRefCreate(name.data(), name.length())));
3703  })
3704  // Static helpers.
3705  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3706  py::arg("symbol"), py::arg("name"))
3707  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3708  py::arg("symbol"))
3709  .def_static("get_visibility", &PySymbolTable::getVisibility,
3710  py::arg("symbol"))
3711  .def_static("set_visibility", &PySymbolTable::setVisibility,
3712  py::arg("symbol"), py::arg("visibility"))
3713  .def_static("replace_all_symbol_uses",
3714  &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3715  py::arg("new_symbol"), py::arg("from_op"))
3716  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3717  py::arg("from_op"), py::arg("all_sym_uses_visible"),
3718  py::arg("callback"));
3719 
3720  // Container bindings.
3721  PyBlockArgumentList::bind(m);
3722  PyBlockIterator::bind(m);
3723  PyBlockList::bind(m);
3724  PyOperationIterator::bind(m);
3725  PyOperationList::bind(m);
3726  PyOpAttributeMap::bind(m);
3727  PyOpOperandIterator::bind(m);
3728  PyOpOperandList::bind(m);
3729  PyOpResultList::bind(m);
3730  PyOpSuccessors::bind(m);
3731  PyRegionIterator::bind(m);
3732  PyRegionList::bind(m);
3733 
3734  // Debug bindings.
3736 
3737  // Attribute builder getter.
3739 
3740  py::register_local_exception_translator([](std::exception_ptr p) {
3741  // We can't define exceptions with custom fields through pybind, so instead
3742  // the exception class is defined in python and imported here.
3743  try {
3744  if (p)
3745  std::rethrow_exception(p);
3746  } catch (const MLIRError &e) {
3747  py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
3748  .attr("MLIRError")(e.message, e.errorDiagnostics);
3749  PyErr_SetObject(PyExc_Exception, obj.ptr());
3750  }
3751  });
3752 }
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:113
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:175
static py::object createCustomDialectWrapper(const std::string &dialectNamespace, py::object dialectDescriptor)
Definition: IRCore.cpp:192
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:57
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:171
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:204
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:60
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:145
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:87
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:51
static const char kDumpDocstring[]
Definition: IRCore.cpp:153
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:156
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:54
static void maybeInsertOperation(PyOperationRef &op, const py::object &maybeIp)
Definition: IRCore.cpp:1367
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:210
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:135
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:122
static void populateResultTypes(StringRef name, py::list resultTypeList, const py::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1544
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:68
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:40
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:48
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:163
py::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:186
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:272
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:117
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:327
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:317
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:96
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:188
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:179
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:254
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:109
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:281
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:223
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:336
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:234
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:346
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:244
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:56
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:355
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:433
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:197
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:309
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:263
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:424
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:215
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:125
MlirStringCallback getCallback()
Definition: PybindUtils.h:132
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:209
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:295
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:303
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:510
static PyLocation & resolve()
Definition: IRCore.cpp:1040
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:284
static PyMlirContext & resolve()
Definition: IRCore.cpp:764
ReferrentTy * get() const
Definition: PybindUtils.h:47
Wrapper around an MlirAsmState.
Definition: IRModule.h:776
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:993
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:995
static PyAttribute createFromCapsule(pybind11::object capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:1907
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:1903
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:1899
Wrapper around an MlirBlock.
Definition: IRModule.h:811
MlirBlock get()
Definition: IRModule.h:818
PyOperationRef & getParentOperation()
Definition: IRModule.h:819
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:391
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:926
PyDiagnosticHandler(MlirContext context, pybind11::object callback)
Definition: IRCore.cpp:920
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:341
pybind11::str getMessage()
Definition: IRCore.cpp:956
PyLocation getLocation()
Definition: IRCore.cpp:949
DiagnosticInfo getInfo()
Definition: IRCore.cpp:977
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:343
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:944
pybind11::tuple getNotes()
Definition: IRCore.cpp:964
Wrapper around an MlirDialect.
Definition: IRModule.h:446
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:483
static PyDialectRegistry createFromCapsule(pybind11::object capsule)
Definition: IRCore.cpp:1006
pybind11::object getCapsule()
Definition: IRCore.cpp:1001
User-level dialect object.
Definition: IRModule.h:470
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:459
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:988
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:34
std::optional< pybind11::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:172
std::optional< pybind11::function > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:145
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:835
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:1876
PyInsertionPoint(PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:1831
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:1863
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:1837
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1889
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1885
Wrapper around an MlirLocation.
Definition: IRModule.h:310
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:312
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1018
static PyLocation createFromCapsule(pybind11::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1022
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1034
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1030
MlirLocation get() const
Definition: IRModule.h:316
pybind11::object attachDiagnosticHandler(pybind11::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:699
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:184
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:188
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:601
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:662
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:635
static PyMlirContext * createNewContextForInit()
For the case of a python init (py::init) method, pybind11 is quite strict about needing to return a p...
Definition: IRCore.cpp:608
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:687
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:689
size_t clearLiveOperations()
Clears the live operations map, returning the number of entries which were invalidated.
Definition: IRCore.cpp:646
std::vector< PyOperation * > getLiveOperationObjects()
Get a list of Python objects which are still in the live context map.
Definition: IRCore.cpp:639
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:693
void clearOperation(MlirOperation op)
Removes an operation from the live operations map and sets it invalid.
Definition: IRCore.cpp:654
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:613
size_t getLiveOperationCount()
Gets the count of live operations associated with this context.
Definition: IRCore.cpp:637
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:597
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:533
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1067
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1100
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1093
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1017
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:1919
MlirNamedAttribute namedAttr
Definition: IRModule.h:1026
pybind11::object getObject()
Definition: IRModule.h:87
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:75
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:725
PyOpView(const pybind11::object &operationObject)
Definition: IRCore.cpp:1821
static pybind11::object buildGeneric(const pybind11::object &cls, std::optional< pybind11::list > resultTypeList, pybind11::list operandList, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, DefaultingPyLocation location, const pybind11::object &maybeIp)
Definition: IRCore.cpp:1632
static pybind11::object constructDerived(const pybind11::object &cls, const PyOperation &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:1810
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:563
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1253
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void writeBytecode(const pybind11::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1232
pybind11::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified)
Definition: IRCore.cpp:1285
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1308
void print(std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, py::object fileObject, bool binary)
Implements the bound 'print' method and helps with others.
Definition: IRCore.cpp:1192
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1317
bool verify()
Verify the operation.
Definition: IRCore.cpp:1326
pybind11::object clone(const pybind11::object &ip)
Clones this operation.
Definition: IRCore.cpp:1508
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:631
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1528
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1353
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:609
PyOperationRef getRef()
Definition: IRModule.h:644
MlirOperation get() const
Definition: IRModule.h:639
void setAttached(const pybind11::object &parent=pybind11::object())
Definition: IRModule.h:650
static pybind11::object create(const std::string &name, std::optional< std::vector< PyType * >> results, std::optional< std::vector< PyValue * >> operands, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, DefaultingPyLocation location, const pybind11::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1382
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1517
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1144
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1334
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1344
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Creates a detached operation.
Definition: IRCore.cpp:1160
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1174
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1358
void checkValid() const
Definition: IRCore.cpp:1186
Wrapper around an MlirRegion.
Definition: IRModule.h:757
PyOperationRef & getParentOperation()
Definition: IRModule.h:766
MlirRegion get()
Definition: IRModule.h:765
Bindings for MLIR symbol tables.
Definition: IRModule.h:1223
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2039
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2110
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2092
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2066
MlirAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2044
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2029
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2009
static MlirAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2054
pybind11::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2017
static MlirAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2081
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, pybind11::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2122
Tracks an entry in the thread context stack.
Definition: IRModule.h:106
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:784
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:896
static pybind11::object pushContext(PyMlirContext &context)
Definition: IRCore.cpp:846
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:841
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:876
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:854
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:836
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:865
static pybind11::object pushLocation(PyLocation &location)
Definition: IRCore.cpp:887
PyMlirContext * getContext()
Definition: IRCore.cpp:813
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:831
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:779
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:819
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:895
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:1961
static PyTypeID createFromCapsule(pybind11::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:1955
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:1951
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:897
Wrapper around the generic MlirType.
Definition: IRModule.h:871
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:1935
static PyType createFromCapsule(pybind11::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:1939
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:873
bool operator==(const PyType &other) const
Definition: IRCore.cpp:1931
Wrapper around the generic MlirValue.
Definition: IRModule.h:1124
static PyValue createFromCapsule(pybind11::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:1988
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1130
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:1969
pybind11::object maybeDownCast()
Definition: IRCore.cpp:1973
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:243
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:944
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1019
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:723
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:692
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:520
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:251
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1173
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:127
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:579
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1152
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:716
@ MlirWalkPreOrder
Definition: IR.h:717
@ MlirWalkPostOrder
Definition: IR.h:718
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:653
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:368
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:321
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1100
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1157
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:214
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:932
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:82
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1121
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:314
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1034
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1017
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1073
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:71
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:792
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:995
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:591
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:978
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:279
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1054
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:663
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1005
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1084
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1092
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:831
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:716
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:883
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:684
MlirWalkResult
Operation walk result.
Definition: IR.h:709
@ MlirWalkResultInterrupt
Definition: IR.h:711
@ MlirWalkResultSkip
Definition: IR.h:712
@ MlirWalkResultAdvance
Definition: IR.h:710
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:772
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:658
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:984
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:98
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:587
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:494
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:900
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:210
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:262
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:881
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:186
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:93
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:200
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:778
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:949
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:506
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:815
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:805
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:856
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:914
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1147
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1038
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:970
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:301
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1003
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1137
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:283
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:524
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:282
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:502
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:678
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:985
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:291
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:708
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:327
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1050
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:339
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1013
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:846
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:556
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:528
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:313
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:247
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:218
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1042
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1133
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:173
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:230
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:564
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:380
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:835
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:258
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:953
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:936
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:560
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:235
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:733
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:53
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:827
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:532
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:287
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:668
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1098
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1162
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1065
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:305
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:768
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:839
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1059
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:583
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:226
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1069
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:940
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:744
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:498
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:536
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1046
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1109
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:644
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:106
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:372
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:376
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:317
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:196
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:270
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:102
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:918
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:234
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1129
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:755
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:895
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:363
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:205
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:510
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1080
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:568
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:706
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:957
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:192
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:168
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:448
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1074
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:75
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:712
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:649
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:972
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:699
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:966
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:962
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:69
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:485
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1088
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:507
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:761
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::hash_code hash_value(const MPInt &x)
Redeclarations of friend declaration above to make it discoverable by lookups.
Definition: MPInt.cpp:17
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:161
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:522
void populateIRCore(pybind11::module &m)
PyObjectRef< PyOperation > PyOperationRef
Definition: IRModule.h:605
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:340
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:248
static void dundeSetItemNamed(const std::string &attributeKind, py::function func, bool replace)
Definition: IRCore.cpp:257
static py::function dundeGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:251
static void bind(py::module &m)
Definition: IRCore.cpp:263
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:234
static void bind(py::module &m)
Definition: IRCore.cpp:239
static bool get(const py::object &)
Definition: IRCore.cpp:237
static void set(py::object &o, bool enable)
Definition: IRCore.cpp:235
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::list parts
Definition: PybindUtils.h:103
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1275
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1280
Materialized diagnostic information.
Definition: IRModule.h:353
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:419
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:429
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:420