MLIR  19.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 #include <optional>
25 #include <utility>
26 
27 namespace py = pybind11;
28 using namespace py::literals;
29 using namespace mlir;
30 using namespace mlir::python;
31 
32 using llvm::SmallVector;
33 using llvm::StringRef;
34 using llvm::Twine;
35 
36 //------------------------------------------------------------------------------
37 // Docstrings (trivial, non-duplicated docstrings are included inline).
38 //------------------------------------------------------------------------------
39 
40 static const char kContextParseTypeDocstring[] =
41  R"(Parses the assembly form of a type.
42 
43 Returns a Type object or raises an MLIRError if the type cannot be parsed.
44 
45 See also: https://mlir.llvm.org/docs/LangRef/#type-system
46 )";
47 
49  R"(Gets a Location representing a caller and callsite)";
50 
51 static const char kContextGetFileLocationDocstring[] =
52  R"(Gets a Location representing a file, line and column)";
53 
54 static const char kContextGetFusedLocationDocstring[] =
55  R"(Gets a Location representing a fused location with optional metadata)";
56 
57 static const char kContextGetNameLocationDocString[] =
58  R"(Gets a Location representing a named location with optional child location)";
59 
60 static const char kModuleParseDocstring[] =
61  R"(Parses a module's assembly format from a string.
62 
63 Returns a new MlirModule or raises an MLIRError if the parsing fails.
64 
65 See also: https://mlir.llvm.org/docs/LangRef/
66 )";
67 
68 static const char kOperationCreateDocstring[] =
69  R"(Creates a new operation.
70 
71 Args:
72  name: Operation name (e.g. "dialect.operation").
73  results: Sequence of Type representing op result types.
74  attributes: Dict of str:Attribute.
75  successors: List of Block for the operation's successors.
76  regions: Number of regions to create.
77  location: A Location object (defaults to resolve from context manager).
78  ip: An InsertionPoint (defaults to resolve from context manager or set to
79  False to disable insertion, even with an insertion point set in the
80  context manager).
81  infer_type: Whether to infer result types.
82 Returns:
83  A new "detached" Operation object. Detached operations can be added
84  to blocks, which causes them to become "attached."
85 )";
86 
87 static const char kOperationPrintDocstring[] =
88  R"(Prints the assembly form of the operation to a file like object.
89 
90 Args:
91  file: The file like object to write to. Defaults to sys.stdout.
92  binary: Whether to write bytes (True) or str (False). Defaults to False.
93  large_elements_limit: Whether to elide elements attributes above this
94  number of elements. Defaults to None (no limit).
95  enable_debug_info: Whether to print debug/location information. Defaults
96  to False.
97  pretty_debug_info: Whether to format debug information for easier reading
98  by a human (warning: the result is unparseable).
99  print_generic_op_form: Whether to print the generic assembly forms of all
100  ops. Defaults to False.
101  use_local_Scope: Whether to print in a way that is more optimized for
102  multi-threaded access but may not be consistent with how the overall
103  module prints.
104  assume_verified: By default, if not printing generic form, the verifier
105  will be run and if it fails, generic form will be printed with a comment
106  about failed verification. While a reasonable default for interactive use,
107  for systematic use, it is often better for the caller to verify explicitly
108  and report failures in a more robust fashion. Set this to True if doing this
109  in order to avoid running a redundant verification. If the IR is actually
110  invalid, behavior is undefined.
111 )";
112 
113 static const char kOperationPrintStateDocstring[] =
114  R"(Prints the assembly form of the operation to a file like object.
115 
116 Args:
117  file: The file like object to write to. Defaults to sys.stdout.
118  binary: Whether to write bytes (True) or str (False). Defaults to False.
119  state: AsmState capturing the operation numbering and flags.
120 )";
121 
122 static const char kOperationGetAsmDocstring[] =
123  R"(Gets the assembly form of the operation with all options available.
124 
125 Args:
126  binary: Whether to return a bytes (True) or str (False) object. Defaults to
127  False.
128  ... others ...: See the print() method for common keyword arguments for
129  configuring the printout.
130 Returns:
131  Either a bytes or str object, depending on the setting of the 'binary'
132  argument.
133 )";
134 
135 static const char kOperationPrintBytecodeDocstring[] =
136  R"(Write the bytecode form of the operation to a file like object.
137 
138 Args:
139  file: The file like object to write to.
140  desired_version: The version of bytecode to emit.
141 Returns:
142  The bytecode writer status.
143 )";
144 
145 static const char kOperationStrDunderDocstring[] =
146  R"(Gets the assembly form of the operation with default options.
147 
148 If more advanced control over the assembly formatting or I/O options is needed,
149 use the dedicated print or get_asm method, which supports keyword arguments to
150 customize behavior.
151 )";
152 
153 static const char kDumpDocstring[] =
154  R"(Dumps a debug representation of the object to stderr.)";
155 
156 static const char kAppendBlockDocstring[] =
157  R"(Appends a new block, with argument types as positional args.
158 
159 Returns:
160  The created block.
161 )";
162 
163 static const char kValueDunderStrDocstring[] =
164  R"(Returns the string form of the value.
165 
166 If the value is a block argument, this is the assembly form of its type and the
167 position in the argument list. If the value is an operation result, this is
168 equivalent to printing the operation that produced it.
169 )";
170 
171 static const char kGetNameAsOperand[] =
172  R"(Returns the string form of value as an operand (i.e., the ValueID).
173 )";
174 
176  R"(Replace all uses of value with the new value, updating anything in
177 the IR that uses 'self' to use the other value instead.
178 )";
179 
180 //------------------------------------------------------------------------------
181 // Utilities.
182 //------------------------------------------------------------------------------
183 
184 /// Helper for creating an @classmethod.
185 template <class Func, typename... Args>
186 py::object classmethod(Func f, Args... args) {
187  py::object cf = py::cpp_function(f, args...);
188  return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
189 }
190 
191 static py::object
192 createCustomDialectWrapper(const std::string &dialectNamespace,
193  py::object dialectDescriptor) {
194  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
195  if (!dialectClass) {
196  // Use the base class.
197  return py::cast(PyDialect(std::move(dialectDescriptor)));
198  }
199 
200  // Create the custom implementation.
201  return (*dialectClass)(std::move(dialectDescriptor));
202 }
203 
204 static MlirStringRef toMlirStringRef(const std::string &s) {
205  return mlirStringRefCreate(s.data(), s.size());
206 }
207 
208 /// Create a block, using the current location context if no locations are
209 /// specified.
210 static MlirBlock createBlock(const py::sequence &pyArgTypes,
211  const std::optional<py::sequence> &pyArgLocs) {
212  SmallVector<MlirType> argTypes;
213  argTypes.reserve(pyArgTypes.size());
214  for (const auto &pyType : pyArgTypes)
215  argTypes.push_back(pyType.cast<PyType &>());
216 
218  if (pyArgLocs) {
219  argLocs.reserve(pyArgLocs->size());
220  for (const auto &pyLoc : *pyArgLocs)
221  argLocs.push_back(pyLoc.cast<PyLocation &>());
222  } else if (!argTypes.empty()) {
223  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
224  }
225 
226  if (argTypes.size() != argLocs.size())
227  throw py::value_error(("Expected " + Twine(argTypes.size()) +
228  " locations, got: " + Twine(argLocs.size()))
229  .str());
230  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
231 }
232 
233 /// Wrapper for the global LLVM debugging flag.
235  static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
236 
237  static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
238 
239  static void bind(py::module &m) {
240  // Debug flags.
241  py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
242  .def_property_static("flag", &PyGlobalDebugFlag::get,
243  &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
244  .def_static(
245  "set_types",
246  [](const std::string &type) {
247  mlirSetGlobalDebugType(type.c_str());
248  },
249  "types"_a, "Sets specific debug types to be produced by LLVM")
250  .def_static("set_types", [](const std::vector<std::string> &types) {
251  std::vector<const char *> pointers;
252  pointers.reserve(types.size());
253  for (const std::string &str : types)
254  pointers.push_back(str.c_str());
255  mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
256  });
257  }
258 };
259 
261  static bool dunderContains(const std::string &attributeKind) {
262  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
263  }
264  static py::function dundeGetItemNamed(const std::string &attributeKind) {
265  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
266  if (!builder)
267  throw py::key_error(attributeKind);
268  return *builder;
269  }
270  static void dundeSetItemNamed(const std::string &attributeKind,
271  py::function func, bool replace) {
272  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
273  replace);
274  }
275 
276  static void bind(py::module &m) {
277  py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
278  .def_static("contains", &PyAttrBuilderMap::dunderContains)
279  .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
280  .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
281  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
282  "Register an attribute builder for building MLIR "
283  "attributes from python values.");
284  }
285 };
286 
287 //------------------------------------------------------------------------------
288 // PyBlock
289 //------------------------------------------------------------------------------
290 
291 py::object PyBlock::getCapsule() {
292  return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
293 }
294 
295 //------------------------------------------------------------------------------
296 // Collections.
297 //------------------------------------------------------------------------------
298 
299 namespace {
300 
301 class PyRegionIterator {
302 public:
303  PyRegionIterator(PyOperationRef operation)
304  : operation(std::move(operation)) {}
305 
306  PyRegionIterator &dunderIter() { return *this; }
307 
308  PyRegion dunderNext() {
309  operation->checkValid();
310  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
311  throw py::stop_iteration();
312  }
313  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
314  return PyRegion(operation, region);
315  }
316 
317  static void bind(py::module &m) {
318  py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
319  .def("__iter__", &PyRegionIterator::dunderIter)
320  .def("__next__", &PyRegionIterator::dunderNext);
321  }
322 
323 private:
324  PyOperationRef operation;
325  int nextIndex = 0;
326 };
327 
328 /// Regions of an op are fixed length and indexed numerically so are represented
329 /// with a sequence-like container.
330 class PyRegionList {
331 public:
332  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
333 
334  PyRegionIterator dunderIter() {
335  operation->checkValid();
336  return PyRegionIterator(operation);
337  }
338 
339  intptr_t dunderLen() {
340  operation->checkValid();
341  return mlirOperationGetNumRegions(operation->get());
342  }
343 
344  PyRegion dunderGetItem(intptr_t index) {
345  // dunderLen checks validity.
346  if (index < 0 || index >= dunderLen()) {
347  throw py::index_error("attempt to access out of bounds region");
348  }
349  MlirRegion region = mlirOperationGetRegion(operation->get(), index);
350  return PyRegion(operation, region);
351  }
352 
353  static void bind(py::module &m) {
354  py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
355  .def("__len__", &PyRegionList::dunderLen)
356  .def("__iter__", &PyRegionList::dunderIter)
357  .def("__getitem__", &PyRegionList::dunderGetItem);
358  }
359 
360 private:
361  PyOperationRef operation;
362 };
363 
364 class PyBlockIterator {
365 public:
366  PyBlockIterator(PyOperationRef operation, MlirBlock next)
367  : operation(std::move(operation)), next(next) {}
368 
369  PyBlockIterator &dunderIter() { return *this; }
370 
371  PyBlock dunderNext() {
372  operation->checkValid();
373  if (mlirBlockIsNull(next)) {
374  throw py::stop_iteration();
375  }
376 
377  PyBlock returnBlock(operation, next);
378  next = mlirBlockGetNextInRegion(next);
379  return returnBlock;
380  }
381 
382  static void bind(py::module &m) {
383  py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
384  .def("__iter__", &PyBlockIterator::dunderIter)
385  .def("__next__", &PyBlockIterator::dunderNext);
386  }
387 
388 private:
389  PyOperationRef operation;
390  MlirBlock next;
391 };
392 
393 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
394 /// we present them as a more full-featured list-like container but optimize
395 /// it for forward iteration. Blocks are always owned by a region.
396 class PyBlockList {
397 public:
398  PyBlockList(PyOperationRef operation, MlirRegion region)
399  : operation(std::move(operation)), region(region) {}
400 
401  PyBlockIterator dunderIter() {
402  operation->checkValid();
403  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
404  }
405 
406  intptr_t dunderLen() {
407  operation->checkValid();
408  intptr_t count = 0;
409  MlirBlock block = mlirRegionGetFirstBlock(region);
410  while (!mlirBlockIsNull(block)) {
411  count += 1;
412  block = mlirBlockGetNextInRegion(block);
413  }
414  return count;
415  }
416 
417  PyBlock dunderGetItem(intptr_t index) {
418  operation->checkValid();
419  if (index < 0) {
420  throw py::index_error("attempt to access out of bounds block");
421  }
422  MlirBlock block = mlirRegionGetFirstBlock(region);
423  while (!mlirBlockIsNull(block)) {
424  if (index == 0) {
425  return PyBlock(operation, block);
426  }
427  block = mlirBlockGetNextInRegion(block);
428  index -= 1;
429  }
430  throw py::index_error("attempt to access out of bounds block");
431  }
432 
433  PyBlock appendBlock(const py::args &pyArgTypes,
434  const std::optional<py::sequence> &pyArgLocs) {
435  operation->checkValid();
436  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
437  mlirRegionAppendOwnedBlock(region, block);
438  return PyBlock(operation, block);
439  }
440 
441  static void bind(py::module &m) {
442  py::class_<PyBlockList>(m, "BlockList", py::module_local())
443  .def("__getitem__", &PyBlockList::dunderGetItem)
444  .def("__iter__", &PyBlockList::dunderIter)
445  .def("__len__", &PyBlockList::dunderLen)
446  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
447  py::arg("arg_locs") = std::nullopt);
448  }
449 
450 private:
451  PyOperationRef operation;
452  MlirRegion region;
453 };
454 
455 class PyOperationIterator {
456 public:
457  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
458  : parentOperation(std::move(parentOperation)), next(next) {}
459 
460  PyOperationIterator &dunderIter() { return *this; }
461 
462  py::object dunderNext() {
463  parentOperation->checkValid();
464  if (mlirOperationIsNull(next)) {
465  throw py::stop_iteration();
466  }
467 
468  PyOperationRef returnOperation =
469  PyOperation::forOperation(parentOperation->getContext(), next);
470  next = mlirOperationGetNextInBlock(next);
471  return returnOperation->createOpView();
472  }
473 
474  static void bind(py::module &m) {
475  py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
476  .def("__iter__", &PyOperationIterator::dunderIter)
477  .def("__next__", &PyOperationIterator::dunderNext);
478  }
479 
480 private:
481  PyOperationRef parentOperation;
482  MlirOperation next;
483 };
484 
485 /// Operations are exposed by the C-API as a forward-only linked list. In
486 /// Python, we present them as a more full-featured list-like container but
487 /// optimize it for forward iteration. Iterable operations are always owned
488 /// by a block.
489 class PyOperationList {
490 public:
491  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
492  : parentOperation(std::move(parentOperation)), block(block) {}
493 
494  PyOperationIterator dunderIter() {
495  parentOperation->checkValid();
496  return PyOperationIterator(parentOperation,
498  }
499 
500  intptr_t dunderLen() {
501  parentOperation->checkValid();
502  intptr_t count = 0;
503  MlirOperation childOp = mlirBlockGetFirstOperation(block);
504  while (!mlirOperationIsNull(childOp)) {
505  count += 1;
506  childOp = mlirOperationGetNextInBlock(childOp);
507  }
508  return count;
509  }
510 
511  py::object dunderGetItem(intptr_t index) {
512  parentOperation->checkValid();
513  if (index < 0) {
514  throw py::index_error("attempt to access out of bounds operation");
515  }
516  MlirOperation childOp = mlirBlockGetFirstOperation(block);
517  while (!mlirOperationIsNull(childOp)) {
518  if (index == 0) {
519  return PyOperation::forOperation(parentOperation->getContext(), childOp)
520  ->createOpView();
521  }
522  childOp = mlirOperationGetNextInBlock(childOp);
523  index -= 1;
524  }
525  throw py::index_error("attempt to access out of bounds operation");
526  }
527 
528  static void bind(py::module &m) {
529  py::class_<PyOperationList>(m, "OperationList", py::module_local())
530  .def("__getitem__", &PyOperationList::dunderGetItem)
531  .def("__iter__", &PyOperationList::dunderIter)
532  .def("__len__", &PyOperationList::dunderLen);
533  }
534 
535 private:
536  PyOperationRef parentOperation;
537  MlirBlock block;
538 };
539 
540 class PyOpOperand {
541 public:
542  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
543 
544  py::object getOwner() {
545  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
546  PyMlirContextRef context =
547  PyMlirContext::forContext(mlirOperationGetContext(owner));
548  return PyOperation::forOperation(context, owner)->createOpView();
549  }
550 
551  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
552 
553  static void bind(py::module &m) {
554  py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
555  .def_property_readonly("owner", &PyOpOperand::getOwner)
556  .def_property_readonly("operand_number",
557  &PyOpOperand::getOperandNumber);
558  }
559 
560 private:
561  MlirOpOperand opOperand;
562 };
563 
564 class PyOpOperandIterator {
565 public:
566  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
567 
568  PyOpOperandIterator &dunderIter() { return *this; }
569 
570  PyOpOperand dunderNext() {
571  if (mlirOpOperandIsNull(opOperand))
572  throw py::stop_iteration();
573 
574  PyOpOperand returnOpOperand(opOperand);
575  opOperand = mlirOpOperandGetNextUse(opOperand);
576  return returnOpOperand;
577  }
578 
579  static void bind(py::module &m) {
580  py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
581  .def("__iter__", &PyOpOperandIterator::dunderIter)
582  .def("__next__", &PyOpOperandIterator::dunderNext);
583  }
584 
585 private:
586  MlirOpOperand opOperand;
587 };
588 
589 } // namespace
590 
591 //------------------------------------------------------------------------------
592 // PyMlirContext
593 //------------------------------------------------------------------------------
594 
595 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
596  py::gil_scoped_acquire acquire;
597  auto &liveContexts = getLiveContexts();
598  liveContexts[context.ptr] = this;
599 }
600 
602  // Note that the only public way to construct an instance is via the
603  // forContext method, which always puts the associated handle into
604  // liveContexts.
605  py::gil_scoped_acquire acquire;
606  getLiveContexts().erase(context.ptr);
607  mlirContextDestroy(context);
608 }
609 
611  return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
612 }
613 
614 py::object PyMlirContext::createFromCapsule(py::object capsule) {
615  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
616  if (mlirContextIsNull(rawContext))
617  throw py::error_already_set();
618  return forContext(rawContext).releaseObject();
619 }
620 
622  MlirContext context = mlirContextCreateWithThreading(false);
623  return new PyMlirContext(context);
624 }
625 
627  py::gil_scoped_acquire acquire;
628  auto &liveContexts = getLiveContexts();
629  auto it = liveContexts.find(context.ptr);
630  if (it == liveContexts.end()) {
631  // Create.
632  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
633  py::object pyRef = py::cast(unownedContextWrapper);
634  assert(pyRef && "cast to py::object failed");
635  liveContexts[context.ptr] = unownedContextWrapper;
636  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
637  }
638  // Use existing.
639  py::object pyRef = py::cast(it->second);
640  return PyMlirContextRef(it->second, std::move(pyRef));
641 }
642 
643 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
644  static LiveContextMap liveContexts;
645  return liveContexts;
646 }
647 
648 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
649 
650 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
651 
652 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
653  std::vector<PyOperation *> liveObjects;
654  for (auto &entry : liveOperations)
655  liveObjects.push_back(entry.second.second);
656  return liveObjects;
657 }
658 
660  for (auto &op : liveOperations)
661  op.second.second->setInvalid();
662  size_t numInvalidated = liveOperations.size();
663  liveOperations.clear();
664  return numInvalidated;
665 }
666 
667 void PyMlirContext::clearOperation(MlirOperation op) {
668  auto it = liveOperations.find(op.ptr);
669  if (it != liveOperations.end()) {
670  it->second.second->setInvalid();
671  liveOperations.erase(it);
672  }
673 }
674 
676  typedef struct {
677  PyOperation &rootOp;
678  bool rootSeen;
679  } callBackData;
680  callBackData data{op.getOperation(), false};
681  // Mark all ops below the op that the passmanager will be rooted
682  // at (but not op itself - note the preorder) as invalid.
683  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
684  void *userData) {
685  callBackData *data = static_cast<callBackData *>(userData);
686  if (LLVM_LIKELY(data->rootSeen))
687  data->rootOp.getOperation().getContext()->clearOperation(op);
688  else
689  data->rootSeen = true;
691  };
692  mlirOperationWalk(op.getOperation(), invalidatingCallback,
693  static_cast<void *>(&data), MlirWalkPreOrder);
694 }
695 void PyMlirContext::clearOperationsInside(MlirOperation op) {
698 }
699 
700 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
701 
702 pybind11::object PyMlirContext::contextEnter() {
703  return PyThreadContextEntry::pushContext(*this);
704 }
705 
706 void PyMlirContext::contextExit(const pybind11::object &excType,
707  const pybind11::object &excVal,
708  const pybind11::object &excTb) {
710 }
711 
712 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
713  // Note that ownership is transferred to the delete callback below by way of
714  // an explicit inc_ref (borrow).
715  PyDiagnosticHandler *pyHandler =
716  new PyDiagnosticHandler(get(), std::move(callback));
717  py::object pyHandlerObject =
718  py::cast(pyHandler, py::return_value_policy::take_ownership);
719  pyHandlerObject.inc_ref();
720 
721  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
722  // guaranteed to be known to pybind.
723  auto handlerCallback =
724  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
725  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
726  py::object pyDiagnosticObject =
727  py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
728 
729  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
730  bool result = false;
731  {
732  // Since this can be called from arbitrary C++ contexts, always get the
733  // gil.
734  py::gil_scoped_acquire gil;
735  try {
736  result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
737  } catch (std::exception &e) {
738  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
739  e.what());
740  pyHandler->hadError = true;
741  }
742  }
743 
744  pyDiagnostic->invalidate();
746  };
747  auto deleteCallback = +[](void *userData) {
748  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
749  assert(pyHandler->registeredID && "handler is not registered");
750  pyHandler->registeredID.reset();
751 
752  // Decrement reference, balancing the inc_ref() above.
753  py::object pyHandlerObject =
754  py::cast(pyHandler, py::return_value_policy::reference);
755  pyHandlerObject.dec_ref();
756  };
757 
758  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
759  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
760  return pyHandlerObject;
761 }
762 
763 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
764  void *userData) {
765  auto *self = static_cast<ErrorCapture *>(userData);
766  // Check if the context requested we emit errors instead of capturing them.
767  if (self->ctx->emitErrorDiagnostics)
768  return mlirLogicalResultFailure();
769 
771  return mlirLogicalResultFailure();
772 
773  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
774  return mlirLogicalResultSuccess();
775 }
776 
779  if (!context) {
780  throw std::runtime_error(
781  "An MLIR function requires a Context but none was provided in the call "
782  "or from the surrounding environment. Either pass to the function with "
783  "a 'context=' argument or establish a default using 'with Context():'");
784  }
785  return *context;
786 }
787 
788 //------------------------------------------------------------------------------
789 // PyThreadContextEntry management
790 //------------------------------------------------------------------------------
791 
792 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
793  static thread_local std::vector<PyThreadContextEntry> stack;
794  return stack;
795 }
796 
798  auto &stack = getStack();
799  if (stack.empty())
800  return nullptr;
801  return &stack.back();
802 }
803 
804 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
805  py::object insertionPoint,
806  py::object location) {
807  auto &stack = getStack();
808  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
809  std::move(location));
810  // If the new stack has more than one entry and the context of the new top
811  // entry matches the previous, copy the insertionPoint and location from the
812  // previous entry if missing from the new top entry.
813  if (stack.size() > 1) {
814  auto &prev = *(stack.rbegin() + 1);
815  auto &current = stack.back();
816  if (current.context.is(prev.context)) {
817  // Default non-context objects from the previous entry.
818  if (!current.insertionPoint)
819  current.insertionPoint = prev.insertionPoint;
820  if (!current.location)
821  current.location = prev.location;
822  }
823  }
824 }
825 
827  if (!context)
828  return nullptr;
829  return py::cast<PyMlirContext *>(context);
830 }
831 
833  if (!insertionPoint)
834  return nullptr;
835  return py::cast<PyInsertionPoint *>(insertionPoint);
836 }
837 
839  if (!location)
840  return nullptr;
841  return py::cast<PyLocation *>(location);
842 }
843 
845  auto *tos = getTopOfStack();
846  return tos ? tos->getContext() : nullptr;
847 }
848 
850  auto *tos = getTopOfStack();
851  return tos ? tos->getInsertionPoint() : nullptr;
852 }
853 
855  auto *tos = getTopOfStack();
856  return tos ? tos->getLocation() : nullptr;
857 }
858 
860  py::object contextObj = py::cast(context);
861  push(FrameKind::Context, /*context=*/contextObj,
862  /*insertionPoint=*/py::object(),
863  /*location=*/py::object());
864  return contextObj;
865 }
866 
868  auto &stack = getStack();
869  if (stack.empty())
870  throw std::runtime_error("Unbalanced Context enter/exit");
871  auto &tos = stack.back();
872  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
873  throw std::runtime_error("Unbalanced Context enter/exit");
874  stack.pop_back();
875 }
876 
877 py::object
879  py::object contextObj =
880  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
881  py::object insertionPointObj = py::cast(insertionPoint);
882  push(FrameKind::InsertionPoint,
883  /*context=*/contextObj,
884  /*insertionPoint=*/insertionPointObj,
885  /*location=*/py::object());
886  return insertionPointObj;
887 }
888 
890  auto &stack = getStack();
891  if (stack.empty())
892  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
893  auto &tos = stack.back();
894  if (tos.frameKind != FrameKind::InsertionPoint &&
895  tos.getInsertionPoint() != &insertionPoint)
896  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
897  stack.pop_back();
898 }
899 
901  py::object contextObj = location.getContext().getObject();
902  py::object locationObj = py::cast(location);
903  push(FrameKind::Location, /*context=*/contextObj,
904  /*insertionPoint=*/py::object(),
905  /*location=*/locationObj);
906  return locationObj;
907 }
908 
910  auto &stack = getStack();
911  if (stack.empty())
912  throw std::runtime_error("Unbalanced Location enter/exit");
913  auto &tos = stack.back();
914  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
915  throw std::runtime_error("Unbalanced Location enter/exit");
916  stack.pop_back();
917 }
918 
919 //------------------------------------------------------------------------------
920 // PyDiagnostic*
921 //------------------------------------------------------------------------------
922 
924  valid = false;
925  if (materializedNotes) {
926  for (auto &noteObject : *materializedNotes) {
927  PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
928  note->invalidate();
929  }
930  }
931 }
932 
934  py::object callback)
935  : context(context), callback(std::move(callback)) {}
936 
938 
940  if (!registeredID)
941  return;
942  MlirDiagnosticHandlerID localID = *registeredID;
943  mlirContextDetachDiagnosticHandler(context, localID);
944  assert(!registeredID && "should have unregistered");
945  // Not strictly necessary but keeps stale pointers from being around to cause
946  // issues.
947  context = {nullptr};
948 }
949 
950 void PyDiagnostic::checkValid() {
951  if (!valid) {
952  throw std::invalid_argument(
953  "Diagnostic is invalid (used outside of callback)");
954  }
955 }
956 
958  checkValid();
959  return mlirDiagnosticGetSeverity(diagnostic);
960 }
961 
963  checkValid();
964  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
965  MlirContext context = mlirLocationGetContext(loc);
966  return PyLocation(PyMlirContext::forContext(context), loc);
967 }
968 
970  checkValid();
971  py::object fileObject = py::module::import("io").attr("StringIO")();
972  PyFileAccumulator accum(fileObject, /*binary=*/false);
973  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
974  return fileObject.attr("getvalue")();
975 }
976 
978  checkValid();
979  if (materializedNotes)
980  return *materializedNotes;
981  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
982  materializedNotes = py::tuple(numNotes);
983  for (intptr_t i = 0; i < numNotes; ++i) {
984  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
985  (*materializedNotes)[i] = PyDiagnostic(noteDiag);
986  }
987  return *materializedNotes;
988 }
989 
991  std::vector<DiagnosticInfo> notes;
992  for (py::handle n : getNotes())
993  notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
994  return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
995 }
996 
997 //------------------------------------------------------------------------------
998 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
999 //------------------------------------------------------------------------------
1000 
1001 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1002  bool attrError) {
1003  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1004  {key.data(), key.size()});
1005  if (mlirDialectIsNull(dialect)) {
1006  std::string msg = (Twine("Dialect '") + key + "' not found").str();
1007  if (attrError)
1008  throw py::attribute_error(msg);
1009  throw py::index_error(msg);
1010  }
1011  return dialect;
1012 }
1013 
1015  return py::reinterpret_steal<py::object>(
1017 }
1018 
1020  MlirDialectRegistry rawRegistry =
1021  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1022  if (mlirDialectRegistryIsNull(rawRegistry))
1023  throw py::error_already_set();
1024  return PyDialectRegistry(rawRegistry);
1025 }
1026 
1027 //------------------------------------------------------------------------------
1028 // PyLocation
1029 //------------------------------------------------------------------------------
1030 
1032  return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
1033 }
1034 
1036  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1037  if (mlirLocationIsNull(rawLoc))
1038  throw py::error_already_set();
1040  rawLoc);
1041 }
1042 
1044  return PyThreadContextEntry::pushLocation(*this);
1045 }
1046 
1047 void PyLocation::contextExit(const pybind11::object &excType,
1048  const pybind11::object &excVal,
1049  const pybind11::object &excTb) {
1051 }
1052 
1054  auto *location = PyThreadContextEntry::getDefaultLocation();
1055  if (!location) {
1056  throw std::runtime_error(
1057  "An MLIR function requires a Location but none was provided in the "
1058  "call or from the surrounding environment. Either pass to the function "
1059  "with a 'loc=' argument or establish a default using 'with loc:'");
1060  }
1061  return *location;
1062 }
1063 
1064 //------------------------------------------------------------------------------
1065 // PyModule
1066 //------------------------------------------------------------------------------
1067 
1068 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1069  : BaseContextObject(std::move(contextRef)), module(module) {}
1070 
1072  py::gil_scoped_acquire acquire;
1073  auto &liveModules = getContext()->liveModules;
1074  assert(liveModules.count(module.ptr) == 1 &&
1075  "destroying module not in live map");
1076  liveModules.erase(module.ptr);
1077  mlirModuleDestroy(module);
1078 }
1079 
1080 PyModuleRef PyModule::forModule(MlirModule module) {
1081  MlirContext context = mlirModuleGetContext(module);
1082  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1083 
1084  py::gil_scoped_acquire acquire;
1085  auto &liveModules = contextRef->liveModules;
1086  auto it = liveModules.find(module.ptr);
1087  if (it == liveModules.end()) {
1088  // Create.
1089  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1090  // Note that the default return value policy on cast is automatic_reference,
1091  // which does not take ownership (delete will not be called).
1092  // Just be explicit.
1093  py::object pyRef =
1094  py::cast(unownedModule, py::return_value_policy::take_ownership);
1095  unownedModule->handle = pyRef;
1096  liveModules[module.ptr] =
1097  std::make_pair(unownedModule->handle, unownedModule);
1098  return PyModuleRef(unownedModule, std::move(pyRef));
1099  }
1100  // Use existing.
1101  PyModule *existing = it->second.second;
1102  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1103  return PyModuleRef(existing, std::move(pyRef));
1104 }
1105 
1106 py::object PyModule::createFromCapsule(py::object capsule) {
1107  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1108  if (mlirModuleIsNull(rawModule))
1109  throw py::error_already_set();
1110  return forModule(rawModule).releaseObject();
1111 }
1112 
1113 py::object PyModule::getCapsule() {
1114  return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
1115 }
1116 
1117 //------------------------------------------------------------------------------
1118 // PyOperation
1119 //------------------------------------------------------------------------------
1120 
1121 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1122  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1123 
1125  // If the operation has already been invalidated there is nothing to do.
1126  if (!valid)
1127  return;
1128  auto &liveOperations = getContext()->liveOperations;
1129  assert(liveOperations.count(operation.ptr) == 1 &&
1130  "destroying operation not in live map");
1131  liveOperations.erase(operation.ptr);
1132  if (!isAttached()) {
1133  mlirOperationDestroy(operation);
1134  }
1135 }
1136 
1137 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1138  MlirOperation operation,
1139  py::object parentKeepAlive) {
1140  auto &liveOperations = contextRef->liveOperations;
1141  // Create.
1142  PyOperation *unownedOperation =
1143  new PyOperation(std::move(contextRef), operation);
1144  // Note that the default return value policy on cast is automatic_reference,
1145  // which does not take ownership (delete will not be called).
1146  // Just be explicit.
1147  py::object pyRef =
1148  py::cast(unownedOperation, py::return_value_policy::take_ownership);
1149  unownedOperation->handle = pyRef;
1150  if (parentKeepAlive) {
1151  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1152  }
1153  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1154  return PyOperationRef(unownedOperation, std::move(pyRef));
1155 }
1156 
1158  MlirOperation operation,
1159  py::object parentKeepAlive) {
1160  auto &liveOperations = contextRef->liveOperations;
1161  auto it = liveOperations.find(operation.ptr);
1162  if (it == liveOperations.end()) {
1163  // Create.
1164  return createInstance(std::move(contextRef), operation,
1165  std::move(parentKeepAlive));
1166  }
1167  // Use existing.
1168  PyOperation *existing = it->second.second;
1169  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1170  return PyOperationRef(existing, std::move(pyRef));
1171 }
1172 
1174  MlirOperation operation,
1175  py::object parentKeepAlive) {
1176  auto &liveOperations = contextRef->liveOperations;
1177  assert(liveOperations.count(operation.ptr) == 0 &&
1178  "cannot create detached operation that already exists");
1179  (void)liveOperations;
1180 
1181  PyOperationRef created = createInstance(std::move(contextRef), operation,
1182  std::move(parentKeepAlive));
1183  created->attached = false;
1184  return created;
1185 }
1186 
1188  const std::string &sourceStr,
1189  const std::string &sourceName) {
1190  PyMlirContext::ErrorCapture errors(contextRef);
1191  MlirOperation op =
1192  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1193  toMlirStringRef(sourceName));
1194  if (mlirOperationIsNull(op))
1195  throw MLIRError("Unable to parse operation assembly", errors.take());
1196  return PyOperation::createDetached(std::move(contextRef), op);
1197 }
1198 
1200  if (!valid) {
1201  throw std::runtime_error("the operation has been invalidated");
1202  }
1203 }
1204 
1205 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1206  bool enableDebugInfo, bool prettyDebugInfo,
1207  bool printGenericOpForm, bool useLocalScope,
1208  bool assumeVerified, py::object fileObject,
1209  bool binary) {
1210  PyOperation &operation = getOperation();
1211  operation.checkValid();
1212  if (fileObject.is_none())
1213  fileObject = py::module::import("sys").attr("stdout");
1214 
1215  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1216  if (largeElementsLimit)
1217  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1218  if (enableDebugInfo)
1219  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1220  /*prettyForm=*/prettyDebugInfo);
1221  if (printGenericOpForm)
1223  if (useLocalScope)
1225  if (assumeVerified)
1227 
1228  PyFileAccumulator accum(fileObject, binary);
1229  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1230  accum.getUserData());
1232 }
1233 
1234 void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1235  bool binary) {
1236  PyOperation &operation = getOperation();
1237  operation.checkValid();
1238  if (fileObject.is_none())
1239  fileObject = py::module::import("sys").attr("stdout");
1240  PyFileAccumulator accum(fileObject, binary);
1241  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1242  accum.getUserData());
1243 }
1244 
1245 void PyOperationBase::writeBytecode(const py::object &fileObject,
1246  std::optional<int64_t> bytecodeVersion) {
1247  PyOperation &operation = getOperation();
1248  operation.checkValid();
1249  PyFileAccumulator accum(fileObject, /*binary=*/true);
1250 
1251  if (!bytecodeVersion.has_value())
1252  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1253  accum.getUserData());
1254 
1255  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1256  mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1258  operation, config, accum.getCallback(), accum.getUserData());
1260  if (mlirLogicalResultIsFailure(res))
1261  throw py::value_error((Twine("Unable to honor desired bytecode version ") +
1262  Twine(*bytecodeVersion))
1263  .str());
1264 }
1265 
1267  std::function<MlirWalkResult(MlirOperation)> callback,
1268  MlirWalkOrder walkOrder) {
1269  PyOperation &operation = getOperation();
1270  operation.checkValid();
1271  struct UserData {
1272  std::function<MlirWalkResult(MlirOperation)> callback;
1273  bool gotException;
1274  std::string exceptionWhat;
1275  py::object exceptionType;
1276  };
1277  UserData userData{callback, false, {}, {}};
1278  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1279  void *userData) {
1280  UserData *calleeUserData = static_cast<UserData *>(userData);
1281  try {
1282  return (calleeUserData->callback)(op);
1283  } catch (py::error_already_set &e) {
1284  calleeUserData->gotException = true;
1285  calleeUserData->exceptionWhat = e.what();
1286  calleeUserData->exceptionType = e.type();
1288  }
1289  };
1290  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1291  if (userData.gotException) {
1292  std::string message("Exception raised in callback: ");
1293  message.append(userData.exceptionWhat);
1294  throw std::runtime_error(message);
1295  }
1296 }
1297 
1298 py::object PyOperationBase::getAsm(bool binary,
1299  std::optional<int64_t> largeElementsLimit,
1300  bool enableDebugInfo, bool prettyDebugInfo,
1301  bool printGenericOpForm, bool useLocalScope,
1302  bool assumeVerified) {
1303  py::object fileObject;
1304  if (binary) {
1305  fileObject = py::module::import("io").attr("BytesIO")();
1306  } else {
1307  fileObject = py::module::import("io").attr("StringIO")();
1308  }
1309  print(/*largeElementsLimit=*/largeElementsLimit,
1310  /*enableDebugInfo=*/enableDebugInfo,
1311  /*prettyDebugInfo=*/prettyDebugInfo,
1312  /*printGenericOpForm=*/printGenericOpForm,
1313  /*useLocalScope=*/useLocalScope,
1314  /*assumeVerified=*/assumeVerified,
1315  /*fileObject=*/fileObject,
1316  /*binary=*/binary);
1317 
1318  return fileObject.attr("getvalue")();
1319 }
1320 
1322  PyOperation &operation = getOperation();
1323  PyOperation &otherOp = other.getOperation();
1324  operation.checkValid();
1325  otherOp.checkValid();
1326  mlirOperationMoveAfter(operation, otherOp);
1327  operation.parentKeepAlive = otherOp.parentKeepAlive;
1328 }
1329 
1331  PyOperation &operation = getOperation();
1332  PyOperation &otherOp = other.getOperation();
1333  operation.checkValid();
1334  otherOp.checkValid();
1335  mlirOperationMoveBefore(operation, otherOp);
1336  operation.parentKeepAlive = otherOp.parentKeepAlive;
1337 }
1338 
1340  PyOperation &op = getOperation();
1342  if (!mlirOperationVerify(op.get()))
1343  throw MLIRError("Verification failed", errors.take());
1344  return true;
1345 }
1346 
1347 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1348  checkValid();
1349  if (!isAttached())
1350  throw py::value_error("Detached operations have no parent");
1351  MlirOperation operation = mlirOperationGetParentOperation(get());
1352  if (mlirOperationIsNull(operation))
1353  return {};
1354  return PyOperation::forOperation(getContext(), operation);
1355 }
1356 
1358  checkValid();
1359  std::optional<PyOperationRef> parentOperation = getParentOperation();
1360  MlirBlock block = mlirOperationGetBlock(get());
1361  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1362  assert(parentOperation && "Operation has no parent");
1363  return PyBlock{std::move(*parentOperation), block};
1364 }
1365 
1367  checkValid();
1368  return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1369 }
1370 
1371 py::object PyOperation::createFromCapsule(py::object capsule) {
1372  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1373  if (mlirOperationIsNull(rawOperation))
1374  throw py::error_already_set();
1375  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1376  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1377  .releaseObject();
1378 }
1379 
1381  const py::object &maybeIp) {
1382  // InsertPoint active?
1383  if (!maybeIp.is(py::cast(false))) {
1384  PyInsertionPoint *ip;
1385  if (maybeIp.is_none()) {
1387  } else {
1388  ip = py::cast<PyInsertionPoint *>(maybeIp);
1389  }
1390  if (ip)
1391  ip->insert(*op.get());
1392  }
1393 }
1394 
1395 py::object PyOperation::create(const std::string &name,
1396  std::optional<std::vector<PyType *>> results,
1397  std::optional<std::vector<PyValue *>> operands,
1398  std::optional<py::dict> attributes,
1399  std::optional<std::vector<PyBlock *>> successors,
1400  int regions, DefaultingPyLocation location,
1401  const py::object &maybeIp, bool inferType) {
1402  llvm::SmallVector<MlirValue, 4> mlirOperands;
1403  llvm::SmallVector<MlirType, 4> mlirResults;
1404  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1406 
1407  // General parameter validation.
1408  if (regions < 0)
1409  throw py::value_error("number of regions must be >= 0");
1410 
1411  // Unpack/validate operands.
1412  if (operands) {
1413  mlirOperands.reserve(operands->size());
1414  for (PyValue *operand : *operands) {
1415  if (!operand)
1416  throw py::value_error("operand value cannot be None");
1417  mlirOperands.push_back(operand->get());
1418  }
1419  }
1420 
1421  // Unpack/validate results.
1422  if (results) {
1423  mlirResults.reserve(results->size());
1424  for (PyType *result : *results) {
1425  // TODO: Verify result type originate from the same context.
1426  if (!result)
1427  throw py::value_error("result type cannot be None");
1428  mlirResults.push_back(*result);
1429  }
1430  }
1431  // Unpack/validate attributes.
1432  if (attributes) {
1433  mlirAttributes.reserve(attributes->size());
1434  for (auto &it : *attributes) {
1435  std::string key;
1436  try {
1437  key = it.first.cast<std::string>();
1438  } catch (py::cast_error &err) {
1439  std::string msg = "Invalid attribute key (not a string) when "
1440  "attempting to create the operation \"" +
1441  name + "\" (" + err.what() + ")";
1442  throw py::cast_error(msg);
1443  }
1444  try {
1445  auto &attribute = it.second.cast<PyAttribute &>();
1446  // TODO: Verify attribute originates from the same context.
1447  mlirAttributes.emplace_back(std::move(key), attribute);
1448  } catch (py::reference_cast_error &) {
1449  // This exception seems thrown when the value is "None".
1450  std::string msg =
1451  "Found an invalid (`None`?) attribute value for the key \"" + key +
1452  "\" when attempting to create the operation \"" + name + "\"";
1453  throw py::cast_error(msg);
1454  } catch (py::cast_error &err) {
1455  std::string msg = "Invalid attribute value for the key \"" + key +
1456  "\" when attempting to create the operation \"" +
1457  name + "\" (" + err.what() + ")";
1458  throw py::cast_error(msg);
1459  }
1460  }
1461  }
1462  // Unpack/validate successors.
1463  if (successors) {
1464  mlirSuccessors.reserve(successors->size());
1465  for (auto *successor : *successors) {
1466  // TODO: Verify successor originate from the same context.
1467  if (!successor)
1468  throw py::value_error("successor block cannot be None");
1469  mlirSuccessors.push_back(successor->get());
1470  }
1471  }
1472 
1473  // Apply unpacked/validated to the operation state. Beyond this
1474  // point, exceptions cannot be thrown or else the state will leak.
1475  MlirOperationState state =
1476  mlirOperationStateGet(toMlirStringRef(name), location);
1477  if (!mlirOperands.empty())
1478  mlirOperationStateAddOperands(&state, mlirOperands.size(),
1479  mlirOperands.data());
1480  state.enableResultTypeInference = inferType;
1481  if (!mlirResults.empty())
1482  mlirOperationStateAddResults(&state, mlirResults.size(),
1483  mlirResults.data());
1484  if (!mlirAttributes.empty()) {
1485  // Note that the attribute names directly reference bytes in
1486  // mlirAttributes, so that vector must not be changed from here
1487  // on.
1488  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1489  mlirNamedAttributes.reserve(mlirAttributes.size());
1490  for (auto &it : mlirAttributes)
1491  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1493  toMlirStringRef(it.first)),
1494  it.second));
1495  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1496  mlirNamedAttributes.data());
1497  }
1498  if (!mlirSuccessors.empty())
1499  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1500  mlirSuccessors.data());
1501  if (regions) {
1503  mlirRegions.resize(regions);
1504  for (int i = 0; i < regions; ++i)
1505  mlirRegions[i] = mlirRegionCreate();
1506  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1507  mlirRegions.data());
1508  }
1509 
1510  // Construct the operation.
1511  MlirOperation operation = mlirOperationCreate(&state);
1512  if (!operation.ptr)
1513  throw py::value_error("Operation creation failed");
1514  PyOperationRef created =
1515  PyOperation::createDetached(location->getContext(), operation);
1516  maybeInsertOperation(created, maybeIp);
1517 
1518  return created->createOpView();
1519 }
1520 
1521 py::object PyOperation::clone(const py::object &maybeIp) {
1522  MlirOperation clonedOperation = mlirOperationClone(operation);
1523  PyOperationRef cloned =
1524  PyOperation::createDetached(getContext(), clonedOperation);
1525  maybeInsertOperation(cloned, maybeIp);
1526 
1527  return cloned->createOpView();
1528 }
1529 
1531  checkValid();
1532  MlirIdentifier ident = mlirOperationGetName(get());
1533  MlirStringRef identStr = mlirIdentifierStr(ident);
1534  auto operationCls = PyGlobals::get().lookupOperationClass(
1535  StringRef(identStr.data, identStr.length));
1536  if (operationCls)
1537  return PyOpView::constructDerived(*operationCls, *getRef().get());
1538  return py::cast(PyOpView(getRef().getObject()));
1539 }
1540 
1542  checkValid();
1543  // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1544  // Python reference to a child operation is live. All children should also
1545  // have their `valid` bit set to false.
1546  auto &liveOperations = getContext()->liveOperations;
1547  if (liveOperations.count(operation.ptr))
1548  liveOperations.erase(operation.ptr);
1549  mlirOperationDestroy(operation);
1550  valid = false;
1551 }
1552 
1553 //------------------------------------------------------------------------------
1554 // PyOpView
1555 //------------------------------------------------------------------------------
1556 
1557 static void populateResultTypes(StringRef name, py::list resultTypeList,
1558  const py::object &resultSegmentSpecObj,
1559  std::vector<int32_t> &resultSegmentLengths,
1560  std::vector<PyType *> &resultTypes) {
1561  resultTypes.reserve(resultTypeList.size());
1562  if (resultSegmentSpecObj.is_none()) {
1563  // Non-variadic result unpacking.
1564  for (const auto &it : llvm::enumerate(resultTypeList)) {
1565  try {
1566  resultTypes.push_back(py::cast<PyType *>(it.value()));
1567  if (!resultTypes.back())
1568  throw py::cast_error();
1569  } catch (py::cast_error &err) {
1570  throw py::value_error((llvm::Twine("Result ") +
1571  llvm::Twine(it.index()) + " of operation \"" +
1572  name + "\" must be a Type (" + err.what() + ")")
1573  .str());
1574  }
1575  }
1576  } else {
1577  // Sized result unpacking.
1578  auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1579  if (resultSegmentSpec.size() != resultTypeList.size()) {
1580  throw py::value_error((llvm::Twine("Operation \"") + name +
1581  "\" requires " +
1582  llvm::Twine(resultSegmentSpec.size()) +
1583  " result segments but was provided " +
1584  llvm::Twine(resultTypeList.size()))
1585  .str());
1586  }
1587  resultSegmentLengths.reserve(resultTypeList.size());
1588  for (const auto &it :
1589  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1590  int segmentSpec = std::get<1>(it.value());
1591  if (segmentSpec == 1 || segmentSpec == 0) {
1592  // Unpack unary element.
1593  try {
1594  auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1595  if (resultType) {
1596  resultTypes.push_back(resultType);
1597  resultSegmentLengths.push_back(1);
1598  } else if (segmentSpec == 0) {
1599  // Allowed to be optional.
1600  resultSegmentLengths.push_back(0);
1601  } else {
1602  throw py::cast_error("was None and result is not optional");
1603  }
1604  } catch (py::cast_error &err) {
1605  throw py::value_error((llvm::Twine("Result ") +
1606  llvm::Twine(it.index()) + " of operation \"" +
1607  name + "\" must be a Type (" + err.what() +
1608  ")")
1609  .str());
1610  }
1611  } else if (segmentSpec == -1) {
1612  // Unpack sequence by appending.
1613  try {
1614  if (std::get<0>(it.value()).is_none()) {
1615  // Treat it as an empty list.
1616  resultSegmentLengths.push_back(0);
1617  } else {
1618  // Unpack the list.
1619  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1620  for (py::object segmentItem : segment) {
1621  resultTypes.push_back(py::cast<PyType *>(segmentItem));
1622  if (!resultTypes.back()) {
1623  throw py::cast_error("contained a None item");
1624  }
1625  }
1626  resultSegmentLengths.push_back(segment.size());
1627  }
1628  } catch (std::exception &err) {
1629  // NOTE: Sloppy to be using a catch-all here, but there are at least
1630  // three different unrelated exceptions that can be thrown in the
1631  // above "casts". Just keep the scope above small and catch them all.
1632  throw py::value_error((llvm::Twine("Result ") +
1633  llvm::Twine(it.index()) + " of operation \"" +
1634  name + "\" must be a Sequence of Types (" +
1635  err.what() + ")")
1636  .str());
1637  }
1638  } else {
1639  throw py::value_error("Unexpected segment spec");
1640  }
1641  }
1642  }
1643 }
1644 
1646  const py::object &cls, std::optional<py::list> resultTypeList,
1647  py::list operandList, std::optional<py::dict> attributes,
1648  std::optional<std::vector<PyBlock *>> successors,
1649  std::optional<int> regions, DefaultingPyLocation location,
1650  const py::object &maybeIp) {
1651  PyMlirContextRef context = location->getContext();
1652  // Class level operation construction metadata.
1653  std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1654  // Operand and result segment specs are either none, which does no
1655  // variadic unpacking, or a list of ints with segment sizes, where each
1656  // element is either a positive number (typically 1 for a scalar) or -1 to
1657  // indicate that it is derived from the length of the same-indexed operand
1658  // or result (implying that it is a list at that position).
1659  py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1660  py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1661 
1662  std::vector<int32_t> operandSegmentLengths;
1663  std::vector<int32_t> resultSegmentLengths;
1664 
1665  // Validate/determine region count.
1666  auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1667  int opMinRegionCount = std::get<0>(opRegionSpec);
1668  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1669  if (!regions) {
1670  regions = opMinRegionCount;
1671  }
1672  if (*regions < opMinRegionCount) {
1673  throw py::value_error(
1674  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1675  llvm::Twine(opMinRegionCount) +
1676  " regions but was built with regions=" + llvm::Twine(*regions))
1677  .str());
1678  }
1679  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1680  throw py::value_error(
1681  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1682  llvm::Twine(opMinRegionCount) +
1683  " regions but was built with regions=" + llvm::Twine(*regions))
1684  .str());
1685  }
1686 
1687  // Unpack results.
1688  std::vector<PyType *> resultTypes;
1689  if (resultTypeList.has_value()) {
1690  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1691  resultSegmentLengths, resultTypes);
1692  }
1693 
1694  // Unpack operands.
1695  std::vector<PyValue *> operands;
1696  operands.reserve(operands.size());
1697  if (operandSegmentSpecObj.is_none()) {
1698  // Non-sized operand unpacking.
1699  for (const auto &it : llvm::enumerate(operandList)) {
1700  try {
1701  operands.push_back(py::cast<PyValue *>(it.value()));
1702  if (!operands.back())
1703  throw py::cast_error();
1704  } catch (py::cast_error &err) {
1705  throw py::value_error((llvm::Twine("Operand ") +
1706  llvm::Twine(it.index()) + " of operation \"" +
1707  name + "\" must be a Value (" + err.what() + ")")
1708  .str());
1709  }
1710  }
1711  } else {
1712  // Sized operand unpacking.
1713  auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1714  if (operandSegmentSpec.size() != operandList.size()) {
1715  throw py::value_error((llvm::Twine("Operation \"") + name +
1716  "\" requires " +
1717  llvm::Twine(operandSegmentSpec.size()) +
1718  "operand segments but was provided " +
1719  llvm::Twine(operandList.size()))
1720  .str());
1721  }
1722  operandSegmentLengths.reserve(operandList.size());
1723  for (const auto &it :
1724  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1725  int segmentSpec = std::get<1>(it.value());
1726  if (segmentSpec == 1 || segmentSpec == 0) {
1727  // Unpack unary element.
1728  try {
1729  auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1730  if (operandValue) {
1731  operands.push_back(operandValue);
1732  operandSegmentLengths.push_back(1);
1733  } else if (segmentSpec == 0) {
1734  // Allowed to be optional.
1735  operandSegmentLengths.push_back(0);
1736  } else {
1737  throw py::cast_error("was None and operand is not optional");
1738  }
1739  } catch (py::cast_error &err) {
1740  throw py::value_error((llvm::Twine("Operand ") +
1741  llvm::Twine(it.index()) + " of operation \"" +
1742  name + "\" must be a Value (" + err.what() +
1743  ")")
1744  .str());
1745  }
1746  } else if (segmentSpec == -1) {
1747  // Unpack sequence by appending.
1748  try {
1749  if (std::get<0>(it.value()).is_none()) {
1750  // Treat it as an empty list.
1751  operandSegmentLengths.push_back(0);
1752  } else {
1753  // Unpack the list.
1754  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1755  for (py::object segmentItem : segment) {
1756  operands.push_back(py::cast<PyValue *>(segmentItem));
1757  if (!operands.back()) {
1758  throw py::cast_error("contained a None item");
1759  }
1760  }
1761  operandSegmentLengths.push_back(segment.size());
1762  }
1763  } catch (std::exception &err) {
1764  // NOTE: Sloppy to be using a catch-all here, but there are at least
1765  // three different unrelated exceptions that can be thrown in the
1766  // above "casts". Just keep the scope above small and catch them all.
1767  throw py::value_error((llvm::Twine("Operand ") +
1768  llvm::Twine(it.index()) + " of operation \"" +
1769  name + "\" must be a Sequence of Values (" +
1770  err.what() + ")")
1771  .str());
1772  }
1773  } else {
1774  throw py::value_error("Unexpected segment spec");
1775  }
1776  }
1777  }
1778 
1779  // Merge operand/result segment lengths into attributes if needed.
1780  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1781  // Dup.
1782  if (attributes) {
1783  attributes = py::dict(*attributes);
1784  } else {
1785  attributes = py::dict();
1786  }
1787  if (attributes->contains("resultSegmentSizes") ||
1788  attributes->contains("operandSegmentSizes")) {
1789  throw py::value_error("Manually setting a 'resultSegmentSizes' or "
1790  "'operandSegmentSizes' attribute is unsupported. "
1791  "Use Operation.create for such low-level access.");
1792  }
1793 
1794  // Add resultSegmentSizes attribute.
1795  if (!resultSegmentLengths.empty()) {
1796  MlirAttribute segmentLengthAttr =
1797  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1798  resultSegmentLengths.data());
1799  (*attributes)["resultSegmentSizes"] =
1800  PyAttribute(context, segmentLengthAttr);
1801  }
1802 
1803  // Add operandSegmentSizes attribute.
1804  if (!operandSegmentLengths.empty()) {
1805  MlirAttribute segmentLengthAttr =
1806  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1807  operandSegmentLengths.data());
1808  (*attributes)["operandSegmentSizes"] =
1809  PyAttribute(context, segmentLengthAttr);
1810  }
1811  }
1812 
1813  // Delegate to create.
1814  return PyOperation::create(name,
1815  /*results=*/std::move(resultTypes),
1816  /*operands=*/std::move(operands),
1817  /*attributes=*/std::move(attributes),
1818  /*successors=*/std::move(successors),
1819  /*regions=*/*regions, location, maybeIp,
1820  !resultTypeList);
1821 }
1822 
1823 pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1824  const PyOperation &operation) {
1825  // TODO: pybind11 2.6 supports a more direct form.
1826  // Upgrade many years from now.
1827  // auto opViewType = py::type::of<PyOpView>();
1828  py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1829  py::object instance = cls.attr("__new__")(cls);
1830  opViewType.attr("__init__")(instance, operation);
1831  return instance;
1832 }
1833 
1834 PyOpView::PyOpView(const py::object &operationObject)
1835  // Casting through the PyOperationBase base-class and then back to the
1836  // Operation lets us accept any PyOperationBase subclass.
1837  : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1838  operationObject(operation.getRef().getObject()) {}
1839 
1840 //------------------------------------------------------------------------------
1841 // PyInsertionPoint.
1842 //------------------------------------------------------------------------------
1843 
1845 
1847  : refOperation(beforeOperationBase.getOperation().getRef()),
1848  block((*refOperation)->getBlock()) {}
1849 
1851  PyOperation &operation = operationBase.getOperation();
1852  if (operation.isAttached())
1853  throw py::value_error(
1854  "Attempt to insert operation that is already attached");
1855  block.getParentOperation()->checkValid();
1856  MlirOperation beforeOp = {nullptr};
1857  if (refOperation) {
1858  // Insert before operation.
1859  (*refOperation)->checkValid();
1860  beforeOp = (*refOperation)->get();
1861  } else {
1862  // Insert at end (before null) is only valid if the block does not
1863  // already end in a known terminator (violating this will cause assertion
1864  // failures later).
1865  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1866  throw py::index_error("Cannot insert operation at the end of a block "
1867  "that already has a terminator. Did you mean to "
1868  "use 'InsertionPoint.at_block_terminator(block)' "
1869  "versus 'InsertionPoint(block)'?");
1870  }
1871  }
1872  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1873  operation.setAttached();
1874 }
1875 
1877  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1878  if (mlirOperationIsNull(firstOp)) {
1879  // Just insert at end.
1880  return PyInsertionPoint(block);
1881  }
1882 
1883  // Insert before first op.
1885  block.getParentOperation()->getContext(), firstOp);
1886  return PyInsertionPoint{block, std::move(firstOpRef)};
1887 }
1888 
1890  MlirOperation terminator = mlirBlockGetTerminator(block.get());
1891  if (mlirOperationIsNull(terminator))
1892  throw py::value_error("Block has no terminator");
1893  PyOperationRef terminatorOpRef = PyOperation::forOperation(
1894  block.getParentOperation()->getContext(), terminator);
1895  return PyInsertionPoint{block, std::move(terminatorOpRef)};
1896 }
1897 
1900 }
1901 
1902 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1903  const pybind11::object &excVal,
1904  const pybind11::object &excTb) {
1906 }
1907 
1908 //------------------------------------------------------------------------------
1909 // PyAttribute.
1910 //------------------------------------------------------------------------------
1911 
1912 bool PyAttribute::operator==(const PyAttribute &other) const {
1913  return mlirAttributeEqual(attr, other.attr);
1914 }
1915 
1917  return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1918 }
1919 
1921  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1922  if (mlirAttributeIsNull(rawAttr))
1923  throw py::error_already_set();
1924  return PyAttribute(
1926 }
1927 
1928 //------------------------------------------------------------------------------
1929 // PyNamedAttribute.
1930 //------------------------------------------------------------------------------
1931 
1932 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1933  : ownedName(new std::string(std::move(ownedName))) {
1936  toMlirStringRef(*this->ownedName)),
1937  attr);
1938 }
1939 
1940 //------------------------------------------------------------------------------
1941 // PyType.
1942 //------------------------------------------------------------------------------
1943 
1944 bool PyType::operator==(const PyType &other) const {
1945  return mlirTypeEqual(type, other.type);
1946 }
1947 
1948 py::object PyType::getCapsule() {
1949  return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1950 }
1951 
1952 PyType PyType::createFromCapsule(py::object capsule) {
1953  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1954  if (mlirTypeIsNull(rawType))
1955  throw py::error_already_set();
1957  rawType);
1958 }
1959 
1960 //------------------------------------------------------------------------------
1961 // PyTypeID.
1962 //------------------------------------------------------------------------------
1963 
1964 py::object PyTypeID::getCapsule() {
1965  return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
1966 }
1967 
1969  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1970  if (mlirTypeIDIsNull(mlirTypeID))
1971  throw py::error_already_set();
1972  return PyTypeID(mlirTypeID);
1973 }
1974 bool PyTypeID::operator==(const PyTypeID &other) const {
1975  return mlirTypeIDEqual(typeID, other.typeID);
1976 }
1977 
1978 //------------------------------------------------------------------------------
1979 // PyValue and subclasses.
1980 //------------------------------------------------------------------------------
1981 
1982 pybind11::object PyValue::getCapsule() {
1983  return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1984 }
1985 
1986 pybind11::object PyValue::maybeDownCast() {
1987  MlirType type = mlirValueGetType(get());
1988  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
1989  assert(!mlirTypeIDIsNull(mlirTypeID) &&
1990  "mlirTypeID was expected to be non-null.");
1991  std::optional<pybind11::function> valueCaster =
1993  // py::return_value_policy::move means use std::move to move the return value
1994  // contents into a new instance that will be owned by Python.
1995  py::object thisObj = py::cast(this, py::return_value_policy::move);
1996  if (!valueCaster)
1997  return thisObj;
1998  return valueCaster.value()(thisObj);
1999 }
2000 
2001 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
2002  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2003  if (mlirValueIsNull(value))
2004  throw py::error_already_set();
2005  MlirOperation owner;
2006  if (mlirValueIsAOpResult(value))
2007  owner = mlirOpResultGetOwner(value);
2008  if (mlirValueIsABlockArgument(value))
2010  if (mlirOperationIsNull(owner))
2011  throw py::error_already_set();
2012  MlirContext ctx = mlirOperationGetContext(owner);
2013  PyOperationRef ownerRef =
2015  return PyValue(ownerRef, value);
2016 }
2017 
2018 //------------------------------------------------------------------------------
2019 // PySymbolTable.
2020 //------------------------------------------------------------------------------
2021 
2023  : operation(operation.getOperation().getRef()) {
2024  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2025  if (mlirSymbolTableIsNull(symbolTable)) {
2026  throw py::cast_error("Operation is not a Symbol Table.");
2027  }
2028 }
2029 
2030 py::object PySymbolTable::dunderGetItem(const std::string &name) {
2031  operation->checkValid();
2032  MlirOperation symbol = mlirSymbolTableLookup(
2033  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2034  if (mlirOperationIsNull(symbol))
2035  throw py::key_error("Symbol '" + name + "' not in the symbol table.");
2036 
2037  return PyOperation::forOperation(operation->getContext(), symbol,
2038  operation.getObject())
2039  ->createOpView();
2040 }
2041 
2043  operation->checkValid();
2044  symbol.getOperation().checkValid();
2045  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2046  // The operation is also erased, so we must invalidate it. There may be Python
2047  // references to this operation so we don't want to delete it from the list of
2048  // live operations here.
2049  symbol.getOperation().valid = false;
2050 }
2051 
2052 void PySymbolTable::dunderDel(const std::string &name) {
2053  py::object operation = dunderGetItem(name);
2054  erase(py::cast<PyOperationBase &>(operation));
2055 }
2056 
2057 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2058  operation->checkValid();
2059  symbol.getOperation().checkValid();
2060  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2062  if (mlirAttributeIsNull(symbolAttr))
2063  throw py::value_error("Expected operation to have a symbol name.");
2064  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2065 }
2066 
2068  // Op must already be a symbol.
2069  PyOperation &operation = symbol.getOperation();
2070  operation.checkValid();
2072  MlirAttribute existingNameAttr =
2073  mlirOperationGetAttributeByName(operation.get(), attrName);
2074  if (mlirAttributeIsNull(existingNameAttr))
2075  throw py::value_error("Expected operation to have a symbol name.");
2076  return existingNameAttr;
2077 }
2078 
2080  const std::string &name) {
2081  // Op must already be a symbol.
2082  PyOperation &operation = symbol.getOperation();
2083  operation.checkValid();
2085  MlirAttribute existingNameAttr =
2086  mlirOperationGetAttributeByName(operation.get(), attrName);
2087  if (mlirAttributeIsNull(existingNameAttr))
2088  throw py::value_error("Expected operation to have a symbol name.");
2089  MlirAttribute newNameAttr =
2090  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2091  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2092 }
2093 
2095  PyOperation &operation = symbol.getOperation();
2096  operation.checkValid();
2098  MlirAttribute existingVisAttr =
2099  mlirOperationGetAttributeByName(operation.get(), attrName);
2100  if (mlirAttributeIsNull(existingVisAttr))
2101  throw py::value_error("Expected operation to have a symbol visibility.");
2102  return existingVisAttr;
2103 }
2104 
2106  const std::string &visibility) {
2107  if (visibility != "public" && visibility != "private" &&
2108  visibility != "nested")
2109  throw py::value_error(
2110  "Expected visibility to be 'public', 'private' or 'nested'");
2111  PyOperation &operation = symbol.getOperation();
2112  operation.checkValid();
2114  MlirAttribute existingVisAttr =
2115  mlirOperationGetAttributeByName(operation.get(), attrName);
2116  if (mlirAttributeIsNull(existingVisAttr))
2117  throw py::value_error("Expected operation to have a symbol visibility.");
2118  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2119  toMlirStringRef(visibility));
2120  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2121 }
2122 
2123 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2124  const std::string &newSymbol,
2125  PyOperationBase &from) {
2126  PyOperation &fromOperation = from.getOperation();
2127  fromOperation.checkValid();
2129  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2130  from.getOperation())))
2131 
2132  throw py::value_error("Symbol rename failed");
2133 }
2134 
2136  bool allSymUsesVisible,
2137  py::object callback) {
2138  PyOperation &fromOperation = from.getOperation();
2139  fromOperation.checkValid();
2140  struct UserData {
2141  PyMlirContextRef context;
2142  py::object callback;
2143  bool gotException;
2144  std::string exceptionWhat;
2145  py::object exceptionType;
2146  };
2147  UserData userData{
2148  fromOperation.getContext(), std::move(callback), false, {}, {}};
2150  fromOperation.get(), allSymUsesVisible,
2151  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2152  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2153  auto pyFoundOp =
2154  PyOperation::forOperation(calleeUserData->context, foundOp);
2155  if (calleeUserData->gotException)
2156  return;
2157  try {
2158  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2159  } catch (py::error_already_set &e) {
2160  calleeUserData->gotException = true;
2161  calleeUserData->exceptionWhat = e.what();
2162  calleeUserData->exceptionType = e.type();
2163  }
2164  },
2165  static_cast<void *>(&userData));
2166  if (userData.gotException) {
2167  std::string message("Exception raised in callback: ");
2168  message.append(userData.exceptionWhat);
2169  throw std::runtime_error(message);
2170  }
2171 }
2172 
2173 namespace {
2174 /// CRTP base class for Python MLIR values that subclass Value and should be
2175 /// castable from it. The value hierarchy is one level deep and is not supposed
2176 /// to accommodate other levels unless core MLIR changes.
2177 template <typename DerivedTy>
2178 class PyConcreteValue : public PyValue {
2179 public:
2180  // Derived classes must define statics for:
2181  // IsAFunctionTy isaFunction
2182  // const char *pyClassName
2183  // and redefine bindDerived.
2184  using ClassTy = py::class_<DerivedTy, PyValue>;
2185  using IsAFunctionTy = bool (*)(MlirValue);
2186 
2187  PyConcreteValue() = default;
2188  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2189  : PyValue(operationRef, value) {}
2190  PyConcreteValue(PyValue &orig)
2191  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2192 
2193  /// Attempts to cast the original value to the derived type and throws on
2194  /// type mismatches.
2195  static MlirValue castFrom(PyValue &orig) {
2196  if (!DerivedTy::isaFunction(orig.get())) {
2197  auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2198  throw py::value_error((Twine("Cannot cast value to ") +
2199  DerivedTy::pyClassName + " (from " + origRepr +
2200  ")")
2201  .str());
2202  }
2203  return orig.get();
2204  }
2205 
2206  /// Binds the Python module objects to functions of this class.
2207  static void bind(py::module &m) {
2208  auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
2209  cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
2210  cls.def_static(
2211  "isinstance",
2212  [](PyValue &otherValue) -> bool {
2213  return DerivedTy::isaFunction(otherValue);
2214  },
2215  py::arg("other_value"));
2217  [](DerivedTy &self) { return self.maybeDownCast(); });
2218  DerivedTy::bindDerived(cls);
2219  }
2220 
2221  /// Implemented by derived classes to add methods to the Python subclass.
2222  static void bindDerived(ClassTy &m) {}
2223 };
2224 
2225 /// Python wrapper for MlirBlockArgument.
2226 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2227 public:
2228  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2229  static constexpr const char *pyClassName = "BlockArgument";
2230  using PyConcreteValue::PyConcreteValue;
2231 
2232  static void bindDerived(ClassTy &c) {
2233  c.def_property_readonly("owner", [](PyBlockArgument &self) {
2234  return PyBlock(self.getParentOperation(),
2235  mlirBlockArgumentGetOwner(self.get()));
2236  });
2237  c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
2238  return mlirBlockArgumentGetArgNumber(self.get());
2239  });
2240  c.def(
2241  "set_type",
2242  [](PyBlockArgument &self, PyType type) {
2243  return mlirBlockArgumentSetType(self.get(), type);
2244  },
2245  py::arg("type"));
2246  }
2247 };
2248 
2249 /// Python wrapper for MlirOpResult.
2250 class PyOpResult : public PyConcreteValue<PyOpResult> {
2251 public:
2252  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2253  static constexpr const char *pyClassName = "OpResult";
2254  using PyConcreteValue::PyConcreteValue;
2255 
2256  static void bindDerived(ClassTy &c) {
2257  c.def_property_readonly("owner", [](PyOpResult &self) {
2258  assert(
2259  mlirOperationEqual(self.getParentOperation()->get(),
2260  mlirOpResultGetOwner(self.get())) &&
2261  "expected the owner of the value in Python to match that in the IR");
2262  return self.getParentOperation().getObject();
2263  });
2264  c.def_property_readonly("result_number", [](PyOpResult &self) {
2265  return mlirOpResultGetResultNumber(self.get());
2266  });
2267  }
2268 };
2269 
2270 /// Returns the list of types of the values held by container.
2271 template <typename Container>
2272 static std::vector<MlirType> getValueTypes(Container &container,
2273  PyMlirContextRef &context) {
2274  std::vector<MlirType> result;
2275  result.reserve(container.size());
2276  for (int i = 0, e = container.size(); i < e; ++i) {
2277  result.push_back(mlirValueGetType(container.getElement(i).get()));
2278  }
2279  return result;
2280 }
2281 
2282 /// A list of block arguments. Internally, these are stored as consecutive
2283 /// elements, random access is cheap. The argument list is associated with the
2284 /// operation that contains the block (detached blocks are not allowed in
2285 /// Python bindings) and extends its lifetime.
2286 class PyBlockArgumentList
2287  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2288 public:
2289  static constexpr const char *pyClassName = "BlockArgumentList";
2291 
2292  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2293  intptr_t startIndex = 0, intptr_t length = -1,
2294  intptr_t step = 1)
2295  : Sliceable(startIndex,
2296  length == -1 ? mlirBlockGetNumArguments(block) : length,
2297  step),
2298  operation(std::move(operation)), block(block) {}
2299 
2300  static void bindDerived(ClassTy &c) {
2301  c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2302  return getValueTypes(self, self.operation->getContext());
2303  });
2304  }
2305 
2306 private:
2307  /// Give the parent CRTP class access to hook implementations below.
2308  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2309 
2310  /// Returns the number of arguments in the list.
2311  intptr_t getRawNumElements() {
2312  operation->checkValid();
2313  return mlirBlockGetNumArguments(block);
2314  }
2315 
2316  /// Returns `pos`-the element in the list.
2317  PyBlockArgument getRawElement(intptr_t pos) {
2318  MlirValue argument = mlirBlockGetArgument(block, pos);
2319  return PyBlockArgument(operation, argument);
2320  }
2321 
2322  /// Returns a sublist of this list.
2323  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2324  intptr_t step) {
2325  return PyBlockArgumentList(operation, block, startIndex, length, step);
2326  }
2327 
2328  PyOperationRef operation;
2329  MlirBlock block;
2330 };
2331 
2332 /// A list of operation operands. Internally, these are stored as consecutive
2333 /// elements, random access is cheap. The (returned) operand list is associated
2334 /// with the operation whose operands these are, and thus extends the lifetime
2335 /// of this operation.
2336 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2337 public:
2338  static constexpr const char *pyClassName = "OpOperandList";
2339  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2340 
2341  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2342  intptr_t length = -1, intptr_t step = 1)
2343  : Sliceable(startIndex,
2344  length == -1 ? mlirOperationGetNumOperands(operation->get())
2345  : length,
2346  step),
2347  operation(operation) {}
2348 
2349  void dunderSetItem(intptr_t index, PyValue value) {
2350  index = wrapIndex(index);
2351  mlirOperationSetOperand(operation->get(), index, value.get());
2352  }
2353 
2354  static void bindDerived(ClassTy &c) {
2355  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2356  }
2357 
2358 private:
2359  /// Give the parent CRTP class access to hook implementations below.
2360  friend class Sliceable<PyOpOperandList, PyValue>;
2361 
2362  intptr_t getRawNumElements() {
2363  operation->checkValid();
2364  return mlirOperationGetNumOperands(operation->get());
2365  }
2366 
2367  PyValue getRawElement(intptr_t pos) {
2368  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2369  MlirOperation owner;
2370  if (mlirValueIsAOpResult(operand))
2371  owner = mlirOpResultGetOwner(operand);
2372  else if (mlirValueIsABlockArgument(operand))
2374  else
2375  assert(false && "Value must be an block arg or op result.");
2376  PyOperationRef pyOwner =
2377  PyOperation::forOperation(operation->getContext(), owner);
2378  return PyValue(pyOwner, operand);
2379  }
2380 
2381  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2382  return PyOpOperandList(operation, startIndex, length, step);
2383  }
2384 
2385  PyOperationRef operation;
2386 };
2387 
2388 /// A list of operation results. Internally, these are stored as consecutive
2389 /// elements, random access is cheap. The (returned) result list is associated
2390 /// with the operation whose results these are, and thus extends the lifetime of
2391 /// this operation.
2392 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2393 public:
2394  static constexpr const char *pyClassName = "OpResultList";
2395  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2396 
2397  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2398  intptr_t length = -1, intptr_t step = 1)
2399  : Sliceable(startIndex,
2400  length == -1 ? mlirOperationGetNumResults(operation->get())
2401  : length,
2402  step),
2403  operation(std::move(operation)) {}
2404 
2405  static void bindDerived(ClassTy &c) {
2406  c.def_property_readonly("types", [](PyOpResultList &self) {
2407  return getValueTypes(self, self.operation->getContext());
2408  });
2409  c.def_property_readonly("owner", [](PyOpResultList &self) {
2410  return self.operation->createOpView();
2411  });
2412  }
2413 
2414 private:
2415  /// Give the parent CRTP class access to hook implementations below.
2416  friend class Sliceable<PyOpResultList, PyOpResult>;
2417 
2418  intptr_t getRawNumElements() {
2419  operation->checkValid();
2420  return mlirOperationGetNumResults(operation->get());
2421  }
2422 
2423  PyOpResult getRawElement(intptr_t index) {
2424  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2425  return PyOpResult(value);
2426  }
2427 
2428  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2429  return PyOpResultList(operation, startIndex, length, step);
2430  }
2431 
2432  PyOperationRef operation;
2433 };
2434 
2435 /// A list of operation successors. Internally, these are stored as consecutive
2436 /// elements, random access is cheap. The (returned) successor list is
2437 /// associated with the operation whose successors these are, and thus extends
2438 /// the lifetime of this operation.
2439 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2440 public:
2441  static constexpr const char *pyClassName = "OpSuccessors";
2442 
2443  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2444  intptr_t length = -1, intptr_t step = 1)
2445  : Sliceable(startIndex,
2446  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2447  : length,
2448  step),
2449  operation(operation) {}
2450 
2451  void dunderSetItem(intptr_t index, PyBlock block) {
2452  index = wrapIndex(index);
2453  mlirOperationSetSuccessor(operation->get(), index, block.get());
2454  }
2455 
2456  static void bindDerived(ClassTy &c) {
2457  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2458  }
2459 
2460 private:
2461  /// Give the parent CRTP class access to hook implementations below.
2462  friend class Sliceable<PyOpSuccessors, PyBlock>;
2463 
2464  intptr_t getRawNumElements() {
2465  operation->checkValid();
2466  return mlirOperationGetNumSuccessors(operation->get());
2467  }
2468 
2469  PyBlock getRawElement(intptr_t pos) {
2470  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2471  return PyBlock(operation, block);
2472  }
2473 
2474  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2475  return PyOpSuccessors(operation, startIndex, length, step);
2476  }
2477 
2478  PyOperationRef operation;
2479 };
2480 
2481 /// A list of operation attributes. Can be indexed by name, producing
2482 /// attributes, or by index, producing named attributes.
2483 class PyOpAttributeMap {
2484 public:
2485  PyOpAttributeMap(PyOperationRef operation)
2486  : operation(std::move(operation)) {}
2487 
2488  MlirAttribute dunderGetItemNamed(const std::string &name) {
2489  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2490  toMlirStringRef(name));
2491  if (mlirAttributeIsNull(attr)) {
2492  throw py::key_error("attempt to access a non-existent attribute");
2493  }
2494  return attr;
2495  }
2496 
2497  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2498  if (index < 0 || index >= dunderLen()) {
2499  throw py::index_error("attempt to access out of bounds attribute");
2500  }
2501  MlirNamedAttribute namedAttr =
2502  mlirOperationGetAttribute(operation->get(), index);
2503  return PyNamedAttribute(
2504  namedAttr.attribute,
2505  std::string(mlirIdentifierStr(namedAttr.name).data,
2506  mlirIdentifierStr(namedAttr.name).length));
2507  }
2508 
2509  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2510  mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2511  attr);
2512  }
2513 
2514  void dunderDelItem(const std::string &name) {
2515  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2516  toMlirStringRef(name));
2517  if (!removed)
2518  throw py::key_error("attempt to delete a non-existent attribute");
2519  }
2520 
2521  intptr_t dunderLen() {
2522  return mlirOperationGetNumAttributes(operation->get());
2523  }
2524 
2525  bool dunderContains(const std::string &name) {
2527  operation->get(), toMlirStringRef(name)));
2528  }
2529 
2530  static void bind(py::module &m) {
2531  py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2532  .def("__contains__", &PyOpAttributeMap::dunderContains)
2533  .def("__len__", &PyOpAttributeMap::dunderLen)
2534  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2535  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2536  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2537  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2538  }
2539 
2540 private:
2541  PyOperationRef operation;
2542 };
2543 
2544 } // namespace
2545 
2546 //------------------------------------------------------------------------------
2547 // Populates the core exports of the 'ir' submodule.
2548 //------------------------------------------------------------------------------
2549 
2550 void mlir::python::populateIRCore(py::module &m) {
2551  //----------------------------------------------------------------------------
2552  // Enums.
2553  //----------------------------------------------------------------------------
2554  py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2555  .value("ERROR", MlirDiagnosticError)
2556  .value("WARNING", MlirDiagnosticWarning)
2557  .value("NOTE", MlirDiagnosticNote)
2558  .value("REMARK", MlirDiagnosticRemark);
2559 
2560  py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2561  .value("PRE_ORDER", MlirWalkPreOrder)
2562  .value("POST_ORDER", MlirWalkPostOrder);
2563 
2564  py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2565  .value("ADVANCE", MlirWalkResultAdvance)
2566  .value("INTERRUPT", MlirWalkResultInterrupt)
2567  .value("SKIP", MlirWalkResultSkip);
2568 
2569  //----------------------------------------------------------------------------
2570  // Mapping of Diagnostics.
2571  //----------------------------------------------------------------------------
2572  py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2573  .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2574  .def_property_readonly("location", &PyDiagnostic::getLocation)
2575  .def_property_readonly("message", &PyDiagnostic::getMessage)
2576  .def_property_readonly("notes", &PyDiagnostic::getNotes)
2577  .def("__str__", [](PyDiagnostic &self) -> py::str {
2578  if (!self.isValid())
2579  return "<Invalid Diagnostic>";
2580  return self.getMessage();
2581  });
2582 
2583  py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
2584  py::module_local())
2585  .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
2586  .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
2587  .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
2588  .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
2589  .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
2590  .def("__str__",
2591  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2592 
2593  py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2594  .def("detach", &PyDiagnosticHandler::detach)
2595  .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2596  .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2597  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2598  .def("__exit__", &PyDiagnosticHandler::contextExit);
2599 
2600  //----------------------------------------------------------------------------
2601  // Mapping of MlirContext.
2602  // Note that this is exported as _BaseContext. The containing, Python level
2603  // __init__.py will subclass it with site-specific functionality and set a
2604  // "Context" attribute on this module.
2605  //----------------------------------------------------------------------------
2606  py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2607  .def(py::init<>(&PyMlirContext::createNewContextForInit))
2608  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2609  .def("_get_context_again",
2610  [](PyMlirContext &self) {
2611  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2612  return ref.releaseObject();
2613  })
2614  .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2615  .def("_get_live_operation_objects",
2616  &PyMlirContext::getLiveOperationObjects)
2617  .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2618  .def("_clear_live_operations_inside",
2619  py::overload_cast<MlirOperation>(
2620  &PyMlirContext::clearOperationsInside))
2621  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2622  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2623  &PyMlirContext::getCapsule)
2624  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2625  .def("__enter__", &PyMlirContext::contextEnter)
2626  .def("__exit__", &PyMlirContext::contextExit)
2627  .def_property_readonly_static(
2628  "current",
2629  [](py::object & /*class*/) {
2630  auto *context = PyThreadContextEntry::getDefaultContext();
2631  if (!context)
2632  return py::none().cast<py::object>();
2633  return py::cast(context);
2634  },
2635  "Gets the Context bound to the current thread or raises ValueError")
2636  .def_property_readonly(
2637  "dialects",
2638  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2639  "Gets a container for accessing dialects by name")
2640  .def_property_readonly(
2641  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2642  "Alias for 'dialect'")
2643  .def(
2644  "get_dialect_descriptor",
2645  [=](PyMlirContext &self, std::string &name) {
2646  MlirDialect dialect = mlirContextGetOrLoadDialect(
2647  self.get(), {name.data(), name.size()});
2648  if (mlirDialectIsNull(dialect)) {
2649  throw py::value_error(
2650  (Twine("Dialect '") + name + "' not found").str());
2651  }
2652  return PyDialectDescriptor(self.getRef(), dialect);
2653  },
2654  py::arg("dialect_name"),
2655  "Gets or loads a dialect by name, returning its descriptor object")
2656  .def_property(
2657  "allow_unregistered_dialects",
2658  [](PyMlirContext &self) -> bool {
2660  },
2661  [](PyMlirContext &self, bool value) {
2663  })
2664  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2665  py::arg("callback"),
2666  "Attaches a diagnostic handler that will receive callbacks")
2667  .def(
2668  "enable_multithreading",
2669  [](PyMlirContext &self, bool enable) {
2670  mlirContextEnableMultithreading(self.get(), enable);
2671  },
2672  py::arg("enable"))
2673  .def(
2674  "is_registered_operation",
2675  [](PyMlirContext &self, std::string &name) {
2677  self.get(), MlirStringRef{name.data(), name.size()});
2678  },
2679  py::arg("operation_name"))
2680  .def(
2681  "append_dialect_registry",
2682  [](PyMlirContext &self, PyDialectRegistry &registry) {
2683  mlirContextAppendDialectRegistry(self.get(), registry);
2684  },
2685  py::arg("registry"))
2686  .def_property("emit_error_diagnostics", nullptr,
2687  &PyMlirContext::setEmitErrorDiagnostics,
2688  "Emit error diagnostics to diagnostic handlers. By default "
2689  "error diagnostics are captured and reported through "
2690  "MLIRError exceptions.")
2691  .def("load_all_available_dialects", [](PyMlirContext &self) {
2693  });
2694 
2695  //----------------------------------------------------------------------------
2696  // Mapping of PyDialectDescriptor
2697  //----------------------------------------------------------------------------
2698  py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2699  .def_property_readonly("namespace",
2700  [](PyDialectDescriptor &self) {
2701  MlirStringRef ns =
2702  mlirDialectGetNamespace(self.get());
2703  return py::str(ns.data, ns.length);
2704  })
2705  .def("__repr__", [](PyDialectDescriptor &self) {
2707  std::string repr("<DialectDescriptor ");
2708  repr.append(ns.data, ns.length);
2709  repr.append(">");
2710  return repr;
2711  });
2712 
2713  //----------------------------------------------------------------------------
2714  // Mapping of PyDialects
2715  //----------------------------------------------------------------------------
2716  py::class_<PyDialects>(m, "Dialects", py::module_local())
2717  .def("__getitem__",
2718  [=](PyDialects &self, std::string keyName) {
2719  MlirDialect dialect =
2720  self.getDialectForKey(keyName, /*attrError=*/false);
2721  py::object descriptor =
2722  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2723  return createCustomDialectWrapper(keyName, std::move(descriptor));
2724  })
2725  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2726  MlirDialect dialect =
2727  self.getDialectForKey(attrName, /*attrError=*/true);
2728  py::object descriptor =
2729  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2730  return createCustomDialectWrapper(attrName, std::move(descriptor));
2731  });
2732 
2733  //----------------------------------------------------------------------------
2734  // Mapping of PyDialect
2735  //----------------------------------------------------------------------------
2736  py::class_<PyDialect>(m, "Dialect", py::module_local())
2737  .def(py::init<py::object>(), py::arg("descriptor"))
2738  .def_property_readonly(
2739  "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2740  .def("__repr__", [](py::object self) {
2741  auto clazz = self.attr("__class__");
2742  return py::str("<Dialect ") +
2743  self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2744  clazz.attr("__module__") + py::str(".") +
2745  clazz.attr("__name__") + py::str(")>");
2746  });
2747 
2748  //----------------------------------------------------------------------------
2749  // Mapping of PyDialectRegistry
2750  //----------------------------------------------------------------------------
2751  py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2752  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2753  &PyDialectRegistry::getCapsule)
2754  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2755  .def(py::init<>());
2756 
2757  //----------------------------------------------------------------------------
2758  // Mapping of Location
2759  //----------------------------------------------------------------------------
2760  py::class_<PyLocation>(m, "Location", py::module_local())
2761  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2762  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2763  .def("__enter__", &PyLocation::contextEnter)
2764  .def("__exit__", &PyLocation::contextExit)
2765  .def("__eq__",
2766  [](PyLocation &self, PyLocation &other) -> bool {
2767  return mlirLocationEqual(self, other);
2768  })
2769  .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2770  .def_property_readonly_static(
2771  "current",
2772  [](py::object & /*class*/) {
2773  auto *loc = PyThreadContextEntry::getDefaultLocation();
2774  if (!loc)
2775  throw py::value_error("No current Location");
2776  return loc;
2777  },
2778  "Gets the Location bound to the current thread or raises ValueError")
2779  .def_static(
2780  "unknown",
2781  [](DefaultingPyMlirContext context) {
2782  return PyLocation(context->getRef(),
2783  mlirLocationUnknownGet(context->get()));
2784  },
2785  py::arg("context") = py::none(),
2786  "Gets a Location representing an unknown location")
2787  .def_static(
2788  "callsite",
2789  [](PyLocation callee, const std::vector<PyLocation> &frames,
2790  DefaultingPyMlirContext context) {
2791  if (frames.empty())
2792  throw py::value_error("No caller frames provided");
2793  MlirLocation caller = frames.back().get();
2794  for (const PyLocation &frame :
2795  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2796  caller = mlirLocationCallSiteGet(frame.get(), caller);
2797  return PyLocation(context->getRef(),
2798  mlirLocationCallSiteGet(callee.get(), caller));
2799  },
2800  py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2802  .def_static(
2803  "file",
2804  [](std::string filename, int line, int col,
2805  DefaultingPyMlirContext context) {
2806  return PyLocation(
2807  context->getRef(),
2809  context->get(), toMlirStringRef(filename), line, col));
2810  },
2811  py::arg("filename"), py::arg("line"), py::arg("col"),
2812  py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2813  .def_static(
2814  "fused",
2815  [](const std::vector<PyLocation> &pyLocations,
2816  std::optional<PyAttribute> metadata,
2817  DefaultingPyMlirContext context) {
2819  locations.reserve(pyLocations.size());
2820  for (auto &pyLocation : pyLocations)
2821  locations.push_back(pyLocation.get());
2822  MlirLocation location = mlirLocationFusedGet(
2823  context->get(), locations.size(), locations.data(),
2824  metadata ? metadata->get() : MlirAttribute{0});
2825  return PyLocation(context->getRef(), location);
2826  },
2827  py::arg("locations"), py::arg("metadata") = py::none(),
2828  py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2829  .def_static(
2830  "name",
2831  [](std::string name, std::optional<PyLocation> childLoc,
2832  DefaultingPyMlirContext context) {
2833  return PyLocation(
2834  context->getRef(),
2836  context->get(), toMlirStringRef(name),
2837  childLoc ? childLoc->get()
2838  : mlirLocationUnknownGet(context->get())));
2839  },
2840  py::arg("name"), py::arg("childLoc") = py::none(),
2841  py::arg("context") = py::none(), kContextGetNameLocationDocString)
2842  .def_static(
2843  "from_attr",
2844  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2845  return PyLocation(context->getRef(),
2846  mlirLocationFromAttribute(attribute));
2847  },
2848  py::arg("attribute"), py::arg("context") = py::none(),
2849  "Gets a Location from a LocationAttr")
2850  .def_property_readonly(
2851  "context",
2852  [](PyLocation &self) { return self.getContext().getObject(); },
2853  "Context that owns the Location")
2854  .def_property_readonly(
2855  "attr",
2856  [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2857  "Get the underlying LocationAttr")
2858  .def(
2859  "emit_error",
2860  [](PyLocation &self, std::string message) {
2861  mlirEmitError(self, message.c_str());
2862  },
2863  py::arg("message"), "Emits an error at this location")
2864  .def("__repr__", [](PyLocation &self) {
2865  PyPrintAccumulator printAccum;
2866  mlirLocationPrint(self, printAccum.getCallback(),
2867  printAccum.getUserData());
2868  return printAccum.join();
2869  });
2870 
2871  //----------------------------------------------------------------------------
2872  // Mapping of Module
2873  //----------------------------------------------------------------------------
2874  py::class_<PyModule>(m, "Module", py::module_local())
2875  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2876  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2877  .def_static(
2878  "parse",
2879  [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2880  PyMlirContext::ErrorCapture errors(context->getRef());
2881  MlirModule module = mlirModuleCreateParse(
2882  context->get(), toMlirStringRef(moduleAsm));
2883  if (mlirModuleIsNull(module))
2884  throw MLIRError("Unable to parse module assembly", errors.take());
2885  return PyModule::forModule(module).releaseObject();
2886  },
2887  py::arg("asm"), py::arg("context") = py::none(),
2889  .def_static(
2890  "create",
2891  [](DefaultingPyLocation loc) {
2892  MlirModule module = mlirModuleCreateEmpty(loc);
2893  return PyModule::forModule(module).releaseObject();
2894  },
2895  py::arg("loc") = py::none(), "Creates an empty module")
2896  .def_property_readonly(
2897  "context",
2898  [](PyModule &self) { return self.getContext().getObject(); },
2899  "Context that created the Module")
2900  .def_property_readonly(
2901  "operation",
2902  [](PyModule &self) {
2903  return PyOperation::forOperation(self.getContext(),
2904  mlirModuleGetOperation(self.get()),
2905  self.getRef().releaseObject())
2906  .releaseObject();
2907  },
2908  "Accesses the module as an operation")
2909  .def_property_readonly(
2910  "body",
2911  [](PyModule &self) {
2912  PyOperationRef moduleOp = PyOperation::forOperation(
2913  self.getContext(), mlirModuleGetOperation(self.get()),
2914  self.getRef().releaseObject());
2915  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2916  return returnBlock;
2917  },
2918  "Return the block for this module")
2919  .def(
2920  "dump",
2921  [](PyModule &self) {
2923  },
2925  .def(
2926  "__str__",
2927  [](py::object self) {
2928  // Defer to the operation's __str__.
2929  return self.attr("operation").attr("__str__")();
2930  },
2932 
2933  //----------------------------------------------------------------------------
2934  // Mapping of Operation.
2935  //----------------------------------------------------------------------------
2936  py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2937  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2938  [](PyOperationBase &self) {
2939  return self.getOperation().getCapsule();
2940  })
2941  .def("__eq__",
2942  [](PyOperationBase &self, PyOperationBase &other) {
2943  return &self.getOperation() == &other.getOperation();
2944  })
2945  .def("__eq__",
2946  [](PyOperationBase &self, py::object other) { return false; })
2947  .def("__hash__",
2948  [](PyOperationBase &self) {
2949  return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2950  })
2951  .def_property_readonly("attributes",
2952  [](PyOperationBase &self) {
2953  return PyOpAttributeMap(
2954  self.getOperation().getRef());
2955  })
2956  .def_property_readonly(
2957  "context",
2958  [](PyOperationBase &self) {
2959  PyOperation &concreteOperation = self.getOperation();
2960  concreteOperation.checkValid();
2961  return concreteOperation.getContext().getObject();
2962  },
2963  "Context that owns the Operation")
2964  .def_property_readonly("name",
2965  [](PyOperationBase &self) {
2966  auto &concreteOperation = self.getOperation();
2967  concreteOperation.checkValid();
2968  MlirOperation operation =
2969  concreteOperation.get();
2971  mlirOperationGetName(operation));
2972  return py::str(name.data, name.length);
2973  })
2974  .def_property_readonly("operands",
2975  [](PyOperationBase &self) {
2976  return PyOpOperandList(
2977  self.getOperation().getRef());
2978  })
2979  .def_property_readonly("regions",
2980  [](PyOperationBase &self) {
2981  return PyRegionList(
2982  self.getOperation().getRef());
2983  })
2984  .def_property_readonly(
2985  "results",
2986  [](PyOperationBase &self) {
2987  return PyOpResultList(self.getOperation().getRef());
2988  },
2989  "Returns the list of Operation results.")
2990  .def_property_readonly(
2991  "result",
2992  [](PyOperationBase &self) {
2993  auto &operation = self.getOperation();
2994  auto numResults = mlirOperationGetNumResults(operation);
2995  if (numResults != 1) {
2996  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2997  throw py::value_error(
2998  (Twine("Cannot call .result on operation ") +
2999  StringRef(name.data, name.length) + " which has " +
3000  Twine(numResults) +
3001  " results (it is only valid for operations with a "
3002  "single result)")
3003  .str());
3004  }
3005  return PyOpResult(operation.getRef(),
3006  mlirOperationGetResult(operation, 0))
3007  .maybeDownCast();
3008  },
3009  "Shortcut to get an op result if it has only one (throws an error "
3010  "otherwise).")
3011  .def_property_readonly(
3012  "location",
3013  [](PyOperationBase &self) {
3014  PyOperation &operation = self.getOperation();
3015  return PyLocation(operation.getContext(),
3016  mlirOperationGetLocation(operation.get()));
3017  },
3018  "Returns the source location the operation was defined or derived "
3019  "from.")
3020  .def_property_readonly("parent",
3021  [](PyOperationBase &self) -> py::object {
3022  auto parent =
3023  self.getOperation().getParentOperation();
3024  if (parent)
3025  return parent->getObject();
3026  return py::none();
3027  })
3028  .def(
3029  "__str__",
3030  [](PyOperationBase &self) {
3031  return self.getAsm(/*binary=*/false,
3032  /*largeElementsLimit=*/std::nullopt,
3033  /*enableDebugInfo=*/false,
3034  /*prettyDebugInfo=*/false,
3035  /*printGenericOpForm=*/false,
3036  /*useLocalScope=*/false,
3037  /*assumeVerified=*/false);
3038  },
3039  "Returns the assembly form of the operation.")
3040  .def("print",
3041  py::overload_cast<PyAsmState &, pybind11::object, bool>(
3043  py::arg("state"), py::arg("file") = py::none(),
3044  py::arg("binary") = false, kOperationPrintStateDocstring)
3045  .def("print",
3046  py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3047  bool, py::object, bool>(&PyOperationBase::print),
3048  // Careful: Lots of arguments must match up with print method.
3049  py::arg("large_elements_limit") = py::none(),
3050  py::arg("enable_debug_info") = false,
3051  py::arg("pretty_debug_info") = false,
3052  py::arg("print_generic_op_form") = false,
3053  py::arg("use_local_scope") = false,
3054  py::arg("assume_verified") = false, py::arg("file") = py::none(),
3055  py::arg("binary") = false, kOperationPrintDocstring)
3056  .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
3057  py::arg("desired_version") = py::none(),
3059  .def("get_asm", &PyOperationBase::getAsm,
3060  // Careful: Lots of arguments must match up with get_asm method.
3061  py::arg("binary") = false,
3062  py::arg("large_elements_limit") = py::none(),
3063  py::arg("enable_debug_info") = false,
3064  py::arg("pretty_debug_info") = false,
3065  py::arg("print_generic_op_form") = false,
3066  py::arg("use_local_scope") = false,
3067  py::arg("assume_verified") = false, kOperationGetAsmDocstring)
3068  .def("verify", &PyOperationBase::verify,
3069  "Verify the operation. Raises MLIRError if verification fails, and "
3070  "returns true otherwise.")
3071  .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
3072  "Puts self immediately after the other operation in its parent "
3073  "block.")
3074  .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
3075  "Puts self immediately before the other operation in its parent "
3076  "block.")
3077  .def(
3078  "clone",
3079  [](PyOperationBase &self, py::object ip) {
3080  return self.getOperation().clone(ip);
3081  },
3082  py::arg("ip") = py::none())
3083  .def(
3084  "detach_from_parent",
3085  [](PyOperationBase &self) {
3086  PyOperation &operation = self.getOperation();
3087  operation.checkValid();
3088  if (!operation.isAttached())
3089  throw py::value_error("Detached operation has no parent.");
3090 
3091  operation.detachFromParent();
3092  return operation.createOpView();
3093  },
3094  "Detaches the operation from its parent block.")
3095  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3096  .def("walk", &PyOperationBase::walk, py::arg("callback"),
3097  py::arg("walk_order") = MlirWalkPostOrder);
3098 
3099  py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
3100  .def_static("create", &PyOperation::create, py::arg("name"),
3101  py::arg("results") = py::none(),
3102  py::arg("operands") = py::none(),
3103  py::arg("attributes") = py::none(),
3104  py::arg("successors") = py::none(), py::arg("regions") = 0,
3105  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3106  py::arg("infer_type") = false, kOperationCreateDocstring)
3107  .def_static(
3108  "parse",
3109  [](const std::string &sourceStr, const std::string &sourceName,
3110  DefaultingPyMlirContext context) {
3111  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3112  ->createOpView();
3113  },
3114  py::arg("source"), py::kw_only(), py::arg("source_name") = "",
3115  py::arg("context") = py::none(),
3116  "Parses an operation. Supports both text assembly format and binary "
3117  "bytecode format.")
3118  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3119  &PyOperation::getCapsule)
3120  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3121  .def_property_readonly("operation", [](py::object self) { return self; })
3122  .def_property_readonly("opview", &PyOperation::createOpView)
3123  .def_property_readonly(
3124  "successors",
3125  [](PyOperationBase &self) {
3126  return PyOpSuccessors(self.getOperation().getRef());
3127  },
3128  "Returns the list of Operation successors.");
3129 
3130  auto opViewClass =
3131  py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
3132  .def(py::init<py::object>(), py::arg("operation"))
3133  .def_property_readonly("operation", &PyOpView::getOperationObject)
3134  .def_property_readonly("opview", [](py::object self) { return self; })
3135  .def(
3136  "__str__",
3137  [](PyOpView &self) { return py::str(self.getOperationObject()); })
3138  .def_property_readonly(
3139  "successors",
3140  [](PyOperationBase &self) {
3141  return PyOpSuccessors(self.getOperation().getRef());
3142  },
3143  "Returns the list of Operation successors.");
3144  opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3145  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3146  opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3147  opViewClass.attr("build_generic") = classmethod(
3148  &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3149  py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3150  py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3151  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3152  "Builds a specific, generated OpView based on class level attributes.");
3153  opViewClass.attr("parse") = classmethod(
3154  [](const py::object &cls, const std::string &sourceStr,
3155  const std::string &sourceName, DefaultingPyMlirContext context) {
3156  PyOperationRef parsed =
3157  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3158 
3159  // Check if the expected operation was parsed, and cast to to the
3160  // appropriate `OpView` subclass if successful.
3161  // NOTE: This accesses attributes that have been automatically added to
3162  // `OpView` subclasses, and is not intended to be used on `OpView`
3163  // directly.
3164  std::string clsOpName =
3165  py::cast<std::string>(cls.attr("OPERATION_NAME"));
3166  MlirStringRef identifier =
3168  std::string_view parsedOpName(identifier.data, identifier.length);
3169  if (clsOpName != parsedOpName)
3170  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3171  parsedOpName + "'");
3172  return PyOpView::constructDerived(cls, *parsed.get());
3173  },
3174  py::arg("cls"), py::arg("source"), py::kw_only(),
3175  py::arg("source_name") = "", py::arg("context") = py::none(),
3176  "Parses a specific, generated OpView based on class level attributes");
3177 
3178  //----------------------------------------------------------------------------
3179  // Mapping of PyRegion.
3180  //----------------------------------------------------------------------------
3181  py::class_<PyRegion>(m, "Region", py::module_local())
3182  .def_property_readonly(
3183  "blocks",
3184  [](PyRegion &self) {
3185  return PyBlockList(self.getParentOperation(), self.get());
3186  },
3187  "Returns a forward-optimized sequence of blocks.")
3188  .def_property_readonly(
3189  "owner",
3190  [](PyRegion &self) {
3191  return self.getParentOperation()->createOpView();
3192  },
3193  "Returns the operation owning this region.")
3194  .def(
3195  "__iter__",
3196  [](PyRegion &self) {
3197  self.checkValid();
3198  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3199  return PyBlockIterator(self.getParentOperation(), firstBlock);
3200  },
3201  "Iterates over blocks in the region.")
3202  .def("__eq__",
3203  [](PyRegion &self, PyRegion &other) {
3204  return self.get().ptr == other.get().ptr;
3205  })
3206  .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3207 
3208  //----------------------------------------------------------------------------
3209  // Mapping of PyBlock.
3210  //----------------------------------------------------------------------------
3211  py::class_<PyBlock>(m, "Block", py::module_local())
3212  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3213  .def_property_readonly(
3214  "owner",
3215  [](PyBlock &self) {
3216  return self.getParentOperation()->createOpView();
3217  },
3218  "Returns the owning operation of this block.")
3219  .def_property_readonly(
3220  "region",
3221  [](PyBlock &self) {
3222  MlirRegion region = mlirBlockGetParentRegion(self.get());
3223  return PyRegion(self.getParentOperation(), region);
3224  },
3225  "Returns the owning region of this block.")
3226  .def_property_readonly(
3227  "arguments",
3228  [](PyBlock &self) {
3229  return PyBlockArgumentList(self.getParentOperation(), self.get());
3230  },
3231  "Returns a list of block arguments.")
3232  .def_property_readonly(
3233  "operations",
3234  [](PyBlock &self) {
3235  return PyOperationList(self.getParentOperation(), self.get());
3236  },
3237  "Returns a forward-optimized sequence of operations.")
3238  .def_static(
3239  "create_at_start",
3240  [](PyRegion &parent, const py::list &pyArgTypes,
3241  const std::optional<py::sequence> &pyArgLocs) {
3242  parent.checkValid();
3243  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3244  mlirRegionInsertOwnedBlock(parent, 0, block);
3245  return PyBlock(parent.getParentOperation(), block);
3246  },
3247  py::arg("parent"), py::arg("arg_types") = py::list(),
3248  py::arg("arg_locs") = std::nullopt,
3249  "Creates and returns a new Block at the beginning of the given "
3250  "region (with given argument types and locations).")
3251  .def(
3252  "append_to",
3253  [](PyBlock &self, PyRegion &region) {
3254  MlirBlock b = self.get();
3256  mlirBlockDetach(b);
3257  mlirRegionAppendOwnedBlock(region.get(), b);
3258  },
3259  "Append this block to a region, transferring ownership if necessary")
3260  .def(
3261  "create_before",
3262  [](PyBlock &self, const py::args &pyArgTypes,
3263  const std::optional<py::sequence> &pyArgLocs) {
3264  self.checkValid();
3265  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3266  MlirRegion region = mlirBlockGetParentRegion(self.get());
3267  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3268  return PyBlock(self.getParentOperation(), block);
3269  },
3270  py::arg("arg_locs") = std::nullopt,
3271  "Creates and returns a new Block before this block "
3272  "(with given argument types and locations).")
3273  .def(
3274  "create_after",
3275  [](PyBlock &self, const py::args &pyArgTypes,
3276  const std::optional<py::sequence> &pyArgLocs) {
3277  self.checkValid();
3278  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3279  MlirRegion region = mlirBlockGetParentRegion(self.get());
3280  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3281  return PyBlock(self.getParentOperation(), block);
3282  },
3283  py::arg("arg_locs") = std::nullopt,
3284  "Creates and returns a new Block after this block "
3285  "(with given argument types and locations).")
3286  .def(
3287  "__iter__",
3288  [](PyBlock &self) {
3289  self.checkValid();
3290  MlirOperation firstOperation =
3292  return PyOperationIterator(self.getParentOperation(),
3293  firstOperation);
3294  },
3295  "Iterates over operations in the block.")
3296  .def("__eq__",
3297  [](PyBlock &self, PyBlock &other) {
3298  return self.get().ptr == other.get().ptr;
3299  })
3300  .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3301  .def("__hash__",
3302  [](PyBlock &self) {
3303  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3304  })
3305  .def(
3306  "__str__",
3307  [](PyBlock &self) {
3308  self.checkValid();
3309  PyPrintAccumulator printAccum;
3310  mlirBlockPrint(self.get(), printAccum.getCallback(),
3311  printAccum.getUserData());
3312  return printAccum.join();
3313  },
3314  "Returns the assembly form of the block.")
3315  .def(
3316  "append",
3317  [](PyBlock &self, PyOperationBase &operation) {
3318  if (operation.getOperation().isAttached())
3319  operation.getOperation().detachFromParent();
3320 
3321  MlirOperation mlirOperation = operation.getOperation().get();
3322  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3323  operation.getOperation().setAttached(
3324  self.getParentOperation().getObject());
3325  },
3326  py::arg("operation"),
3327  "Appends an operation to this block. If the operation is currently "
3328  "in another block, it will be moved.");
3329 
3330  //----------------------------------------------------------------------------
3331  // Mapping of PyInsertionPoint.
3332  //----------------------------------------------------------------------------
3333 
3334  py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
3335  .def(py::init<PyBlock &>(), py::arg("block"),
3336  "Inserts after the last operation but still inside the block.")
3337  .def("__enter__", &PyInsertionPoint::contextEnter)
3338  .def("__exit__", &PyInsertionPoint::contextExit)
3339  .def_property_readonly_static(
3340  "current",
3341  [](py::object & /*class*/) {
3342  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3343  if (!ip)
3344  throw py::value_error("No current InsertionPoint");
3345  return ip;
3346  },
3347  "Gets the InsertionPoint bound to the current thread or raises "
3348  "ValueError if none has been set")
3349  .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3350  "Inserts before a referenced operation.")
3351  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3352  py::arg("block"), "Inserts at the beginning of the block.")
3353  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3354  py::arg("block"), "Inserts before the block terminator.")
3355  .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3356  "Inserts an operation.")
3357  .def_property_readonly(
3358  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3359  "Returns the block that this InsertionPoint points to.")
3360  .def_property_readonly(
3361  "ref_operation",
3362  [](PyInsertionPoint &self) -> py::object {
3363  auto refOperation = self.getRefOperation();
3364  if (refOperation)
3365  return refOperation->getObject();
3366  return py::none();
3367  },
3368  "The reference operation before which new operations are "
3369  "inserted, or None if the insertion point is at the end of "
3370  "the block");
3371 
3372  //----------------------------------------------------------------------------
3373  // Mapping of PyAttribute.
3374  //----------------------------------------------------------------------------
3375  py::class_<PyAttribute>(m, "Attribute", py::module_local())
3376  // Delegate to the PyAttribute copy constructor, which will also lifetime
3377  // extend the backing context which owns the MlirAttribute.
3378  .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
3379  "Casts the passed attribute to the generic Attribute")
3380  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3381  &PyAttribute::getCapsule)
3382  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3383  .def_static(
3384  "parse",
3385  [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3386  PyMlirContext::ErrorCapture errors(context->getRef());
3387  MlirAttribute attr = mlirAttributeParseGet(
3388  context->get(), toMlirStringRef(attrSpec));
3389  if (mlirAttributeIsNull(attr))
3390  throw MLIRError("Unable to parse attribute", errors.take());
3391  return attr;
3392  },
3393  py::arg("asm"), py::arg("context") = py::none(),
3394  "Parses an attribute from an assembly form. Raises an MLIRError on "
3395  "failure.")
3396  .def_property_readonly(
3397  "context",
3398  [](PyAttribute &self) { return self.getContext().getObject(); },
3399  "Context that owns the Attribute")
3400  .def_property_readonly(
3401  "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
3402  .def(
3403  "get_named",
3404  [](PyAttribute &self, std::string name) {
3405  return PyNamedAttribute(self, std::move(name));
3406  },
3407  py::keep_alive<0, 1>(), "Binds a name to the attribute")
3408  .def("__eq__",
3409  [](PyAttribute &self, PyAttribute &other) { return self == other; })
3410  .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3411  .def("__hash__",
3412  [](PyAttribute &self) {
3413  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3414  })
3415  .def(
3416  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3418  .def(
3419  "__str__",
3420  [](PyAttribute &self) {
3421  PyPrintAccumulator printAccum;
3422  mlirAttributePrint(self, printAccum.getCallback(),
3423  printAccum.getUserData());
3424  return printAccum.join();
3425  },
3426  "Returns the assembly form of the Attribute.")
3427  .def("__repr__",
3428  [](PyAttribute &self) {
3429  // Generally, assembly formats are not printed for __repr__ because
3430  // this can cause exceptionally long debug output and exceptions.
3431  // However, attribute values are generally considered useful and
3432  // are printed. This may need to be re-evaluated if debug dumps end
3433  // up being excessive.
3434  PyPrintAccumulator printAccum;
3435  printAccum.parts.append("Attribute(");
3436  mlirAttributePrint(self, printAccum.getCallback(),
3437  printAccum.getUserData());
3438  printAccum.parts.append(")");
3439  return printAccum.join();
3440  })
3441  .def_property_readonly(
3442  "typeid",
3443  [](PyAttribute &self) -> MlirTypeID {
3444  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3445  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3446  "mlirTypeID was expected to be non-null.");
3447  return mlirTypeID;
3448  })
3450  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3451  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3452  "mlirTypeID was expected to be non-null.");
3453  std::optional<pybind11::function> typeCaster =
3454  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3455  mlirAttributeGetDialect(self));
3456  if (!typeCaster)
3457  return py::cast(self);
3458  return typeCaster.value()(self);
3459  });
3460 
3461  //----------------------------------------------------------------------------
3462  // Mapping of PyNamedAttribute
3463  //----------------------------------------------------------------------------
3464  py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3465  .def("__repr__",
3466  [](PyNamedAttribute &self) {
3467  PyPrintAccumulator printAccum;
3468  printAccum.parts.append("NamedAttribute(");
3469  printAccum.parts.append(
3470  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3471  mlirIdentifierStr(self.namedAttr.name).length));
3472  printAccum.parts.append("=");
3473  mlirAttributePrint(self.namedAttr.attribute,
3474  printAccum.getCallback(),
3475  printAccum.getUserData());
3476  printAccum.parts.append(")");
3477  return printAccum.join();
3478  })
3479  .def_property_readonly(
3480  "name",
3481  [](PyNamedAttribute &self) {
3482  return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3483  mlirIdentifierStr(self.namedAttr.name).length);
3484  },
3485  "The name of the NamedAttribute binding")
3486  .def_property_readonly(
3487  "attr",
3488  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3489  py::keep_alive<0, 1>(),
3490  "The underlying generic attribute of the NamedAttribute binding");
3491 
3492  //----------------------------------------------------------------------------
3493  // Mapping of PyType.
3494  //----------------------------------------------------------------------------
3495  py::class_<PyType>(m, "Type", py::module_local())
3496  // Delegate to the PyType copy constructor, which will also lifetime
3497  // extend the backing context which owns the MlirType.
3498  .def(py::init<PyType &>(), py::arg("cast_from_type"),
3499  "Casts the passed type to the generic Type")
3500  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3501  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3502  .def_static(
3503  "parse",
3504  [](std::string typeSpec, DefaultingPyMlirContext context) {
3505  PyMlirContext::ErrorCapture errors(context->getRef());
3506  MlirType type =
3507  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3508  if (mlirTypeIsNull(type))
3509  throw MLIRError("Unable to parse type", errors.take());
3510  return type;
3511  },
3512  py::arg("asm"), py::arg("context") = py::none(),
3514  .def_property_readonly(
3515  "context", [](PyType &self) { return self.getContext().getObject(); },
3516  "Context that owns the Type")
3517  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3518  .def("__eq__", [](PyType &self, py::object &other) { return false; })
3519  .def("__hash__",
3520  [](PyType &self) {
3521  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3522  })
3523  .def(
3524  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3525  .def(
3526  "__str__",
3527  [](PyType &self) {
3528  PyPrintAccumulator printAccum;
3529  mlirTypePrint(self, printAccum.getCallback(),
3530  printAccum.getUserData());
3531  return printAccum.join();
3532  },
3533  "Returns the assembly form of the type.")
3534  .def("__repr__",
3535  [](PyType &self) {
3536  // Generally, assembly formats are not printed for __repr__ because
3537  // this can cause exceptionally long debug output and exceptions.
3538  // However, types are an exception as they typically have compact
3539  // assembly forms and printing them is useful.
3540  PyPrintAccumulator printAccum;
3541  printAccum.parts.append("Type(");
3542  mlirTypePrint(self, printAccum.getCallback(),
3543  printAccum.getUserData());
3544  printAccum.parts.append(")");
3545  return printAccum.join();
3546  })
3548  [](PyType &self) {
3549  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3550  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3551  "mlirTypeID was expected to be non-null.");
3552  std::optional<pybind11::function> typeCaster =
3553  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3554  mlirTypeGetDialect(self));
3555  if (!typeCaster)
3556  return py::cast(self);
3557  return typeCaster.value()(self);
3558  })
3559  .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
3560  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3561  if (!mlirTypeIDIsNull(mlirTypeID))
3562  return mlirTypeID;
3563  auto origRepr =
3564  pybind11::repr(pybind11::cast(self)).cast<std::string>();
3565  throw py::value_error(
3566  (origRepr + llvm::Twine(" has no typeid.")).str());
3567  });
3568 
3569  //----------------------------------------------------------------------------
3570  // Mapping of PyTypeID.
3571  //----------------------------------------------------------------------------
3572  py::class_<PyTypeID>(m, "TypeID", py::module_local())
3573  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3574  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3575  // Note, this tests whether the underlying TypeIDs are the same,
3576  // not whether the wrapper MlirTypeIDs are the same, nor whether
3577  // the Python objects are the same (i.e., PyTypeID is a value type).
3578  .def("__eq__",
3579  [](PyTypeID &self, PyTypeID &other) { return self == other; })
3580  .def("__eq__",
3581  [](PyTypeID &self, const py::object &other) { return false; })
3582  // Note, this gives the hash value of the underlying TypeID, not the
3583  // hash value of the Python object, nor the hash value of the
3584  // MlirTypeID wrapper.
3585  .def("__hash__", [](PyTypeID &self) {
3586  return static_cast<size_t>(mlirTypeIDHashValue(self));
3587  });
3588 
3589  //----------------------------------------------------------------------------
3590  // Mapping of Value.
3591  //----------------------------------------------------------------------------
3592  py::class_<PyValue>(m, "Value", py::module_local())
3593  .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
3594  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3595  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3596  .def_property_readonly(
3597  "context",
3598  [](PyValue &self) { return self.getParentOperation()->getContext(); },
3599  "Context in which the value lives.")
3600  .def(
3601  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3603  .def_property_readonly(
3604  "owner",
3605  [](PyValue &self) -> py::object {
3606  MlirValue v = self.get();
3607  if (mlirValueIsAOpResult(v)) {
3608  assert(
3609  mlirOperationEqual(self.getParentOperation()->get(),
3610  mlirOpResultGetOwner(self.get())) &&
3611  "expected the owner of the value in Python to match that in "
3612  "the IR");
3613  return self.getParentOperation().getObject();
3614  }
3615 
3616  if (mlirValueIsABlockArgument(v)) {
3617  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3618  return py::cast(PyBlock(self.getParentOperation(), block));
3619  }
3620 
3621  assert(false && "Value must be a block argument or an op result");
3622  return py::none();
3623  })
3624  .def_property_readonly("uses",
3625  [](PyValue &self) {
3626  return PyOpOperandIterator(
3627  mlirValueGetFirstUse(self.get()));
3628  })
3629  .def("__eq__",
3630  [](PyValue &self, PyValue &other) {
3631  return self.get().ptr == other.get().ptr;
3632  })
3633  .def("__eq__", [](PyValue &self, py::object other) { return false; })
3634  .def("__hash__",
3635  [](PyValue &self) {
3636  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3637  })
3638  .def(
3639  "__str__",
3640  [](PyValue &self) {
3641  PyPrintAccumulator printAccum;
3642  printAccum.parts.append("Value(");
3643  mlirValuePrint(self.get(), printAccum.getCallback(),
3644  printAccum.getUserData());
3645  printAccum.parts.append(")");
3646  return printAccum.join();
3647  },
3649  .def(
3650  "get_name",
3651  [](PyValue &self, bool useLocalScope) {
3652  PyPrintAccumulator printAccum;
3653  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3654  if (useLocalScope)
3656  MlirAsmState valueState =
3657  mlirAsmStateCreateForValue(self.get(), flags);
3658  mlirValuePrintAsOperand(self.get(), valueState,
3659  printAccum.getCallback(),
3660  printAccum.getUserData());
3662  mlirAsmStateDestroy(valueState);
3663  return printAccum.join();
3664  },
3665  py::arg("use_local_scope") = false)
3666  .def(
3667  "get_name",
3668  [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
3669  PyPrintAccumulator printAccum;
3670  MlirAsmState valueState = state.get().get();
3671  mlirValuePrintAsOperand(self.get(), valueState,
3672  printAccum.getCallback(),
3673  printAccum.getUserData());
3674  return printAccum.join();
3675  },
3676  py::arg("state"), kGetNameAsOperand)
3677  .def_property_readonly(
3678  "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
3679  .def(
3680  "set_type",
3681  [](PyValue &self, const PyType &type) {
3682  return mlirValueSetType(self.get(), type);
3683  },
3684  py::arg("type"))
3685  .def(
3686  "replace_all_uses_with",
3687  [](PyValue &self, PyValue &with) {
3688  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3689  },
3692  [](PyValue &self) { return self.maybeDownCast(); });
3693  PyBlockArgument::bind(m);
3694  PyOpResult::bind(m);
3695  PyOpOperand::bind(m);
3696 
3697  py::class_<PyAsmState>(m, "AsmState", py::module_local())
3698  .def(py::init<PyValue &, bool>(), py::arg("value"),
3699  py::arg("use_local_scope") = false)
3700  .def(py::init<PyOperationBase &, bool>(), py::arg("op"),
3701  py::arg("use_local_scope") = false);
3702 
3703  //----------------------------------------------------------------------------
3704  // Mapping of SymbolTable.
3705  //----------------------------------------------------------------------------
3706  py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3707  .def(py::init<PyOperationBase &>())
3708  .def("__getitem__", &PySymbolTable::dunderGetItem)
3709  .def("insert", &PySymbolTable::insert, py::arg("operation"))
3710  .def("erase", &PySymbolTable::erase, py::arg("operation"))
3711  .def("__delitem__", &PySymbolTable::dunderDel)
3712  .def("__contains__",
3713  [](PySymbolTable &table, const std::string &name) {
3715  table, mlirStringRefCreate(name.data(), name.length())));
3716  })
3717  // Static helpers.
3718  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3719  py::arg("symbol"), py::arg("name"))
3720  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3721  py::arg("symbol"))
3722  .def_static("get_visibility", &PySymbolTable::getVisibility,
3723  py::arg("symbol"))
3724  .def_static("set_visibility", &PySymbolTable::setVisibility,
3725  py::arg("symbol"), py::arg("visibility"))
3726  .def_static("replace_all_symbol_uses",
3727  &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3728  py::arg("new_symbol"), py::arg("from_op"))
3729  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3730  py::arg("from_op"), py::arg("all_sym_uses_visible"),
3731  py::arg("callback"));
3732 
3733  // Container bindings.
3734  PyBlockArgumentList::bind(m);
3735  PyBlockIterator::bind(m);
3736  PyBlockList::bind(m);
3737  PyOperationIterator::bind(m);
3738  PyOperationList::bind(m);
3739  PyOpAttributeMap::bind(m);
3740  PyOpOperandIterator::bind(m);
3741  PyOpOperandList::bind(m);
3742  PyOpResultList::bind(m);
3743  PyOpSuccessors::bind(m);
3744  PyRegionIterator::bind(m);
3745  PyRegionList::bind(m);
3746 
3747  // Debug bindings.
3749 
3750  // Attribute builder getter.
3752 
3753  py::register_local_exception_translator([](std::exception_ptr p) {
3754  // We can't define exceptions with custom fields through pybind, so instead
3755  // the exception class is defined in python and imported here.
3756  try {
3757  if (p)
3758  std::rethrow_exception(p);
3759  } catch (const MLIRError &e) {
3760  py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
3761  .attr("MLIRError")(e.message, e.errorDiagnostics);
3762  PyErr_SetObject(PyExc_Exception, obj.ptr());
3763  }
3764  });
3765 }
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition: Debug.cpp:20
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools.
Definition: Debug.cpp:28
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:113
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:175
static py::object createCustomDialectWrapper(const std::string &dialectNamespace, py::object dialectDescriptor)
Definition: IRCore.cpp:192
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:57
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:171
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:204
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:60
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:145
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:87
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:51
static const char kDumpDocstring[]
Definition: IRCore.cpp:153
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:156
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:54
static void maybeInsertOperation(PyOperationRef &op, const py::object &maybeIp)
Definition: IRCore.cpp:1380
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:210
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:135
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:122
static void populateResultTypes(StringRef name, py::list resultTypeList, const py::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1557
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:68
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:40
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:48
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:163
py::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:186
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:272
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:117
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:327
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:317
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:96
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:188
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:179
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:254
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:109
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:281
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:223
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:336
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:234
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:346
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:244
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:56
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:355
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:433
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:197
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:309
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:263
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:424
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:215
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:125
MlirStringCallback getCallback()
Definition: PybindUtils.h:132
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:209
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:295
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:303
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:510
static PyLocation & resolve()
Definition: IRCore.cpp:1053
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:284
static PyMlirContext & resolve()
Definition: IRCore.cpp:777
ReferrentTy * get() const
Definition: PybindUtils.h:47
Wrapper around an MlirAsmState.
Definition: IRModule.h:776
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:993
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:995
static PyAttribute createFromCapsule(pybind11::object capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:1920
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:1916
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:1912
Wrapper around an MlirBlock.
Definition: IRModule.h:811
MlirBlock get()
Definition: IRModule.h:818
PyOperationRef & getParentOperation()
Definition: IRModule.h:819
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:391
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:939
PyDiagnosticHandler(MlirContext context, pybind11::object callback)
Definition: IRCore.cpp:933
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:341
pybind11::str getMessage()
Definition: IRCore.cpp:969
PyLocation getLocation()
Definition: IRCore.cpp:962
DiagnosticInfo getInfo()
Definition: IRCore.cpp:990
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:343
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:957
pybind11::tuple getNotes()
Definition: IRCore.cpp:977
Wrapper around an MlirDialect.
Definition: IRModule.h:446
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:483
static PyDialectRegistry createFromCapsule(pybind11::object capsule)
Definition: IRCore.cpp:1019
pybind11::object getCapsule()
Definition: IRCore.cpp:1014
User-level dialect object.
Definition: IRModule.h:470
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:459
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:1001
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:34
std::optional< pybind11::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:172
std::optional< pybind11::function > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:145
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:835
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:1889
PyInsertionPoint(PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:1844
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:1876
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:1850
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1902
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1898
Wrapper around an MlirLocation.
Definition: IRModule.h:310
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:312
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1031
static PyLocation createFromCapsule(pybind11::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1035
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1047
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1043
MlirLocation get() const
Definition: IRModule.h:316
pybind11::object attachDiagnosticHandler(pybind11::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:712
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:184
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:188
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:614
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:675
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:648
static PyMlirContext * createNewContextForInit()
For the case of a python init (py::init) method, pybind11 is quite strict about needing to return a p...
Definition: IRCore.cpp:621
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:700
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:702
size_t clearLiveOperations()
Clears the live operations map, returning the number of entries which were invalidated.
Definition: IRCore.cpp:659
std::vector< PyOperation * > getLiveOperationObjects()
Get a list of Python objects which are still in the live context map.
Definition: IRCore.cpp:652
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:706
void clearOperation(MlirOperation op)
Removes an operation from the live operations map and sets it invalid.
Definition: IRCore.cpp:667
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:626
size_t getLiveOperationCount()
Gets the count of live operations associated with this context.
Definition: IRCore.cpp:650
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:610
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:533
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1080
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1113
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1106
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1017
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:1932
MlirNamedAttribute namedAttr
Definition: IRModule.h:1026
pybind11::object getObject()
Definition: IRModule.h:87
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:75
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:725
PyOpView(const pybind11::object &operationObject)
Definition: IRCore.cpp:1834
static pybind11::object buildGeneric(const pybind11::object &cls, std::optional< pybind11::list > resultTypeList, pybind11::list operandList, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, DefaultingPyLocation location, const pybind11::object &maybeIp)
Definition: IRCore.cpp:1645
static pybind11::object constructDerived(const pybind11::object &cls, const PyOperation &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:1823
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:563
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1266
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void writeBytecode(const pybind11::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1245
pybind11::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified)
Definition: IRCore.cpp:1298
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1321
void print(std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, py::object fileObject, bool binary)
Implements the bound 'print' method and helps with others.
Definition: IRCore.cpp:1205
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1330
bool verify()
Verify the operation.
Definition: IRCore.cpp:1339
pybind11::object clone(const pybind11::object &ip)
Clones this operation.
Definition: IRCore.cpp:1521
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:631
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1541
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1366
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:609
PyOperationRef getRef()
Definition: IRModule.h:644
MlirOperation get() const
Definition: IRModule.h:639
void setAttached(const pybind11::object &parent=pybind11::object())
Definition: IRModule.h:650
static pybind11::object create(const std::string &name, std::optional< std::vector< PyType * >> results, std::optional< std::vector< PyValue * >> operands, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, DefaultingPyLocation location, const pybind11::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1395
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1530
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1157
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1347
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1357
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Creates a detached operation.
Definition: IRCore.cpp:1173
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1187
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1371
void checkValid() const
Definition: IRCore.cpp:1199
Wrapper around an MlirRegion.
Definition: IRModule.h:757
PyOperationRef & getParentOperation()
Definition: IRModule.h:766
MlirRegion get()
Definition: IRModule.h:765
Bindings for MLIR symbol tables.
Definition: IRModule.h:1223
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2052
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2123
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2105
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2079
MlirAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2057
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2042
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2022
static MlirAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2067
pybind11::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2030
static MlirAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2094
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, pybind11::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2135
Tracks an entry in the thread context stack.
Definition: IRModule.h:106
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:797
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:909
static pybind11::object pushContext(PyMlirContext &context)
Definition: IRCore.cpp:859
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:854
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:889
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:867
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:849
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:878
static pybind11::object pushLocation(PyLocation &location)
Definition: IRCore.cpp:900
PyMlirContext * getContext()
Definition: IRCore.cpp:826
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:844
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:792
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:832
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:895
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:1974
static PyTypeID createFromCapsule(pybind11::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:1968
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:1964
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:897
Wrapper around the generic MlirType.
Definition: IRModule.h:871
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:1948
static PyType createFromCapsule(pybind11::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:1952
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:873
bool operator==(const PyType &other) const
Definition: IRCore.cpp:1944
Wrapper around the generic MlirValue.
Definition: IRModule.h:1124
static PyValue createFromCapsule(pybind11::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:2001
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1130
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:1982
pybind11::object maybeDownCast()
Definition: IRCore.cpp:1986
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:243
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:944
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1019
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:723
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:692
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:520
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:251
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1173
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:127
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:579
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1152
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:716
@ MlirWalkPreOrder
Definition: IR.h:717
@ MlirWalkPostOrder
Definition: IR.h:718
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:653
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:368
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:321
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1100
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1157
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:214
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:932
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:82
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1121
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:314
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1034
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1017
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1073
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:71
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:792
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:995
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:591
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:978
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:279
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1054
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:663
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1005
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1084
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1092
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:831
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:716
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:883
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:684
MlirWalkResult
Operation walk result.
Definition: IR.h:709
@ MlirWalkResultInterrupt
Definition: IR.h:711
@ MlirWalkResultSkip
Definition: IR.h:712
@ MlirWalkResultAdvance
Definition: IR.h:710
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:772
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:658
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:984
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:98
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:587
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:494
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:900
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:210
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:262
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:881
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:186
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:93
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:200
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:778
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:949
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:506
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:815
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:805
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:856
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:914
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1147
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1038
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:970
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:301
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1003
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1137
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:283
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:524
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:282
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:502
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:678
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:985
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:291
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:708
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:327
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1050
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:339
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1013
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:846
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:556
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:528
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:313
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:247
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:218
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1042
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1133
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:173
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:230
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:564
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:380
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:835
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:258
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:953
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:936
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:560
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:235
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:733
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:53
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:827
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:532
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:287
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:668
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1098
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1162
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1065
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:305
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:768
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:839
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1059
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:583
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:226
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1069
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:940
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:744
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:498
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:536
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1046
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1109
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:644
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:106
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:372
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:376
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:317
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:196
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:270
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:102
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:918
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:234
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1129
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:755
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:895
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:363
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:205
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:510
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1080
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:568
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:706
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:957
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:192
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:168
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:448
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1074
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:75
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:712
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:649
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:972
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:699
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:966
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:962
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:69
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:485
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1088
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:507
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:761
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::hash_code hash_value(const MPInt &x)
Redeclarations of friend declaration above to make it discoverable by lookups.
Definition: MPInt.cpp:17
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:161
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:522
void populateIRCore(pybind11::module &m)
PyObjectRef< PyOperation > PyOperationRef
Definition: IRModule.h:605
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:340
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:261
static void dundeSetItemNamed(const std::string &attributeKind, py::function func, bool replace)
Definition: IRCore.cpp:270
static py::function dundeGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:264
static void bind(py::module &m)
Definition: IRCore.cpp:276
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:234
static void bind(py::module &m)
Definition: IRCore.cpp:239
static bool get(const py::object &)
Definition: IRCore.cpp:237
static void set(py::object &o, bool enable)
Definition: IRCore.cpp:235
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::list parts
Definition: PybindUtils.h:103
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1275
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1280
Materialized diagnostic information.
Definition: IRModule.h:353
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:419
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:429
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:420