MLIR  20.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 #include <optional>
25 #include <utility>
26 
27 namespace py = pybind11;
28 using namespace py::literals;
29 using namespace mlir;
30 using namespace mlir::python;
31 
32 using llvm::SmallVector;
33 using llvm::StringRef;
34 using llvm::Twine;
35 
36 //------------------------------------------------------------------------------
37 // Docstrings (trivial, non-duplicated docstrings are included inline).
38 //------------------------------------------------------------------------------
39 
40 static const char kContextParseTypeDocstring[] =
41  R"(Parses the assembly form of a type.
42 
43 Returns a Type object or raises an MLIRError if the type cannot be parsed.
44 
45 See also: https://mlir.llvm.org/docs/LangRef/#type-system
46 )";
47 
49  R"(Gets a Location representing a caller and callsite)";
50 
51 static const char kContextGetFileLocationDocstring[] =
52  R"(Gets a Location representing a file, line and column)";
53 
54 static const char kContextGetFusedLocationDocstring[] =
55  R"(Gets a Location representing a fused location with optional metadata)";
56 
57 static const char kContextGetNameLocationDocString[] =
58  R"(Gets a Location representing a named location with optional child location)";
59 
60 static const char kModuleParseDocstring[] =
61  R"(Parses a module's assembly format from a string.
62 
63 Returns a new MlirModule or raises an MLIRError if the parsing fails.
64 
65 See also: https://mlir.llvm.org/docs/LangRef/
66 )";
67 
68 static const char kOperationCreateDocstring[] =
69  R"(Creates a new operation.
70 
71 Args:
72  name: Operation name (e.g. "dialect.operation").
73  results: Sequence of Type representing op result types.
74  attributes: Dict of str:Attribute.
75  successors: List of Block for the operation's successors.
76  regions: Number of regions to create.
77  location: A Location object (defaults to resolve from context manager).
78  ip: An InsertionPoint (defaults to resolve from context manager or set to
79  False to disable insertion, even with an insertion point set in the
80  context manager).
81  infer_type: Whether to infer result types.
82 Returns:
83  A new "detached" Operation object. Detached operations can be added
84  to blocks, which causes them to become "attached."
85 )";
86 
87 static const char kOperationPrintDocstring[] =
88  R"(Prints the assembly form of the operation to a file like object.
89 
90 Args:
91  file: The file like object to write to. Defaults to sys.stdout.
92  binary: Whether to write bytes (True) or str (False). Defaults to False.
93  large_elements_limit: Whether to elide elements attributes above this
94  number of elements. Defaults to None (no limit).
95  enable_debug_info: Whether to print debug/location information. Defaults
96  to False.
97  pretty_debug_info: Whether to format debug information for easier reading
98  by a human (warning: the result is unparseable).
99  print_generic_op_form: Whether to print the generic assembly forms of all
100  ops. Defaults to False.
101  use_local_Scope: Whether to print in a way that is more optimized for
102  multi-threaded access but may not be consistent with how the overall
103  module prints.
104  assume_verified: By default, if not printing generic form, the verifier
105  will be run and if it fails, generic form will be printed with a comment
106  about failed verification. While a reasonable default for interactive use,
107  for systematic use, it is often better for the caller to verify explicitly
108  and report failures in a more robust fashion. Set this to True if doing this
109  in order to avoid running a redundant verification. If the IR is actually
110  invalid, behavior is undefined.
111  skip_regions: Whether to skip printing regions. Defaults to False.
112 )";
113 
114 static const char kOperationPrintStateDocstring[] =
115  R"(Prints the assembly form of the operation to a file like object.
116 
117 Args:
118  file: The file like object to write to. Defaults to sys.stdout.
119  binary: Whether to write bytes (True) or str (False). Defaults to False.
120  state: AsmState capturing the operation numbering and flags.
121 )";
122 
123 static const char kOperationGetAsmDocstring[] =
124  R"(Gets the assembly form of the operation with all options available.
125 
126 Args:
127  binary: Whether to return a bytes (True) or str (False) object. Defaults to
128  False.
129  ... others ...: See the print() method for common keyword arguments for
130  configuring the printout.
131 Returns:
132  Either a bytes or str object, depending on the setting of the 'binary'
133  argument.
134 )";
135 
136 static const char kOperationPrintBytecodeDocstring[] =
137  R"(Write the bytecode form of the operation to a file like object.
138 
139 Args:
140  file: The file like object to write to.
141  desired_version: The version of bytecode to emit.
142 Returns:
143  The bytecode writer status.
144 )";
145 
146 static const char kOperationStrDunderDocstring[] =
147  R"(Gets the assembly form of the operation with default options.
148 
149 If more advanced control over the assembly formatting or I/O options is needed,
150 use the dedicated print or get_asm method, which supports keyword arguments to
151 customize behavior.
152 )";
153 
154 static const char kDumpDocstring[] =
155  R"(Dumps a debug representation of the object to stderr.)";
156 
157 static const char kAppendBlockDocstring[] =
158  R"(Appends a new block, with argument types as positional args.
159 
160 Returns:
161  The created block.
162 )";
163 
164 static const char kValueDunderStrDocstring[] =
165  R"(Returns the string form of the value.
166 
167 If the value is a block argument, this is the assembly form of its type and the
168 position in the argument list. If the value is an operation result, this is
169 equivalent to printing the operation that produced it.
170 )";
171 
172 static const char kGetNameAsOperand[] =
173  R"(Returns the string form of value as an operand (i.e., the ValueID).
174 )";
175 
177  R"(Replace all uses of value with the new value, updating anything in
178 the IR that uses 'self' to use the other value instead.
179 )";
180 
181 //------------------------------------------------------------------------------
182 // Utilities.
183 //------------------------------------------------------------------------------
184 
185 /// Helper for creating an @classmethod.
186 template <class Func, typename... Args>
187 py::object classmethod(Func f, Args... args) {
188  py::object cf = py::cpp_function(f, args...);
189  return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
190 }
191 
192 static py::object
193 createCustomDialectWrapper(const std::string &dialectNamespace,
194  py::object dialectDescriptor) {
195  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
196  if (!dialectClass) {
197  // Use the base class.
198  return py::cast(PyDialect(std::move(dialectDescriptor)));
199  }
200 
201  // Create the custom implementation.
202  return (*dialectClass)(std::move(dialectDescriptor));
203 }
204 
205 static MlirStringRef toMlirStringRef(const std::string &s) {
206  return mlirStringRefCreate(s.data(), s.size());
207 }
208 
209 /// Create a block, using the current location context if no locations are
210 /// specified.
211 static MlirBlock createBlock(const py::sequence &pyArgTypes,
212  const std::optional<py::sequence> &pyArgLocs) {
213  SmallVector<MlirType> argTypes;
214  argTypes.reserve(pyArgTypes.size());
215  for (const auto &pyType : pyArgTypes)
216  argTypes.push_back(pyType.cast<PyType &>());
217 
219  if (pyArgLocs) {
220  argLocs.reserve(pyArgLocs->size());
221  for (const auto &pyLoc : *pyArgLocs)
222  argLocs.push_back(pyLoc.cast<PyLocation &>());
223  } else if (!argTypes.empty()) {
224  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
225  }
226 
227  if (argTypes.size() != argLocs.size())
228  throw py::value_error(("Expected " + Twine(argTypes.size()) +
229  " locations, got: " + Twine(argLocs.size()))
230  .str());
231  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
232 }
233 
234 /// Wrapper for the global LLVM debugging flag.
236  static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
237 
238  static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
239 
240  static void bind(py::module &m) {
241  // Debug flags.
242  py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
243  .def_property_static("flag", &PyGlobalDebugFlag::get,
244  &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
245  .def_static(
246  "set_types",
247  [](const std::string &type) {
248  mlirSetGlobalDebugType(type.c_str());
249  },
250  "types"_a, "Sets specific debug types to be produced by LLVM")
251  .def_static("set_types", [](const std::vector<std::string> &types) {
252  std::vector<const char *> pointers;
253  pointers.reserve(types.size());
254  for (const std::string &str : types)
255  pointers.push_back(str.c_str());
256  mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
257  });
258  }
259 };
260 
262  static bool dunderContains(const std::string &attributeKind) {
263  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
264  }
265  static py::function dundeGetItemNamed(const std::string &attributeKind) {
266  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
267  if (!builder)
268  throw py::key_error(attributeKind);
269  return *builder;
270  }
271  static void dundeSetItemNamed(const std::string &attributeKind,
272  py::function func, bool replace) {
273  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
274  replace);
275  }
276 
277  static void bind(py::module &m) {
278  py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
279  .def_static("contains", &PyAttrBuilderMap::dunderContains)
280  .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
281  .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
282  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
283  "Register an attribute builder for building MLIR "
284  "attributes from python values.");
285  }
286 };
287 
288 //------------------------------------------------------------------------------
289 // PyBlock
290 //------------------------------------------------------------------------------
291 
292 py::object PyBlock::getCapsule() {
293  return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
294 }
295 
296 //------------------------------------------------------------------------------
297 // Collections.
298 //------------------------------------------------------------------------------
299 
300 namespace {
301 
302 class PyRegionIterator {
303 public:
304  PyRegionIterator(PyOperationRef operation)
305  : operation(std::move(operation)) {}
306 
307  PyRegionIterator &dunderIter() { return *this; }
308 
309  PyRegion dunderNext() {
310  operation->checkValid();
311  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
312  throw py::stop_iteration();
313  }
314  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
315  return PyRegion(operation, region);
316  }
317 
318  static void bind(py::module &m) {
319  py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
320  .def("__iter__", &PyRegionIterator::dunderIter)
321  .def("__next__", &PyRegionIterator::dunderNext);
322  }
323 
324 private:
325  PyOperationRef operation;
326  int nextIndex = 0;
327 };
328 
329 /// Regions of an op are fixed length and indexed numerically so are represented
330 /// with a sequence-like container.
331 class PyRegionList {
332 public:
333  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
334 
335  PyRegionIterator dunderIter() {
336  operation->checkValid();
337  return PyRegionIterator(operation);
338  }
339 
340  intptr_t dunderLen() {
341  operation->checkValid();
342  return mlirOperationGetNumRegions(operation->get());
343  }
344 
345  PyRegion dunderGetItem(intptr_t index) {
346  // dunderLen checks validity.
347  if (index < 0 || index >= dunderLen()) {
348  throw py::index_error("attempt to access out of bounds region");
349  }
350  MlirRegion region = mlirOperationGetRegion(operation->get(), index);
351  return PyRegion(operation, region);
352  }
353 
354  static void bind(py::module &m) {
355  py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
356  .def("__len__", &PyRegionList::dunderLen)
357  .def("__iter__", &PyRegionList::dunderIter)
358  .def("__getitem__", &PyRegionList::dunderGetItem);
359  }
360 
361 private:
362  PyOperationRef operation;
363 };
364 
365 class PyBlockIterator {
366 public:
367  PyBlockIterator(PyOperationRef operation, MlirBlock next)
368  : operation(std::move(operation)), next(next) {}
369 
370  PyBlockIterator &dunderIter() { return *this; }
371 
372  PyBlock dunderNext() {
373  operation->checkValid();
374  if (mlirBlockIsNull(next)) {
375  throw py::stop_iteration();
376  }
377 
378  PyBlock returnBlock(operation, next);
379  next = mlirBlockGetNextInRegion(next);
380  return returnBlock;
381  }
382 
383  static void bind(py::module &m) {
384  py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
385  .def("__iter__", &PyBlockIterator::dunderIter)
386  .def("__next__", &PyBlockIterator::dunderNext);
387  }
388 
389 private:
390  PyOperationRef operation;
391  MlirBlock next;
392 };
393 
394 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
395 /// we present them as a more full-featured list-like container but optimize
396 /// it for forward iteration. Blocks are always owned by a region.
397 class PyBlockList {
398 public:
399  PyBlockList(PyOperationRef operation, MlirRegion region)
400  : operation(std::move(operation)), region(region) {}
401 
402  PyBlockIterator dunderIter() {
403  operation->checkValid();
404  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
405  }
406 
407  intptr_t dunderLen() {
408  operation->checkValid();
409  intptr_t count = 0;
410  MlirBlock block = mlirRegionGetFirstBlock(region);
411  while (!mlirBlockIsNull(block)) {
412  count += 1;
413  block = mlirBlockGetNextInRegion(block);
414  }
415  return count;
416  }
417 
418  PyBlock dunderGetItem(intptr_t index) {
419  operation->checkValid();
420  if (index < 0) {
421  throw py::index_error("attempt to access out of bounds block");
422  }
423  MlirBlock block = mlirRegionGetFirstBlock(region);
424  while (!mlirBlockIsNull(block)) {
425  if (index == 0) {
426  return PyBlock(operation, block);
427  }
428  block = mlirBlockGetNextInRegion(block);
429  index -= 1;
430  }
431  throw py::index_error("attempt to access out of bounds block");
432  }
433 
434  PyBlock appendBlock(const py::args &pyArgTypes,
435  const std::optional<py::sequence> &pyArgLocs) {
436  operation->checkValid();
437  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
438  mlirRegionAppendOwnedBlock(region, block);
439  return PyBlock(operation, block);
440  }
441 
442  static void bind(py::module &m) {
443  py::class_<PyBlockList>(m, "BlockList", py::module_local())
444  .def("__getitem__", &PyBlockList::dunderGetItem)
445  .def("__iter__", &PyBlockList::dunderIter)
446  .def("__len__", &PyBlockList::dunderLen)
447  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
448  py::arg("arg_locs") = std::nullopt);
449  }
450 
451 private:
452  PyOperationRef operation;
453  MlirRegion region;
454 };
455 
456 class PyOperationIterator {
457 public:
458  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
459  : parentOperation(std::move(parentOperation)), next(next) {}
460 
461  PyOperationIterator &dunderIter() { return *this; }
462 
463  py::object dunderNext() {
464  parentOperation->checkValid();
465  if (mlirOperationIsNull(next)) {
466  throw py::stop_iteration();
467  }
468 
469  PyOperationRef returnOperation =
470  PyOperation::forOperation(parentOperation->getContext(), next);
471  next = mlirOperationGetNextInBlock(next);
472  return returnOperation->createOpView();
473  }
474 
475  static void bind(py::module &m) {
476  py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
477  .def("__iter__", &PyOperationIterator::dunderIter)
478  .def("__next__", &PyOperationIterator::dunderNext);
479  }
480 
481 private:
482  PyOperationRef parentOperation;
483  MlirOperation next;
484 };
485 
486 /// Operations are exposed by the C-API as a forward-only linked list. In
487 /// Python, we present them as a more full-featured list-like container but
488 /// optimize it for forward iteration. Iterable operations are always owned
489 /// by a block.
490 class PyOperationList {
491 public:
492  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
493  : parentOperation(std::move(parentOperation)), block(block) {}
494 
495  PyOperationIterator dunderIter() {
496  parentOperation->checkValid();
497  return PyOperationIterator(parentOperation,
499  }
500 
501  intptr_t dunderLen() {
502  parentOperation->checkValid();
503  intptr_t count = 0;
504  MlirOperation childOp = mlirBlockGetFirstOperation(block);
505  while (!mlirOperationIsNull(childOp)) {
506  count += 1;
507  childOp = mlirOperationGetNextInBlock(childOp);
508  }
509  return count;
510  }
511 
512  py::object dunderGetItem(intptr_t index) {
513  parentOperation->checkValid();
514  if (index < 0) {
515  throw py::index_error("attempt to access out of bounds operation");
516  }
517  MlirOperation childOp = mlirBlockGetFirstOperation(block);
518  while (!mlirOperationIsNull(childOp)) {
519  if (index == 0) {
520  return PyOperation::forOperation(parentOperation->getContext(), childOp)
521  ->createOpView();
522  }
523  childOp = mlirOperationGetNextInBlock(childOp);
524  index -= 1;
525  }
526  throw py::index_error("attempt to access out of bounds operation");
527  }
528 
529  static void bind(py::module &m) {
530  py::class_<PyOperationList>(m, "OperationList", py::module_local())
531  .def("__getitem__", &PyOperationList::dunderGetItem)
532  .def("__iter__", &PyOperationList::dunderIter)
533  .def("__len__", &PyOperationList::dunderLen);
534  }
535 
536 private:
537  PyOperationRef parentOperation;
538  MlirBlock block;
539 };
540 
541 class PyOpOperand {
542 public:
543  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
544 
545  py::object getOwner() {
546  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
547  PyMlirContextRef context =
548  PyMlirContext::forContext(mlirOperationGetContext(owner));
549  return PyOperation::forOperation(context, owner)->createOpView();
550  }
551 
552  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
553 
554  static void bind(py::module &m) {
555  py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
556  .def_property_readonly("owner", &PyOpOperand::getOwner)
557  .def_property_readonly("operand_number",
558  &PyOpOperand::getOperandNumber);
559  }
560 
561 private:
562  MlirOpOperand opOperand;
563 };
564 
565 class PyOpOperandIterator {
566 public:
567  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
568 
569  PyOpOperandIterator &dunderIter() { return *this; }
570 
571  PyOpOperand dunderNext() {
572  if (mlirOpOperandIsNull(opOperand))
573  throw py::stop_iteration();
574 
575  PyOpOperand returnOpOperand(opOperand);
576  opOperand = mlirOpOperandGetNextUse(opOperand);
577  return returnOpOperand;
578  }
579 
580  static void bind(py::module &m) {
581  py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
582  .def("__iter__", &PyOpOperandIterator::dunderIter)
583  .def("__next__", &PyOpOperandIterator::dunderNext);
584  }
585 
586 private:
587  MlirOpOperand opOperand;
588 };
589 
590 } // namespace
591 
592 //------------------------------------------------------------------------------
593 // PyMlirContext
594 //------------------------------------------------------------------------------
595 
596 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
597  py::gil_scoped_acquire acquire;
598  auto &liveContexts = getLiveContexts();
599  liveContexts[context.ptr] = this;
600 }
601 
603  // Note that the only public way to construct an instance is via the
604  // forContext method, which always puts the associated handle into
605  // liveContexts.
606  py::gil_scoped_acquire acquire;
607  getLiveContexts().erase(context.ptr);
608  mlirContextDestroy(context);
609 }
610 
612  return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
613 }
614 
615 py::object PyMlirContext::createFromCapsule(py::object capsule) {
616  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
617  if (mlirContextIsNull(rawContext))
618  throw py::error_already_set();
619  return forContext(rawContext).releaseObject();
620 }
621 
623  MlirContext context = mlirContextCreateWithThreading(false);
624  return new PyMlirContext(context);
625 }
626 
628  py::gil_scoped_acquire acquire;
629  auto &liveContexts = getLiveContexts();
630  auto it = liveContexts.find(context.ptr);
631  if (it == liveContexts.end()) {
632  // Create.
633  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
634  py::object pyRef = py::cast(unownedContextWrapper);
635  assert(pyRef && "cast to py::object failed");
636  liveContexts[context.ptr] = unownedContextWrapper;
637  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
638  }
639  // Use existing.
640  py::object pyRef = py::cast(it->second);
641  return PyMlirContextRef(it->second, std::move(pyRef));
642 }
643 
644 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
645  static LiveContextMap liveContexts;
646  return liveContexts;
647 }
648 
649 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
650 
651 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
652 
653 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
654  std::vector<PyOperation *> liveObjects;
655  for (auto &entry : liveOperations)
656  liveObjects.push_back(entry.second.second);
657  return liveObjects;
658 }
659 
661  for (auto &op : liveOperations)
662  op.second.second->setInvalid();
663  size_t numInvalidated = liveOperations.size();
664  liveOperations.clear();
665  return numInvalidated;
666 }
667 
668 void PyMlirContext::clearOperation(MlirOperation op) {
669  auto it = liveOperations.find(op.ptr);
670  if (it != liveOperations.end()) {
671  it->second.second->setInvalid();
672  liveOperations.erase(it);
673  }
674 }
675 
677  typedef struct {
678  PyOperation &rootOp;
679  bool rootSeen;
680  } callBackData;
681  callBackData data{op.getOperation(), false};
682  // Mark all ops below the op that the passmanager will be rooted
683  // at (but not op itself - note the preorder) as invalid.
684  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
685  void *userData) {
686  callBackData *data = static_cast<callBackData *>(userData);
687  if (LLVM_LIKELY(data->rootSeen))
688  data->rootOp.getOperation().getContext()->clearOperation(op);
689  else
690  data->rootSeen = true;
692  };
693  mlirOperationWalk(op.getOperation(), invalidatingCallback,
694  static_cast<void *>(&data), MlirWalkPreOrder);
695 }
696 void PyMlirContext::clearOperationsInside(MlirOperation op) {
699 }
700 
702  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
703  void *userData) {
704  PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
705  contextRef->clearOperation(op);
707  };
708  mlirOperationWalk(op.getOperation(), invalidatingCallback,
709  &op.getOperation().getContext(), MlirWalkPreOrder);
710 }
711 
712 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
713 
714 pybind11::object PyMlirContext::contextEnter() {
715  return PyThreadContextEntry::pushContext(*this);
716 }
717 
718 void PyMlirContext::contextExit(const pybind11::object &excType,
719  const pybind11::object &excVal,
720  const pybind11::object &excTb) {
722 }
723 
724 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
725  // Note that ownership is transferred to the delete callback below by way of
726  // an explicit inc_ref (borrow).
727  PyDiagnosticHandler *pyHandler =
728  new PyDiagnosticHandler(get(), std::move(callback));
729  py::object pyHandlerObject =
730  py::cast(pyHandler, py::return_value_policy::take_ownership);
731  pyHandlerObject.inc_ref();
732 
733  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
734  // guaranteed to be known to pybind.
735  auto handlerCallback =
736  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
737  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
738  py::object pyDiagnosticObject =
739  py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
740 
741  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
742  bool result = false;
743  {
744  // Since this can be called from arbitrary C++ contexts, always get the
745  // gil.
746  py::gil_scoped_acquire gil;
747  try {
748  result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
749  } catch (std::exception &e) {
750  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
751  e.what());
752  pyHandler->hadError = true;
753  }
754  }
755 
756  pyDiagnostic->invalidate();
758  };
759  auto deleteCallback = +[](void *userData) {
760  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
761  assert(pyHandler->registeredID && "handler is not registered");
762  pyHandler->registeredID.reset();
763 
764  // Decrement reference, balancing the inc_ref() above.
765  py::object pyHandlerObject =
766  py::cast(pyHandler, py::return_value_policy::reference);
767  pyHandlerObject.dec_ref();
768  };
769 
770  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
771  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
772  return pyHandlerObject;
773 }
774 
775 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
776  void *userData) {
777  auto *self = static_cast<ErrorCapture *>(userData);
778  // Check if the context requested we emit errors instead of capturing them.
779  if (self->ctx->emitErrorDiagnostics)
780  return mlirLogicalResultFailure();
781 
783  return mlirLogicalResultFailure();
784 
785  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
786  return mlirLogicalResultSuccess();
787 }
788 
791  if (!context) {
792  throw std::runtime_error(
793  "An MLIR function requires a Context but none was provided in the call "
794  "or from the surrounding environment. Either pass to the function with "
795  "a 'context=' argument or establish a default using 'with Context():'");
796  }
797  return *context;
798 }
799 
800 //------------------------------------------------------------------------------
801 // PyThreadContextEntry management
802 //------------------------------------------------------------------------------
803 
804 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
805  static thread_local std::vector<PyThreadContextEntry> stack;
806  return stack;
807 }
808 
810  auto &stack = getStack();
811  if (stack.empty())
812  return nullptr;
813  return &stack.back();
814 }
815 
816 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
817  py::object insertionPoint,
818  py::object location) {
819  auto &stack = getStack();
820  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
821  std::move(location));
822  // If the new stack has more than one entry and the context of the new top
823  // entry matches the previous, copy the insertionPoint and location from the
824  // previous entry if missing from the new top entry.
825  if (stack.size() > 1) {
826  auto &prev = *(stack.rbegin() + 1);
827  auto &current = stack.back();
828  if (current.context.is(prev.context)) {
829  // Default non-context objects from the previous entry.
830  if (!current.insertionPoint)
831  current.insertionPoint = prev.insertionPoint;
832  if (!current.location)
833  current.location = prev.location;
834  }
835  }
836 }
837 
839  if (!context)
840  return nullptr;
841  return py::cast<PyMlirContext *>(context);
842 }
843 
845  if (!insertionPoint)
846  return nullptr;
847  return py::cast<PyInsertionPoint *>(insertionPoint);
848 }
849 
851  if (!location)
852  return nullptr;
853  return py::cast<PyLocation *>(location);
854 }
855 
857  auto *tos = getTopOfStack();
858  return tos ? tos->getContext() : nullptr;
859 }
860 
862  auto *tos = getTopOfStack();
863  return tos ? tos->getInsertionPoint() : nullptr;
864 }
865 
867  auto *tos = getTopOfStack();
868  return tos ? tos->getLocation() : nullptr;
869 }
870 
872  py::object contextObj = py::cast(context);
873  push(FrameKind::Context, /*context=*/contextObj,
874  /*insertionPoint=*/py::object(),
875  /*location=*/py::object());
876  return contextObj;
877 }
878 
880  auto &stack = getStack();
881  if (stack.empty())
882  throw std::runtime_error("Unbalanced Context enter/exit");
883  auto &tos = stack.back();
884  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
885  throw std::runtime_error("Unbalanced Context enter/exit");
886  stack.pop_back();
887 }
888 
889 py::object
891  py::object contextObj =
892  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
893  py::object insertionPointObj = py::cast(insertionPoint);
894  push(FrameKind::InsertionPoint,
895  /*context=*/contextObj,
896  /*insertionPoint=*/insertionPointObj,
897  /*location=*/py::object());
898  return insertionPointObj;
899 }
900 
902  auto &stack = getStack();
903  if (stack.empty())
904  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
905  auto &tos = stack.back();
906  if (tos.frameKind != FrameKind::InsertionPoint &&
907  tos.getInsertionPoint() != &insertionPoint)
908  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
909  stack.pop_back();
910 }
911 
913  py::object contextObj = location.getContext().getObject();
914  py::object locationObj = py::cast(location);
915  push(FrameKind::Location, /*context=*/contextObj,
916  /*insertionPoint=*/py::object(),
917  /*location=*/locationObj);
918  return locationObj;
919 }
920 
922  auto &stack = getStack();
923  if (stack.empty())
924  throw std::runtime_error("Unbalanced Location enter/exit");
925  auto &tos = stack.back();
926  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
927  throw std::runtime_error("Unbalanced Location enter/exit");
928  stack.pop_back();
929 }
930 
931 //------------------------------------------------------------------------------
932 // PyDiagnostic*
933 //------------------------------------------------------------------------------
934 
936  valid = false;
937  if (materializedNotes) {
938  for (auto &noteObject : *materializedNotes) {
939  PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
940  note->invalidate();
941  }
942  }
943 }
944 
946  py::object callback)
947  : context(context), callback(std::move(callback)) {}
948 
950 
952  if (!registeredID)
953  return;
954  MlirDiagnosticHandlerID localID = *registeredID;
955  mlirContextDetachDiagnosticHandler(context, localID);
956  assert(!registeredID && "should have unregistered");
957  // Not strictly necessary but keeps stale pointers from being around to cause
958  // issues.
959  context = {nullptr};
960 }
961 
962 void PyDiagnostic::checkValid() {
963  if (!valid) {
964  throw std::invalid_argument(
965  "Diagnostic is invalid (used outside of callback)");
966  }
967 }
968 
970  checkValid();
971  return mlirDiagnosticGetSeverity(diagnostic);
972 }
973 
975  checkValid();
976  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
977  MlirContext context = mlirLocationGetContext(loc);
978  return PyLocation(PyMlirContext::forContext(context), loc);
979 }
980 
982  checkValid();
983  py::object fileObject = py::module::import("io").attr("StringIO")();
984  PyFileAccumulator accum(fileObject, /*binary=*/false);
985  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
986  return fileObject.attr("getvalue")();
987 }
988 
990  checkValid();
991  if (materializedNotes)
992  return *materializedNotes;
993  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
994  materializedNotes = py::tuple(numNotes);
995  for (intptr_t i = 0; i < numNotes; ++i) {
996  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
997  (*materializedNotes)[i] = PyDiagnostic(noteDiag);
998  }
999  return *materializedNotes;
1000 }
1001 
1003  std::vector<DiagnosticInfo> notes;
1004  for (py::handle n : getNotes())
1005  notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
1006  return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
1007 }
1008 
1009 //------------------------------------------------------------------------------
1010 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1011 //------------------------------------------------------------------------------
1012 
1013 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1014  bool attrError) {
1015  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1016  {key.data(), key.size()});
1017  if (mlirDialectIsNull(dialect)) {
1018  std::string msg = (Twine("Dialect '") + key + "' not found").str();
1019  if (attrError)
1020  throw py::attribute_error(msg);
1021  throw py::index_error(msg);
1022  }
1023  return dialect;
1024 }
1025 
1027  return py::reinterpret_steal<py::object>(
1029 }
1030 
1032  MlirDialectRegistry rawRegistry =
1033  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1034  if (mlirDialectRegistryIsNull(rawRegistry))
1035  throw py::error_already_set();
1036  return PyDialectRegistry(rawRegistry);
1037 }
1038 
1039 //------------------------------------------------------------------------------
1040 // PyLocation
1041 //------------------------------------------------------------------------------
1042 
1044  return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
1045 }
1046 
1048  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1049  if (mlirLocationIsNull(rawLoc))
1050  throw py::error_already_set();
1052  rawLoc);
1053 }
1054 
1056  return PyThreadContextEntry::pushLocation(*this);
1057 }
1058 
1059 void PyLocation::contextExit(const pybind11::object &excType,
1060  const pybind11::object &excVal,
1061  const pybind11::object &excTb) {
1063 }
1064 
1066  auto *location = PyThreadContextEntry::getDefaultLocation();
1067  if (!location) {
1068  throw std::runtime_error(
1069  "An MLIR function requires a Location but none was provided in the "
1070  "call or from the surrounding environment. Either pass to the function "
1071  "with a 'loc=' argument or establish a default using 'with loc:'");
1072  }
1073  return *location;
1074 }
1075 
1076 //------------------------------------------------------------------------------
1077 // PyModule
1078 //------------------------------------------------------------------------------
1079 
1080 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1081  : BaseContextObject(std::move(contextRef)), module(module) {}
1082 
1084  py::gil_scoped_acquire acquire;
1085  auto &liveModules = getContext()->liveModules;
1086  assert(liveModules.count(module.ptr) == 1 &&
1087  "destroying module not in live map");
1088  liveModules.erase(module.ptr);
1089  mlirModuleDestroy(module);
1090 }
1091 
1092 PyModuleRef PyModule::forModule(MlirModule module) {
1093  MlirContext context = mlirModuleGetContext(module);
1094  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1095 
1096  py::gil_scoped_acquire acquire;
1097  auto &liveModules = contextRef->liveModules;
1098  auto it = liveModules.find(module.ptr);
1099  if (it == liveModules.end()) {
1100  // Create.
1101  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1102  // Note that the default return value policy on cast is automatic_reference,
1103  // which does not take ownership (delete will not be called).
1104  // Just be explicit.
1105  py::object pyRef =
1106  py::cast(unownedModule, py::return_value_policy::take_ownership);
1107  unownedModule->handle = pyRef;
1108  liveModules[module.ptr] =
1109  std::make_pair(unownedModule->handle, unownedModule);
1110  return PyModuleRef(unownedModule, std::move(pyRef));
1111  }
1112  // Use existing.
1113  PyModule *existing = it->second.second;
1114  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1115  return PyModuleRef(existing, std::move(pyRef));
1116 }
1117 
1118 py::object PyModule::createFromCapsule(py::object capsule) {
1119  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1120  if (mlirModuleIsNull(rawModule))
1121  throw py::error_already_set();
1122  return forModule(rawModule).releaseObject();
1123 }
1124 
1125 py::object PyModule::getCapsule() {
1126  return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
1127 }
1128 
1129 //------------------------------------------------------------------------------
1130 // PyOperation
1131 //------------------------------------------------------------------------------
1132 
1133 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1134  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1135 
1137  // If the operation has already been invalidated there is nothing to do.
1138  if (!valid)
1139  return;
1140 
1141  // Otherwise, invalidate the operation and remove it from live map when it is
1142  // attached.
1143  if (isAttached()) {
1144  getContext()->clearOperation(*this);
1145  } else {
1146  // And destroy it when it is detached, i.e. owned by Python, in which case
1147  // all nested operations must be invalidated at removed from the live map as
1148  // well.
1149  erase();
1150  }
1151 }
1152 
1153 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1154  MlirOperation operation,
1155  py::object parentKeepAlive) {
1156  auto &liveOperations = contextRef->liveOperations;
1157  // Create.
1158  PyOperation *unownedOperation =
1159  new PyOperation(std::move(contextRef), operation);
1160  // Note that the default return value policy on cast is automatic_reference,
1161  // which does not take ownership (delete will not be called).
1162  // Just be explicit.
1163  py::object pyRef =
1164  py::cast(unownedOperation, py::return_value_policy::take_ownership);
1165  unownedOperation->handle = pyRef;
1166  if (parentKeepAlive) {
1167  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1168  }
1169  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1170  return PyOperationRef(unownedOperation, std::move(pyRef));
1171 }
1172 
1174  MlirOperation operation,
1175  py::object parentKeepAlive) {
1176  auto &liveOperations = contextRef->liveOperations;
1177  auto it = liveOperations.find(operation.ptr);
1178  if (it == liveOperations.end()) {
1179  // Create.
1180  return createInstance(std::move(contextRef), operation,
1181  std::move(parentKeepAlive));
1182  }
1183  // Use existing.
1184  PyOperation *existing = it->second.second;
1185  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1186  return PyOperationRef(existing, std::move(pyRef));
1187 }
1188 
1190  MlirOperation operation,
1191  py::object parentKeepAlive) {
1192  auto &liveOperations = contextRef->liveOperations;
1193  assert(liveOperations.count(operation.ptr) == 0 &&
1194  "cannot create detached operation that already exists");
1195  (void)liveOperations;
1196 
1197  PyOperationRef created = createInstance(std::move(contextRef), operation,
1198  std::move(parentKeepAlive));
1199  created->attached = false;
1200  return created;
1201 }
1202 
1204  const std::string &sourceStr,
1205  const std::string &sourceName) {
1206  PyMlirContext::ErrorCapture errors(contextRef);
1207  MlirOperation op =
1208  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1209  toMlirStringRef(sourceName));
1210  if (mlirOperationIsNull(op))
1211  throw MLIRError("Unable to parse operation assembly", errors.take());
1212  return PyOperation::createDetached(std::move(contextRef), op);
1213 }
1214 
1216  if (!valid) {
1217  throw std::runtime_error("the operation has been invalidated");
1218  }
1219 }
1220 
1221 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1222  bool enableDebugInfo, bool prettyDebugInfo,
1223  bool printGenericOpForm, bool useLocalScope,
1224  bool assumeVerified, py::object fileObject,
1225  bool binary, bool skipRegions) {
1226  PyOperation &operation = getOperation();
1227  operation.checkValid();
1228  if (fileObject.is_none())
1229  fileObject = py::module::import("sys").attr("stdout");
1230 
1231  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1232  if (largeElementsLimit)
1233  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1234  if (enableDebugInfo)
1235  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1236  /*prettyForm=*/prettyDebugInfo);
1237  if (printGenericOpForm)
1239  if (useLocalScope)
1241  if (assumeVerified)
1243  if (skipRegions)
1245 
1246  PyFileAccumulator accum(fileObject, binary);
1247  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1248  accum.getUserData());
1250 }
1251 
1252 void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1253  bool binary) {
1254  PyOperation &operation = getOperation();
1255  operation.checkValid();
1256  if (fileObject.is_none())
1257  fileObject = py::module::import("sys").attr("stdout");
1258  PyFileAccumulator accum(fileObject, binary);
1259  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1260  accum.getUserData());
1261 }
1262 
1263 void PyOperationBase::writeBytecode(const py::object &fileObject,
1264  std::optional<int64_t> bytecodeVersion) {
1265  PyOperation &operation = getOperation();
1266  operation.checkValid();
1267  PyFileAccumulator accum(fileObject, /*binary=*/true);
1268 
1269  if (!bytecodeVersion.has_value())
1270  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1271  accum.getUserData());
1272 
1273  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1274  mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1276  operation, config, accum.getCallback(), accum.getUserData());
1278  if (mlirLogicalResultIsFailure(res))
1279  throw py::value_error((Twine("Unable to honor desired bytecode version ") +
1280  Twine(*bytecodeVersion))
1281  .str());
1282 }
1283 
1285  std::function<MlirWalkResult(MlirOperation)> callback,
1286  MlirWalkOrder walkOrder) {
1287  PyOperation &operation = getOperation();
1288  operation.checkValid();
1289  struct UserData {
1290  std::function<MlirWalkResult(MlirOperation)> callback;
1291  bool gotException;
1292  std::string exceptionWhat;
1293  py::object exceptionType;
1294  };
1295  UserData userData{callback, false, {}, {}};
1296  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1297  void *userData) {
1298  UserData *calleeUserData = static_cast<UserData *>(userData);
1299  try {
1300  return (calleeUserData->callback)(op);
1301  } catch (py::error_already_set &e) {
1302  calleeUserData->gotException = true;
1303  calleeUserData->exceptionWhat = e.what();
1304  calleeUserData->exceptionType = e.type();
1306  }
1307  };
1308  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1309  if (userData.gotException) {
1310  std::string message("Exception raised in callback: ");
1311  message.append(userData.exceptionWhat);
1312  throw std::runtime_error(message);
1313  }
1314 }
1315 
1316 py::object PyOperationBase::getAsm(bool binary,
1317  std::optional<int64_t> largeElementsLimit,
1318  bool enableDebugInfo, bool prettyDebugInfo,
1319  bool printGenericOpForm, bool useLocalScope,
1320  bool assumeVerified, bool skipRegions) {
1321  py::object fileObject;
1322  if (binary) {
1323  fileObject = py::module::import("io").attr("BytesIO")();
1324  } else {
1325  fileObject = py::module::import("io").attr("StringIO")();
1326  }
1327  print(/*largeElementsLimit=*/largeElementsLimit,
1328  /*enableDebugInfo=*/enableDebugInfo,
1329  /*prettyDebugInfo=*/prettyDebugInfo,
1330  /*printGenericOpForm=*/printGenericOpForm,
1331  /*useLocalScope=*/useLocalScope,
1332  /*assumeVerified=*/assumeVerified,
1333  /*fileObject=*/fileObject,
1334  /*binary=*/binary,
1335  /*skipRegions=*/skipRegions);
1336 
1337  return fileObject.attr("getvalue")();
1338 }
1339 
1341  PyOperation &operation = getOperation();
1342  PyOperation &otherOp = other.getOperation();
1343  operation.checkValid();
1344  otherOp.checkValid();
1345  mlirOperationMoveAfter(operation, otherOp);
1346  operation.parentKeepAlive = otherOp.parentKeepAlive;
1347 }
1348 
1350  PyOperation &operation = getOperation();
1351  PyOperation &otherOp = other.getOperation();
1352  operation.checkValid();
1353  otherOp.checkValid();
1354  mlirOperationMoveBefore(operation, otherOp);
1355  operation.parentKeepAlive = otherOp.parentKeepAlive;
1356 }
1357 
1359  PyOperation &op = getOperation();
1361  if (!mlirOperationVerify(op.get()))
1362  throw MLIRError("Verification failed", errors.take());
1363  return true;
1364 }
1365 
1366 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1367  checkValid();
1368  if (!isAttached())
1369  throw py::value_error("Detached operations have no parent");
1370  MlirOperation operation = mlirOperationGetParentOperation(get());
1371  if (mlirOperationIsNull(operation))
1372  return {};
1373  return PyOperation::forOperation(getContext(), operation);
1374 }
1375 
1377  checkValid();
1378  std::optional<PyOperationRef> parentOperation = getParentOperation();
1379  MlirBlock block = mlirOperationGetBlock(get());
1380  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1381  assert(parentOperation && "Operation has no parent");
1382  return PyBlock{std::move(*parentOperation), block};
1383 }
1384 
1386  checkValid();
1387  return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1388 }
1389 
1390 py::object PyOperation::createFromCapsule(py::object capsule) {
1391  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1392  if (mlirOperationIsNull(rawOperation))
1393  throw py::error_already_set();
1394  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1395  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1396  .releaseObject();
1397 }
1398 
1400  const py::object &maybeIp) {
1401  // InsertPoint active?
1402  if (!maybeIp.is(py::cast(false))) {
1403  PyInsertionPoint *ip;
1404  if (maybeIp.is_none()) {
1406  } else {
1407  ip = py::cast<PyInsertionPoint *>(maybeIp);
1408  }
1409  if (ip)
1410  ip->insert(*op.get());
1411  }
1412 }
1413 
1414 py::object PyOperation::create(const std::string &name,
1415  std::optional<std::vector<PyType *>> results,
1416  std::optional<std::vector<PyValue *>> operands,
1417  std::optional<py::dict> attributes,
1418  std::optional<std::vector<PyBlock *>> successors,
1419  int regions, DefaultingPyLocation location,
1420  const py::object &maybeIp, bool inferType) {
1421  llvm::SmallVector<MlirValue, 4> mlirOperands;
1422  llvm::SmallVector<MlirType, 4> mlirResults;
1423  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1425 
1426  // General parameter validation.
1427  if (regions < 0)
1428  throw py::value_error("number of regions must be >= 0");
1429 
1430  // Unpack/validate operands.
1431  if (operands) {
1432  mlirOperands.reserve(operands->size());
1433  for (PyValue *operand : *operands) {
1434  if (!operand)
1435  throw py::value_error("operand value cannot be None");
1436  mlirOperands.push_back(operand->get());
1437  }
1438  }
1439 
1440  // Unpack/validate results.
1441  if (results) {
1442  mlirResults.reserve(results->size());
1443  for (PyType *result : *results) {
1444  // TODO: Verify result type originate from the same context.
1445  if (!result)
1446  throw py::value_error("result type cannot be None");
1447  mlirResults.push_back(*result);
1448  }
1449  }
1450  // Unpack/validate attributes.
1451  if (attributes) {
1452  mlirAttributes.reserve(attributes->size());
1453  for (auto &it : *attributes) {
1454  std::string key;
1455  try {
1456  key = it.first.cast<std::string>();
1457  } catch (py::cast_error &err) {
1458  std::string msg = "Invalid attribute key (not a string) when "
1459  "attempting to create the operation \"" +
1460  name + "\" (" + err.what() + ")";
1461  throw py::cast_error(msg);
1462  }
1463  try {
1464  auto &attribute = it.second.cast<PyAttribute &>();
1465  // TODO: Verify attribute originates from the same context.
1466  mlirAttributes.emplace_back(std::move(key), attribute);
1467  } catch (py::reference_cast_error &) {
1468  // This exception seems thrown when the value is "None".
1469  std::string msg =
1470  "Found an invalid (`None`?) attribute value for the key \"" + key +
1471  "\" when attempting to create the operation \"" + name + "\"";
1472  throw py::cast_error(msg);
1473  } catch (py::cast_error &err) {
1474  std::string msg = "Invalid attribute value for the key \"" + key +
1475  "\" when attempting to create the operation \"" +
1476  name + "\" (" + err.what() + ")";
1477  throw py::cast_error(msg);
1478  }
1479  }
1480  }
1481  // Unpack/validate successors.
1482  if (successors) {
1483  mlirSuccessors.reserve(successors->size());
1484  for (auto *successor : *successors) {
1485  // TODO: Verify successor originate from the same context.
1486  if (!successor)
1487  throw py::value_error("successor block cannot be None");
1488  mlirSuccessors.push_back(successor->get());
1489  }
1490  }
1491 
1492  // Apply unpacked/validated to the operation state. Beyond this
1493  // point, exceptions cannot be thrown or else the state will leak.
1494  MlirOperationState state =
1495  mlirOperationStateGet(toMlirStringRef(name), location);
1496  if (!mlirOperands.empty())
1497  mlirOperationStateAddOperands(&state, mlirOperands.size(),
1498  mlirOperands.data());
1499  state.enableResultTypeInference = inferType;
1500  if (!mlirResults.empty())
1501  mlirOperationStateAddResults(&state, mlirResults.size(),
1502  mlirResults.data());
1503  if (!mlirAttributes.empty()) {
1504  // Note that the attribute names directly reference bytes in
1505  // mlirAttributes, so that vector must not be changed from here
1506  // on.
1507  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1508  mlirNamedAttributes.reserve(mlirAttributes.size());
1509  for (auto &it : mlirAttributes)
1510  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1512  toMlirStringRef(it.first)),
1513  it.second));
1514  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1515  mlirNamedAttributes.data());
1516  }
1517  if (!mlirSuccessors.empty())
1518  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1519  mlirSuccessors.data());
1520  if (regions) {
1522  mlirRegions.resize(regions);
1523  for (int i = 0; i < regions; ++i)
1524  mlirRegions[i] = mlirRegionCreate();
1525  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1526  mlirRegions.data());
1527  }
1528 
1529  // Construct the operation.
1530  MlirOperation operation = mlirOperationCreate(&state);
1531  if (!operation.ptr)
1532  throw py::value_error("Operation creation failed");
1533  PyOperationRef created =
1534  PyOperation::createDetached(location->getContext(), operation);
1535  maybeInsertOperation(created, maybeIp);
1536 
1537  return created->createOpView();
1538 }
1539 
1540 py::object PyOperation::clone(const py::object &maybeIp) {
1541  MlirOperation clonedOperation = mlirOperationClone(operation);
1542  PyOperationRef cloned =
1543  PyOperation::createDetached(getContext(), clonedOperation);
1544  maybeInsertOperation(cloned, maybeIp);
1545 
1546  return cloned->createOpView();
1547 }
1548 
1550  checkValid();
1551  MlirIdentifier ident = mlirOperationGetName(get());
1552  MlirStringRef identStr = mlirIdentifierStr(ident);
1553  auto operationCls = PyGlobals::get().lookupOperationClass(
1554  StringRef(identStr.data, identStr.length));
1555  if (operationCls)
1556  return PyOpView::constructDerived(*operationCls, *getRef().get());
1557  return py::cast(PyOpView(getRef().getObject()));
1558 }
1559 
1561  checkValid();
1563  mlirOperationDestroy(operation);
1564 }
1565 
1566 //------------------------------------------------------------------------------
1567 // PyOpView
1568 //------------------------------------------------------------------------------
1569 
1570 static void populateResultTypes(StringRef name, py::list resultTypeList,
1571  const py::object &resultSegmentSpecObj,
1572  std::vector<int32_t> &resultSegmentLengths,
1573  std::vector<PyType *> &resultTypes) {
1574  resultTypes.reserve(resultTypeList.size());
1575  if (resultSegmentSpecObj.is_none()) {
1576  // Non-variadic result unpacking.
1577  for (const auto &it : llvm::enumerate(resultTypeList)) {
1578  try {
1579  resultTypes.push_back(py::cast<PyType *>(it.value()));
1580  if (!resultTypes.back())
1581  throw py::cast_error();
1582  } catch (py::cast_error &err) {
1583  throw py::value_error((llvm::Twine("Result ") +
1584  llvm::Twine(it.index()) + " of operation \"" +
1585  name + "\" must be a Type (" + err.what() + ")")
1586  .str());
1587  }
1588  }
1589  } else {
1590  // Sized result unpacking.
1591  auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1592  if (resultSegmentSpec.size() != resultTypeList.size()) {
1593  throw py::value_error((llvm::Twine("Operation \"") + name +
1594  "\" requires " +
1595  llvm::Twine(resultSegmentSpec.size()) +
1596  " result segments but was provided " +
1597  llvm::Twine(resultTypeList.size()))
1598  .str());
1599  }
1600  resultSegmentLengths.reserve(resultTypeList.size());
1601  for (const auto &it :
1602  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1603  int segmentSpec = std::get<1>(it.value());
1604  if (segmentSpec == 1 || segmentSpec == 0) {
1605  // Unpack unary element.
1606  try {
1607  auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1608  if (resultType) {
1609  resultTypes.push_back(resultType);
1610  resultSegmentLengths.push_back(1);
1611  } else if (segmentSpec == 0) {
1612  // Allowed to be optional.
1613  resultSegmentLengths.push_back(0);
1614  } else {
1615  throw py::cast_error("was None and result is not optional");
1616  }
1617  } catch (py::cast_error &err) {
1618  throw py::value_error((llvm::Twine("Result ") +
1619  llvm::Twine(it.index()) + " of operation \"" +
1620  name + "\" must be a Type (" + err.what() +
1621  ")")
1622  .str());
1623  }
1624  } else if (segmentSpec == -1) {
1625  // Unpack sequence by appending.
1626  try {
1627  if (std::get<0>(it.value()).is_none()) {
1628  // Treat it as an empty list.
1629  resultSegmentLengths.push_back(0);
1630  } else {
1631  // Unpack the list.
1632  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1633  for (py::object segmentItem : segment) {
1634  resultTypes.push_back(py::cast<PyType *>(segmentItem));
1635  if (!resultTypes.back()) {
1636  throw py::cast_error("contained a None item");
1637  }
1638  }
1639  resultSegmentLengths.push_back(segment.size());
1640  }
1641  } catch (std::exception &err) {
1642  // NOTE: Sloppy to be using a catch-all here, but there are at least
1643  // three different unrelated exceptions that can be thrown in the
1644  // above "casts". Just keep the scope above small and catch them all.
1645  throw py::value_error((llvm::Twine("Result ") +
1646  llvm::Twine(it.index()) + " of operation \"" +
1647  name + "\" must be a Sequence of Types (" +
1648  err.what() + ")")
1649  .str());
1650  }
1651  } else {
1652  throw py::value_error("Unexpected segment spec");
1653  }
1654  }
1655  }
1656 }
1657 
1659  const py::object &cls, std::optional<py::list> resultTypeList,
1660  py::list operandList, std::optional<py::dict> attributes,
1661  std::optional<std::vector<PyBlock *>> successors,
1662  std::optional<int> regions, DefaultingPyLocation location,
1663  const py::object &maybeIp) {
1664  PyMlirContextRef context = location->getContext();
1665  // Class level operation construction metadata.
1666  std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1667  // Operand and result segment specs are either none, which does no
1668  // variadic unpacking, or a list of ints with segment sizes, where each
1669  // element is either a positive number (typically 1 for a scalar) or -1 to
1670  // indicate that it is derived from the length of the same-indexed operand
1671  // or result (implying that it is a list at that position).
1672  py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1673  py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1674 
1675  std::vector<int32_t> operandSegmentLengths;
1676  std::vector<int32_t> resultSegmentLengths;
1677 
1678  // Validate/determine region count.
1679  auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1680  int opMinRegionCount = std::get<0>(opRegionSpec);
1681  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1682  if (!regions) {
1683  regions = opMinRegionCount;
1684  }
1685  if (*regions < opMinRegionCount) {
1686  throw py::value_error(
1687  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1688  llvm::Twine(opMinRegionCount) +
1689  " regions but was built with regions=" + llvm::Twine(*regions))
1690  .str());
1691  }
1692  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1693  throw py::value_error(
1694  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1695  llvm::Twine(opMinRegionCount) +
1696  " regions but was built with regions=" + llvm::Twine(*regions))
1697  .str());
1698  }
1699 
1700  // Unpack results.
1701  std::vector<PyType *> resultTypes;
1702  if (resultTypeList.has_value()) {
1703  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1704  resultSegmentLengths, resultTypes);
1705  }
1706 
1707  // Unpack operands.
1708  std::vector<PyValue *> operands;
1709  operands.reserve(operands.size());
1710  if (operandSegmentSpecObj.is_none()) {
1711  // Non-sized operand unpacking.
1712  for (const auto &it : llvm::enumerate(operandList)) {
1713  try {
1714  operands.push_back(py::cast<PyValue *>(it.value()));
1715  if (!operands.back())
1716  throw py::cast_error();
1717  } catch (py::cast_error &err) {
1718  throw py::value_error((llvm::Twine("Operand ") +
1719  llvm::Twine(it.index()) + " of operation \"" +
1720  name + "\" must be a Value (" + err.what() + ")")
1721  .str());
1722  }
1723  }
1724  } else {
1725  // Sized operand unpacking.
1726  auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1727  if (operandSegmentSpec.size() != operandList.size()) {
1728  throw py::value_error((llvm::Twine("Operation \"") + name +
1729  "\" requires " +
1730  llvm::Twine(operandSegmentSpec.size()) +
1731  "operand segments but was provided " +
1732  llvm::Twine(operandList.size()))
1733  .str());
1734  }
1735  operandSegmentLengths.reserve(operandList.size());
1736  for (const auto &it :
1737  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1738  int segmentSpec = std::get<1>(it.value());
1739  if (segmentSpec == 1 || segmentSpec == 0) {
1740  // Unpack unary element.
1741  try {
1742  auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1743  if (operandValue) {
1744  operands.push_back(operandValue);
1745  operandSegmentLengths.push_back(1);
1746  } else if (segmentSpec == 0) {
1747  // Allowed to be optional.
1748  operandSegmentLengths.push_back(0);
1749  } else {
1750  throw py::cast_error("was None and operand is not optional");
1751  }
1752  } catch (py::cast_error &err) {
1753  throw py::value_error((llvm::Twine("Operand ") +
1754  llvm::Twine(it.index()) + " of operation \"" +
1755  name + "\" must be a Value (" + err.what() +
1756  ")")
1757  .str());
1758  }
1759  } else if (segmentSpec == -1) {
1760  // Unpack sequence by appending.
1761  try {
1762  if (std::get<0>(it.value()).is_none()) {
1763  // Treat it as an empty list.
1764  operandSegmentLengths.push_back(0);
1765  } else {
1766  // Unpack the list.
1767  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1768  for (py::object segmentItem : segment) {
1769  operands.push_back(py::cast<PyValue *>(segmentItem));
1770  if (!operands.back()) {
1771  throw py::cast_error("contained a None item");
1772  }
1773  }
1774  operandSegmentLengths.push_back(segment.size());
1775  }
1776  } catch (std::exception &err) {
1777  // NOTE: Sloppy to be using a catch-all here, but there are at least
1778  // three different unrelated exceptions that can be thrown in the
1779  // above "casts". Just keep the scope above small and catch them all.
1780  throw py::value_error((llvm::Twine("Operand ") +
1781  llvm::Twine(it.index()) + " of operation \"" +
1782  name + "\" must be a Sequence of Values (" +
1783  err.what() + ")")
1784  .str());
1785  }
1786  } else {
1787  throw py::value_error("Unexpected segment spec");
1788  }
1789  }
1790  }
1791 
1792  // Merge operand/result segment lengths into attributes if needed.
1793  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1794  // Dup.
1795  if (attributes) {
1796  attributes = py::dict(*attributes);
1797  } else {
1798  attributes = py::dict();
1799  }
1800  if (attributes->contains("resultSegmentSizes") ||
1801  attributes->contains("operandSegmentSizes")) {
1802  throw py::value_error("Manually setting a 'resultSegmentSizes' or "
1803  "'operandSegmentSizes' attribute is unsupported. "
1804  "Use Operation.create for such low-level access.");
1805  }
1806 
1807  // Add resultSegmentSizes attribute.
1808  if (!resultSegmentLengths.empty()) {
1809  MlirAttribute segmentLengthAttr =
1810  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1811  resultSegmentLengths.data());
1812  (*attributes)["resultSegmentSizes"] =
1813  PyAttribute(context, segmentLengthAttr);
1814  }
1815 
1816  // Add operandSegmentSizes attribute.
1817  if (!operandSegmentLengths.empty()) {
1818  MlirAttribute segmentLengthAttr =
1819  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1820  operandSegmentLengths.data());
1821  (*attributes)["operandSegmentSizes"] =
1822  PyAttribute(context, segmentLengthAttr);
1823  }
1824  }
1825 
1826  // Delegate to create.
1827  return PyOperation::create(name,
1828  /*results=*/std::move(resultTypes),
1829  /*operands=*/std::move(operands),
1830  /*attributes=*/std::move(attributes),
1831  /*successors=*/std::move(successors),
1832  /*regions=*/*regions, location, maybeIp,
1833  !resultTypeList);
1834 }
1835 
1836 pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1837  const PyOperation &operation) {
1838  // TODO: pybind11 2.6 supports a more direct form.
1839  // Upgrade many years from now.
1840  // auto opViewType = py::type::of<PyOpView>();
1841  py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1842  py::object instance = cls.attr("__new__")(cls);
1843  opViewType.attr("__init__")(instance, operation);
1844  return instance;
1845 }
1846 
1847 PyOpView::PyOpView(const py::object &operationObject)
1848  // Casting through the PyOperationBase base-class and then back to the
1849  // Operation lets us accept any PyOperationBase subclass.
1850  : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1851  operationObject(operation.getRef().getObject()) {}
1852 
1853 //------------------------------------------------------------------------------
1854 // PyInsertionPoint.
1855 //------------------------------------------------------------------------------
1856 
1858 
1860  : refOperation(beforeOperationBase.getOperation().getRef()),
1861  block((*refOperation)->getBlock()) {}
1862 
1864  PyOperation &operation = operationBase.getOperation();
1865  if (operation.isAttached())
1866  throw py::value_error(
1867  "Attempt to insert operation that is already attached");
1868  block.getParentOperation()->checkValid();
1869  MlirOperation beforeOp = {nullptr};
1870  if (refOperation) {
1871  // Insert before operation.
1872  (*refOperation)->checkValid();
1873  beforeOp = (*refOperation)->get();
1874  } else {
1875  // Insert at end (before null) is only valid if the block does not
1876  // already end in a known terminator (violating this will cause assertion
1877  // failures later).
1878  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1879  throw py::index_error("Cannot insert operation at the end of a block "
1880  "that already has a terminator. Did you mean to "
1881  "use 'InsertionPoint.at_block_terminator(block)' "
1882  "versus 'InsertionPoint(block)'?");
1883  }
1884  }
1885  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1886  operation.setAttached();
1887 }
1888 
1890  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1891  if (mlirOperationIsNull(firstOp)) {
1892  // Just insert at end.
1893  return PyInsertionPoint(block);
1894  }
1895 
1896  // Insert before first op.
1898  block.getParentOperation()->getContext(), firstOp);
1899  return PyInsertionPoint{block, std::move(firstOpRef)};
1900 }
1901 
1903  MlirOperation terminator = mlirBlockGetTerminator(block.get());
1904  if (mlirOperationIsNull(terminator))
1905  throw py::value_error("Block has no terminator");
1906  PyOperationRef terminatorOpRef = PyOperation::forOperation(
1907  block.getParentOperation()->getContext(), terminator);
1908  return PyInsertionPoint{block, std::move(terminatorOpRef)};
1909 }
1910 
1913 }
1914 
1915 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1916  const pybind11::object &excVal,
1917  const pybind11::object &excTb) {
1919 }
1920 
1921 //------------------------------------------------------------------------------
1922 // PyAttribute.
1923 //------------------------------------------------------------------------------
1924 
1925 bool PyAttribute::operator==(const PyAttribute &other) const {
1926  return mlirAttributeEqual(attr, other.attr);
1927 }
1928 
1930  return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1931 }
1932 
1934  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1935  if (mlirAttributeIsNull(rawAttr))
1936  throw py::error_already_set();
1937  return PyAttribute(
1939 }
1940 
1941 //------------------------------------------------------------------------------
1942 // PyNamedAttribute.
1943 //------------------------------------------------------------------------------
1944 
1945 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1946  : ownedName(new std::string(std::move(ownedName))) {
1949  toMlirStringRef(*this->ownedName)),
1950  attr);
1951 }
1952 
1953 //------------------------------------------------------------------------------
1954 // PyType.
1955 //------------------------------------------------------------------------------
1956 
1957 bool PyType::operator==(const PyType &other) const {
1958  return mlirTypeEqual(type, other.type);
1959 }
1960 
1961 py::object PyType::getCapsule() {
1962  return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1963 }
1964 
1965 PyType PyType::createFromCapsule(py::object capsule) {
1966  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1967  if (mlirTypeIsNull(rawType))
1968  throw py::error_already_set();
1970  rawType);
1971 }
1972 
1973 //------------------------------------------------------------------------------
1974 // PyTypeID.
1975 //------------------------------------------------------------------------------
1976 
1977 py::object PyTypeID::getCapsule() {
1978  return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
1979 }
1980 
1982  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1983  if (mlirTypeIDIsNull(mlirTypeID))
1984  throw py::error_already_set();
1985  return PyTypeID(mlirTypeID);
1986 }
1987 bool PyTypeID::operator==(const PyTypeID &other) const {
1988  return mlirTypeIDEqual(typeID, other.typeID);
1989 }
1990 
1991 //------------------------------------------------------------------------------
1992 // PyValue and subclasses.
1993 //------------------------------------------------------------------------------
1994 
1995 pybind11::object PyValue::getCapsule() {
1996  return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1997 }
1998 
1999 pybind11::object PyValue::maybeDownCast() {
2000  MlirType type = mlirValueGetType(get());
2001  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2002  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2003  "mlirTypeID was expected to be non-null.");
2004  std::optional<pybind11::function> valueCaster =
2006  // py::return_value_policy::move means use std::move to move the return value
2007  // contents into a new instance that will be owned by Python.
2008  py::object thisObj = py::cast(this, py::return_value_policy::move);
2009  if (!valueCaster)
2010  return thisObj;
2011  return valueCaster.value()(thisObj);
2012 }
2013 
2014 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
2015  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2016  if (mlirValueIsNull(value))
2017  throw py::error_already_set();
2018  MlirOperation owner;
2019  if (mlirValueIsAOpResult(value))
2020  owner = mlirOpResultGetOwner(value);
2021  if (mlirValueIsABlockArgument(value))
2023  if (mlirOperationIsNull(owner))
2024  throw py::error_already_set();
2025  MlirContext ctx = mlirOperationGetContext(owner);
2026  PyOperationRef ownerRef =
2028  return PyValue(ownerRef, value);
2029 }
2030 
2031 //------------------------------------------------------------------------------
2032 // PySymbolTable.
2033 //------------------------------------------------------------------------------
2034 
2036  : operation(operation.getOperation().getRef()) {
2037  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2038  if (mlirSymbolTableIsNull(symbolTable)) {
2039  throw py::cast_error("Operation is not a Symbol Table.");
2040  }
2041 }
2042 
2043 py::object PySymbolTable::dunderGetItem(const std::string &name) {
2044  operation->checkValid();
2045  MlirOperation symbol = mlirSymbolTableLookup(
2046  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2047  if (mlirOperationIsNull(symbol))
2048  throw py::key_error("Symbol '" + name + "' not in the symbol table.");
2049 
2050  return PyOperation::forOperation(operation->getContext(), symbol,
2051  operation.getObject())
2052  ->createOpView();
2053 }
2054 
2056  operation->checkValid();
2057  symbol.getOperation().checkValid();
2058  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2059  // The operation is also erased, so we must invalidate it. There may be Python
2060  // references to this operation so we don't want to delete it from the list of
2061  // live operations here.
2062  symbol.getOperation().valid = false;
2063 }
2064 
2065 void PySymbolTable::dunderDel(const std::string &name) {
2066  py::object operation = dunderGetItem(name);
2067  erase(py::cast<PyOperationBase &>(operation));
2068 }
2069 
2070 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2071  operation->checkValid();
2072  symbol.getOperation().checkValid();
2073  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2075  if (mlirAttributeIsNull(symbolAttr))
2076  throw py::value_error("Expected operation to have a symbol name.");
2077  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2078 }
2079 
2081  // Op must already be a symbol.
2082  PyOperation &operation = symbol.getOperation();
2083  operation.checkValid();
2085  MlirAttribute existingNameAttr =
2086  mlirOperationGetAttributeByName(operation.get(), attrName);
2087  if (mlirAttributeIsNull(existingNameAttr))
2088  throw py::value_error("Expected operation to have a symbol name.");
2089  return existingNameAttr;
2090 }
2091 
2093  const std::string &name) {
2094  // Op must already be a symbol.
2095  PyOperation &operation = symbol.getOperation();
2096  operation.checkValid();
2098  MlirAttribute existingNameAttr =
2099  mlirOperationGetAttributeByName(operation.get(), attrName);
2100  if (mlirAttributeIsNull(existingNameAttr))
2101  throw py::value_error("Expected operation to have a symbol name.");
2102  MlirAttribute newNameAttr =
2103  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2104  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2105 }
2106 
2108  PyOperation &operation = symbol.getOperation();
2109  operation.checkValid();
2111  MlirAttribute existingVisAttr =
2112  mlirOperationGetAttributeByName(operation.get(), attrName);
2113  if (mlirAttributeIsNull(existingVisAttr))
2114  throw py::value_error("Expected operation to have a symbol visibility.");
2115  return existingVisAttr;
2116 }
2117 
2119  const std::string &visibility) {
2120  if (visibility != "public" && visibility != "private" &&
2121  visibility != "nested")
2122  throw py::value_error(
2123  "Expected visibility to be 'public', 'private' or 'nested'");
2124  PyOperation &operation = symbol.getOperation();
2125  operation.checkValid();
2127  MlirAttribute existingVisAttr =
2128  mlirOperationGetAttributeByName(operation.get(), attrName);
2129  if (mlirAttributeIsNull(existingVisAttr))
2130  throw py::value_error("Expected operation to have a symbol visibility.");
2131  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2132  toMlirStringRef(visibility));
2133  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2134 }
2135 
2136 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2137  const std::string &newSymbol,
2138  PyOperationBase &from) {
2139  PyOperation &fromOperation = from.getOperation();
2140  fromOperation.checkValid();
2142  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2143  from.getOperation())))
2144 
2145  throw py::value_error("Symbol rename failed");
2146 }
2147 
2149  bool allSymUsesVisible,
2150  py::object callback) {
2151  PyOperation &fromOperation = from.getOperation();
2152  fromOperation.checkValid();
2153  struct UserData {
2154  PyMlirContextRef context;
2155  py::object callback;
2156  bool gotException;
2157  std::string exceptionWhat;
2158  py::object exceptionType;
2159  };
2160  UserData userData{
2161  fromOperation.getContext(), std::move(callback), false, {}, {}};
2163  fromOperation.get(), allSymUsesVisible,
2164  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2165  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2166  auto pyFoundOp =
2167  PyOperation::forOperation(calleeUserData->context, foundOp);
2168  if (calleeUserData->gotException)
2169  return;
2170  try {
2171  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2172  } catch (py::error_already_set &e) {
2173  calleeUserData->gotException = true;
2174  calleeUserData->exceptionWhat = e.what();
2175  calleeUserData->exceptionType = e.type();
2176  }
2177  },
2178  static_cast<void *>(&userData));
2179  if (userData.gotException) {
2180  std::string message("Exception raised in callback: ");
2181  message.append(userData.exceptionWhat);
2182  throw std::runtime_error(message);
2183  }
2184 }
2185 
2186 namespace {
2187 /// CRTP base class for Python MLIR values that subclass Value and should be
2188 /// castable from it. The value hierarchy is one level deep and is not supposed
2189 /// to accommodate other levels unless core MLIR changes.
2190 template <typename DerivedTy>
2191 class PyConcreteValue : public PyValue {
2192 public:
2193  // Derived classes must define statics for:
2194  // IsAFunctionTy isaFunction
2195  // const char *pyClassName
2196  // and redefine bindDerived.
2197  using ClassTy = py::class_<DerivedTy, PyValue>;
2198  using IsAFunctionTy = bool (*)(MlirValue);
2199 
2200  PyConcreteValue() = default;
2201  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2202  : PyValue(operationRef, value) {}
2203  PyConcreteValue(PyValue &orig)
2204  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2205 
2206  /// Attempts to cast the original value to the derived type and throws on
2207  /// type mismatches.
2208  static MlirValue castFrom(PyValue &orig) {
2209  if (!DerivedTy::isaFunction(orig.get())) {
2210  auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2211  throw py::value_error((Twine("Cannot cast value to ") +
2212  DerivedTy::pyClassName + " (from " + origRepr +
2213  ")")
2214  .str());
2215  }
2216  return orig.get();
2217  }
2218 
2219  /// Binds the Python module objects to functions of this class.
2220  static void bind(py::module &m) {
2221  auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
2222  cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
2223  cls.def_static(
2224  "isinstance",
2225  [](PyValue &otherValue) -> bool {
2226  return DerivedTy::isaFunction(otherValue);
2227  },
2228  py::arg("other_value"));
2230  [](DerivedTy &self) { return self.maybeDownCast(); });
2231  DerivedTy::bindDerived(cls);
2232  }
2233 
2234  /// Implemented by derived classes to add methods to the Python subclass.
2235  static void bindDerived(ClassTy &m) {}
2236 };
2237 
2238 /// Python wrapper for MlirBlockArgument.
2239 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2240 public:
2241  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2242  static constexpr const char *pyClassName = "BlockArgument";
2243  using PyConcreteValue::PyConcreteValue;
2244 
2245  static void bindDerived(ClassTy &c) {
2246  c.def_property_readonly("owner", [](PyBlockArgument &self) {
2247  return PyBlock(self.getParentOperation(),
2248  mlirBlockArgumentGetOwner(self.get()));
2249  });
2250  c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
2251  return mlirBlockArgumentGetArgNumber(self.get());
2252  });
2253  c.def(
2254  "set_type",
2255  [](PyBlockArgument &self, PyType type) {
2256  return mlirBlockArgumentSetType(self.get(), type);
2257  },
2258  py::arg("type"));
2259  }
2260 };
2261 
2262 /// Python wrapper for MlirOpResult.
2263 class PyOpResult : public PyConcreteValue<PyOpResult> {
2264 public:
2265  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2266  static constexpr const char *pyClassName = "OpResult";
2267  using PyConcreteValue::PyConcreteValue;
2268 
2269  static void bindDerived(ClassTy &c) {
2270  c.def_property_readonly("owner", [](PyOpResult &self) {
2271  assert(
2272  mlirOperationEqual(self.getParentOperation()->get(),
2273  mlirOpResultGetOwner(self.get())) &&
2274  "expected the owner of the value in Python to match that in the IR");
2275  return self.getParentOperation().getObject();
2276  });
2277  c.def_property_readonly("result_number", [](PyOpResult &self) {
2278  return mlirOpResultGetResultNumber(self.get());
2279  });
2280  }
2281 };
2282 
2283 /// Returns the list of types of the values held by container.
2284 template <typename Container>
2285 static std::vector<MlirType> getValueTypes(Container &container,
2286  PyMlirContextRef &context) {
2287  std::vector<MlirType> result;
2288  result.reserve(container.size());
2289  for (int i = 0, e = container.size(); i < e; ++i) {
2290  result.push_back(mlirValueGetType(container.getElement(i).get()));
2291  }
2292  return result;
2293 }
2294 
2295 /// A list of block arguments. Internally, these are stored as consecutive
2296 /// elements, random access is cheap. The argument list is associated with the
2297 /// operation that contains the block (detached blocks are not allowed in
2298 /// Python bindings) and extends its lifetime.
2299 class PyBlockArgumentList
2300  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2301 public:
2302  static constexpr const char *pyClassName = "BlockArgumentList";
2304 
2305  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2306  intptr_t startIndex = 0, intptr_t length = -1,
2307  intptr_t step = 1)
2308  : Sliceable(startIndex,
2309  length == -1 ? mlirBlockGetNumArguments(block) : length,
2310  step),
2311  operation(std::move(operation)), block(block) {}
2312 
2313  static void bindDerived(ClassTy &c) {
2314  c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2315  return getValueTypes(self, self.operation->getContext());
2316  });
2317  }
2318 
2319 private:
2320  /// Give the parent CRTP class access to hook implementations below.
2321  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2322 
2323  /// Returns the number of arguments in the list.
2324  intptr_t getRawNumElements() {
2325  operation->checkValid();
2326  return mlirBlockGetNumArguments(block);
2327  }
2328 
2329  /// Returns `pos`-the element in the list.
2330  PyBlockArgument getRawElement(intptr_t pos) {
2331  MlirValue argument = mlirBlockGetArgument(block, pos);
2332  return PyBlockArgument(operation, argument);
2333  }
2334 
2335  /// Returns a sublist of this list.
2336  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2337  intptr_t step) {
2338  return PyBlockArgumentList(operation, block, startIndex, length, step);
2339  }
2340 
2341  PyOperationRef operation;
2342  MlirBlock block;
2343 };
2344 
2345 /// A list of operation operands. Internally, these are stored as consecutive
2346 /// elements, random access is cheap. The (returned) operand list is associated
2347 /// with the operation whose operands these are, and thus extends the lifetime
2348 /// of this operation.
2349 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2350 public:
2351  static constexpr const char *pyClassName = "OpOperandList";
2352  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2353 
2354  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2355  intptr_t length = -1, intptr_t step = 1)
2356  : Sliceable(startIndex,
2357  length == -1 ? mlirOperationGetNumOperands(operation->get())
2358  : length,
2359  step),
2360  operation(operation) {}
2361 
2362  void dunderSetItem(intptr_t index, PyValue value) {
2363  index = wrapIndex(index);
2364  mlirOperationSetOperand(operation->get(), index, value.get());
2365  }
2366 
2367  static void bindDerived(ClassTy &c) {
2368  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2369  }
2370 
2371 private:
2372  /// Give the parent CRTP class access to hook implementations below.
2373  friend class Sliceable<PyOpOperandList, PyValue>;
2374 
2375  intptr_t getRawNumElements() {
2376  operation->checkValid();
2377  return mlirOperationGetNumOperands(operation->get());
2378  }
2379 
2380  PyValue getRawElement(intptr_t pos) {
2381  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2382  MlirOperation owner;
2383  if (mlirValueIsAOpResult(operand))
2384  owner = mlirOpResultGetOwner(operand);
2385  else if (mlirValueIsABlockArgument(operand))
2387  else
2388  assert(false && "Value must be an block arg or op result.");
2389  PyOperationRef pyOwner =
2390  PyOperation::forOperation(operation->getContext(), owner);
2391  return PyValue(pyOwner, operand);
2392  }
2393 
2394  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2395  return PyOpOperandList(operation, startIndex, length, step);
2396  }
2397 
2398  PyOperationRef operation;
2399 };
2400 
2401 /// A list of operation results. Internally, these are stored as consecutive
2402 /// elements, random access is cheap. The (returned) result list is associated
2403 /// with the operation whose results these are, and thus extends the lifetime of
2404 /// this operation.
2405 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2406 public:
2407  static constexpr const char *pyClassName = "OpResultList";
2408  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2409 
2410  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2411  intptr_t length = -1, intptr_t step = 1)
2412  : Sliceable(startIndex,
2413  length == -1 ? mlirOperationGetNumResults(operation->get())
2414  : length,
2415  step),
2416  operation(std::move(operation)) {}
2417 
2418  static void bindDerived(ClassTy &c) {
2419  c.def_property_readonly("types", [](PyOpResultList &self) {
2420  return getValueTypes(self, self.operation->getContext());
2421  });
2422  c.def_property_readonly("owner", [](PyOpResultList &self) {
2423  return self.operation->createOpView();
2424  });
2425  }
2426 
2427 private:
2428  /// Give the parent CRTP class access to hook implementations below.
2429  friend class Sliceable<PyOpResultList, PyOpResult>;
2430 
2431  intptr_t getRawNumElements() {
2432  operation->checkValid();
2433  return mlirOperationGetNumResults(operation->get());
2434  }
2435 
2436  PyOpResult getRawElement(intptr_t index) {
2437  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2438  return PyOpResult(value);
2439  }
2440 
2441  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2442  return PyOpResultList(operation, startIndex, length, step);
2443  }
2444 
2445  PyOperationRef operation;
2446 };
2447 
2448 /// A list of operation successors. Internally, these are stored as consecutive
2449 /// elements, random access is cheap. The (returned) successor list is
2450 /// associated with the operation whose successors these are, and thus extends
2451 /// the lifetime of this operation.
2452 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2453 public:
2454  static constexpr const char *pyClassName = "OpSuccessors";
2455 
2456  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2457  intptr_t length = -1, intptr_t step = 1)
2458  : Sliceable(startIndex,
2459  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2460  : length,
2461  step),
2462  operation(operation) {}
2463 
2464  void dunderSetItem(intptr_t index, PyBlock block) {
2465  index = wrapIndex(index);
2466  mlirOperationSetSuccessor(operation->get(), index, block.get());
2467  }
2468 
2469  static void bindDerived(ClassTy &c) {
2470  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2471  }
2472 
2473 private:
2474  /// Give the parent CRTP class access to hook implementations below.
2475  friend class Sliceable<PyOpSuccessors, PyBlock>;
2476 
2477  intptr_t getRawNumElements() {
2478  operation->checkValid();
2479  return mlirOperationGetNumSuccessors(operation->get());
2480  }
2481 
2482  PyBlock getRawElement(intptr_t pos) {
2483  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2484  return PyBlock(operation, block);
2485  }
2486 
2487  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2488  return PyOpSuccessors(operation, startIndex, length, step);
2489  }
2490 
2491  PyOperationRef operation;
2492 };
2493 
2494 /// A list of operation attributes. Can be indexed by name, producing
2495 /// attributes, or by index, producing named attributes.
2496 class PyOpAttributeMap {
2497 public:
2498  PyOpAttributeMap(PyOperationRef operation)
2499  : operation(std::move(operation)) {}
2500 
2501  MlirAttribute dunderGetItemNamed(const std::string &name) {
2502  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2503  toMlirStringRef(name));
2504  if (mlirAttributeIsNull(attr)) {
2505  throw py::key_error("attempt to access a non-existent attribute");
2506  }
2507  return attr;
2508  }
2509 
2510  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2511  if (index < 0 || index >= dunderLen()) {
2512  throw py::index_error("attempt to access out of bounds attribute");
2513  }
2514  MlirNamedAttribute namedAttr =
2515  mlirOperationGetAttribute(operation->get(), index);
2516  return PyNamedAttribute(
2517  namedAttr.attribute,
2518  std::string(mlirIdentifierStr(namedAttr.name).data,
2519  mlirIdentifierStr(namedAttr.name).length));
2520  }
2521 
2522  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2523  mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2524  attr);
2525  }
2526 
2527  void dunderDelItem(const std::string &name) {
2528  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2529  toMlirStringRef(name));
2530  if (!removed)
2531  throw py::key_error("attempt to delete a non-existent attribute");
2532  }
2533 
2534  intptr_t dunderLen() {
2535  return mlirOperationGetNumAttributes(operation->get());
2536  }
2537 
2538  bool dunderContains(const std::string &name) {
2540  operation->get(), toMlirStringRef(name)));
2541  }
2542 
2543  static void bind(py::module &m) {
2544  py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2545  .def("__contains__", &PyOpAttributeMap::dunderContains)
2546  .def("__len__", &PyOpAttributeMap::dunderLen)
2547  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2548  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2549  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2550  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2551  }
2552 
2553 private:
2554  PyOperationRef operation;
2555 };
2556 
2557 } // namespace
2558 
2559 //------------------------------------------------------------------------------
2560 // Populates the core exports of the 'ir' submodule.
2561 //------------------------------------------------------------------------------
2562 
2563 void mlir::python::populateIRCore(py::module &m) {
2564  //----------------------------------------------------------------------------
2565  // Enums.
2566  //----------------------------------------------------------------------------
2567  py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2568  .value("ERROR", MlirDiagnosticError)
2569  .value("WARNING", MlirDiagnosticWarning)
2570  .value("NOTE", MlirDiagnosticNote)
2571  .value("REMARK", MlirDiagnosticRemark);
2572 
2573  py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2574  .value("PRE_ORDER", MlirWalkPreOrder)
2575  .value("POST_ORDER", MlirWalkPostOrder);
2576 
2577  py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2578  .value("ADVANCE", MlirWalkResultAdvance)
2579  .value("INTERRUPT", MlirWalkResultInterrupt)
2580  .value("SKIP", MlirWalkResultSkip);
2581 
2582  //----------------------------------------------------------------------------
2583  // Mapping of Diagnostics.
2584  //----------------------------------------------------------------------------
2585  py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2586  .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2587  .def_property_readonly("location", &PyDiagnostic::getLocation)
2588  .def_property_readonly("message", &PyDiagnostic::getMessage)
2589  .def_property_readonly("notes", &PyDiagnostic::getNotes)
2590  .def("__str__", [](PyDiagnostic &self) -> py::str {
2591  if (!self.isValid())
2592  return "<Invalid Diagnostic>";
2593  return self.getMessage();
2594  });
2595 
2596  py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
2597  py::module_local())
2598  .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
2599  .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
2600  .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
2601  .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
2602  .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
2603  .def("__str__",
2604  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2605 
2606  py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2607  .def("detach", &PyDiagnosticHandler::detach)
2608  .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2609  .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2610  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2611  .def("__exit__", &PyDiagnosticHandler::contextExit);
2612 
2613  //----------------------------------------------------------------------------
2614  // Mapping of MlirContext.
2615  // Note that this is exported as _BaseContext. The containing, Python level
2616  // __init__.py will subclass it with site-specific functionality and set a
2617  // "Context" attribute on this module.
2618  //----------------------------------------------------------------------------
2619  py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2620  .def(py::init<>(&PyMlirContext::createNewContextForInit))
2621  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2622  .def("_get_context_again",
2623  [](PyMlirContext &self) {
2624  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2625  return ref.releaseObject();
2626  })
2627  .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2628  .def("_get_live_operation_objects",
2629  &PyMlirContext::getLiveOperationObjects)
2630  .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2631  .def("_clear_live_operations_inside",
2632  py::overload_cast<MlirOperation>(
2633  &PyMlirContext::clearOperationsInside))
2634  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2635  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2636  &PyMlirContext::getCapsule)
2637  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2638  .def("__enter__", &PyMlirContext::contextEnter)
2639  .def("__exit__", &PyMlirContext::contextExit)
2640  .def_property_readonly_static(
2641  "current",
2642  [](py::object & /*class*/) {
2643  auto *context = PyThreadContextEntry::getDefaultContext();
2644  if (!context)
2645  return py::none().cast<py::object>();
2646  return py::cast(context);
2647  },
2648  "Gets the Context bound to the current thread or raises ValueError")
2649  .def_property_readonly(
2650  "dialects",
2651  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2652  "Gets a container for accessing dialects by name")
2653  .def_property_readonly(
2654  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2655  "Alias for 'dialect'")
2656  .def(
2657  "get_dialect_descriptor",
2658  [=](PyMlirContext &self, std::string &name) {
2659  MlirDialect dialect = mlirContextGetOrLoadDialect(
2660  self.get(), {name.data(), name.size()});
2661  if (mlirDialectIsNull(dialect)) {
2662  throw py::value_error(
2663  (Twine("Dialect '") + name + "' not found").str());
2664  }
2665  return PyDialectDescriptor(self.getRef(), dialect);
2666  },
2667  py::arg("dialect_name"),
2668  "Gets or loads a dialect by name, returning its descriptor object")
2669  .def_property(
2670  "allow_unregistered_dialects",
2671  [](PyMlirContext &self) -> bool {
2673  },
2674  [](PyMlirContext &self, bool value) {
2676  })
2677  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2678  py::arg("callback"),
2679  "Attaches a diagnostic handler that will receive callbacks")
2680  .def(
2681  "enable_multithreading",
2682  [](PyMlirContext &self, bool enable) {
2683  mlirContextEnableMultithreading(self.get(), enable);
2684  },
2685  py::arg("enable"))
2686  .def(
2687  "is_registered_operation",
2688  [](PyMlirContext &self, std::string &name) {
2690  self.get(), MlirStringRef{name.data(), name.size()});
2691  },
2692  py::arg("operation_name"))
2693  .def(
2694  "append_dialect_registry",
2695  [](PyMlirContext &self, PyDialectRegistry &registry) {
2696  mlirContextAppendDialectRegistry(self.get(), registry);
2697  },
2698  py::arg("registry"))
2699  .def_property("emit_error_diagnostics", nullptr,
2700  &PyMlirContext::setEmitErrorDiagnostics,
2701  "Emit error diagnostics to diagnostic handlers. By default "
2702  "error diagnostics are captured and reported through "
2703  "MLIRError exceptions.")
2704  .def("load_all_available_dialects", [](PyMlirContext &self) {
2706  });
2707 
2708  //----------------------------------------------------------------------------
2709  // Mapping of PyDialectDescriptor
2710  //----------------------------------------------------------------------------
2711  py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2712  .def_property_readonly("namespace",
2713  [](PyDialectDescriptor &self) {
2714  MlirStringRef ns =
2715  mlirDialectGetNamespace(self.get());
2716  return py::str(ns.data, ns.length);
2717  })
2718  .def("__repr__", [](PyDialectDescriptor &self) {
2720  std::string repr("<DialectDescriptor ");
2721  repr.append(ns.data, ns.length);
2722  repr.append(">");
2723  return repr;
2724  });
2725 
2726  //----------------------------------------------------------------------------
2727  // Mapping of PyDialects
2728  //----------------------------------------------------------------------------
2729  py::class_<PyDialects>(m, "Dialects", py::module_local())
2730  .def("__getitem__",
2731  [=](PyDialects &self, std::string keyName) {
2732  MlirDialect dialect =
2733  self.getDialectForKey(keyName, /*attrError=*/false);
2734  py::object descriptor =
2735  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2736  return createCustomDialectWrapper(keyName, std::move(descriptor));
2737  })
2738  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2739  MlirDialect dialect =
2740  self.getDialectForKey(attrName, /*attrError=*/true);
2741  py::object descriptor =
2742  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2743  return createCustomDialectWrapper(attrName, std::move(descriptor));
2744  });
2745 
2746  //----------------------------------------------------------------------------
2747  // Mapping of PyDialect
2748  //----------------------------------------------------------------------------
2749  py::class_<PyDialect>(m, "Dialect", py::module_local())
2750  .def(py::init<py::object>(), py::arg("descriptor"))
2751  .def_property_readonly(
2752  "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2753  .def("__repr__", [](py::object self) {
2754  auto clazz = self.attr("__class__");
2755  return py::str("<Dialect ") +
2756  self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2757  clazz.attr("__module__") + py::str(".") +
2758  clazz.attr("__name__") + py::str(")>");
2759  });
2760 
2761  //----------------------------------------------------------------------------
2762  // Mapping of PyDialectRegistry
2763  //----------------------------------------------------------------------------
2764  py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2765  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2766  &PyDialectRegistry::getCapsule)
2767  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2768  .def(py::init<>());
2769 
2770  //----------------------------------------------------------------------------
2771  // Mapping of Location
2772  //----------------------------------------------------------------------------
2773  py::class_<PyLocation>(m, "Location", py::module_local())
2774  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2775  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2776  .def("__enter__", &PyLocation::contextEnter)
2777  .def("__exit__", &PyLocation::contextExit)
2778  .def("__eq__",
2779  [](PyLocation &self, PyLocation &other) -> bool {
2780  return mlirLocationEqual(self, other);
2781  })
2782  .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2783  .def_property_readonly_static(
2784  "current",
2785  [](py::object & /*class*/) {
2786  auto *loc = PyThreadContextEntry::getDefaultLocation();
2787  if (!loc)
2788  throw py::value_error("No current Location");
2789  return loc;
2790  },
2791  "Gets the Location bound to the current thread or raises ValueError")
2792  .def_static(
2793  "unknown",
2794  [](DefaultingPyMlirContext context) {
2795  return PyLocation(context->getRef(),
2796  mlirLocationUnknownGet(context->get()));
2797  },
2798  py::arg("context") = py::none(),
2799  "Gets a Location representing an unknown location")
2800  .def_static(
2801  "callsite",
2802  [](PyLocation callee, const std::vector<PyLocation> &frames,
2803  DefaultingPyMlirContext context) {
2804  if (frames.empty())
2805  throw py::value_error("No caller frames provided");
2806  MlirLocation caller = frames.back().get();
2807  for (const PyLocation &frame :
2808  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2809  caller = mlirLocationCallSiteGet(frame.get(), caller);
2810  return PyLocation(context->getRef(),
2811  mlirLocationCallSiteGet(callee.get(), caller));
2812  },
2813  py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2815  .def_static(
2816  "file",
2817  [](std::string filename, int line, int col,
2818  DefaultingPyMlirContext context) {
2819  return PyLocation(
2820  context->getRef(),
2822  context->get(), toMlirStringRef(filename), line, col));
2823  },
2824  py::arg("filename"), py::arg("line"), py::arg("col"),
2825  py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2826  .def_static(
2827  "fused",
2828  [](const std::vector<PyLocation> &pyLocations,
2829  std::optional<PyAttribute> metadata,
2830  DefaultingPyMlirContext context) {
2832  locations.reserve(pyLocations.size());
2833  for (auto &pyLocation : pyLocations)
2834  locations.push_back(pyLocation.get());
2835  MlirLocation location = mlirLocationFusedGet(
2836  context->get(), locations.size(), locations.data(),
2837  metadata ? metadata->get() : MlirAttribute{0});
2838  return PyLocation(context->getRef(), location);
2839  },
2840  py::arg("locations"), py::arg("metadata") = py::none(),
2841  py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2842  .def_static(
2843  "name",
2844  [](std::string name, std::optional<PyLocation> childLoc,
2845  DefaultingPyMlirContext context) {
2846  return PyLocation(
2847  context->getRef(),
2849  context->get(), toMlirStringRef(name),
2850  childLoc ? childLoc->get()
2851  : mlirLocationUnknownGet(context->get())));
2852  },
2853  py::arg("name"), py::arg("childLoc") = py::none(),
2854  py::arg("context") = py::none(), kContextGetNameLocationDocString)
2855  .def_static(
2856  "from_attr",
2857  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2858  return PyLocation(context->getRef(),
2859  mlirLocationFromAttribute(attribute));
2860  },
2861  py::arg("attribute"), py::arg("context") = py::none(),
2862  "Gets a Location from a LocationAttr")
2863  .def_property_readonly(
2864  "context",
2865  [](PyLocation &self) { return self.getContext().getObject(); },
2866  "Context that owns the Location")
2867  .def_property_readonly(
2868  "attr",
2869  [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2870  "Get the underlying LocationAttr")
2871  .def(
2872  "emit_error",
2873  [](PyLocation &self, std::string message) {
2874  mlirEmitError(self, message.c_str());
2875  },
2876  py::arg("message"), "Emits an error at this location")
2877  .def("__repr__", [](PyLocation &self) {
2878  PyPrintAccumulator printAccum;
2879  mlirLocationPrint(self, printAccum.getCallback(),
2880  printAccum.getUserData());
2881  return printAccum.join();
2882  });
2883 
2884  //----------------------------------------------------------------------------
2885  // Mapping of Module
2886  //----------------------------------------------------------------------------
2887  py::class_<PyModule>(m, "Module", py::module_local())
2888  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2889  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2890  .def_static(
2891  "parse",
2892  [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2893  PyMlirContext::ErrorCapture errors(context->getRef());
2894  MlirModule module = mlirModuleCreateParse(
2895  context->get(), toMlirStringRef(moduleAsm));
2896  if (mlirModuleIsNull(module))
2897  throw MLIRError("Unable to parse module assembly", errors.take());
2898  return PyModule::forModule(module).releaseObject();
2899  },
2900  py::arg("asm"), py::arg("context") = py::none(),
2902  .def_static(
2903  "create",
2904  [](DefaultingPyLocation loc) {
2905  MlirModule module = mlirModuleCreateEmpty(loc);
2906  return PyModule::forModule(module).releaseObject();
2907  },
2908  py::arg("loc") = py::none(), "Creates an empty module")
2909  .def_property_readonly(
2910  "context",
2911  [](PyModule &self) { return self.getContext().getObject(); },
2912  "Context that created the Module")
2913  .def_property_readonly(
2914  "operation",
2915  [](PyModule &self) {
2916  return PyOperation::forOperation(self.getContext(),
2917  mlirModuleGetOperation(self.get()),
2918  self.getRef().releaseObject())
2919  .releaseObject();
2920  },
2921  "Accesses the module as an operation")
2922  .def_property_readonly(
2923  "body",
2924  [](PyModule &self) {
2925  PyOperationRef moduleOp = PyOperation::forOperation(
2926  self.getContext(), mlirModuleGetOperation(self.get()),
2927  self.getRef().releaseObject());
2928  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2929  return returnBlock;
2930  },
2931  "Return the block for this module")
2932  .def(
2933  "dump",
2934  [](PyModule &self) {
2936  },
2938  .def(
2939  "__str__",
2940  [](py::object self) {
2941  // Defer to the operation's __str__.
2942  return self.attr("operation").attr("__str__")();
2943  },
2945 
2946  //----------------------------------------------------------------------------
2947  // Mapping of Operation.
2948  //----------------------------------------------------------------------------
2949  py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2950  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2951  [](PyOperationBase &self) {
2952  return self.getOperation().getCapsule();
2953  })
2954  .def("__eq__",
2955  [](PyOperationBase &self, PyOperationBase &other) {
2956  return &self.getOperation() == &other.getOperation();
2957  })
2958  .def("__eq__",
2959  [](PyOperationBase &self, py::object other) { return false; })
2960  .def("__hash__",
2961  [](PyOperationBase &self) {
2962  return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2963  })
2964  .def_property_readonly("attributes",
2965  [](PyOperationBase &self) {
2966  return PyOpAttributeMap(
2967  self.getOperation().getRef());
2968  })
2969  .def_property_readonly(
2970  "context",
2971  [](PyOperationBase &self) {
2972  PyOperation &concreteOperation = self.getOperation();
2973  concreteOperation.checkValid();
2974  return concreteOperation.getContext().getObject();
2975  },
2976  "Context that owns the Operation")
2977  .def_property_readonly("name",
2978  [](PyOperationBase &self) {
2979  auto &concreteOperation = self.getOperation();
2980  concreteOperation.checkValid();
2981  MlirOperation operation =
2982  concreteOperation.get();
2984  mlirOperationGetName(operation));
2985  return py::str(name.data, name.length);
2986  })
2987  .def_property_readonly("operands",
2988  [](PyOperationBase &self) {
2989  return PyOpOperandList(
2990  self.getOperation().getRef());
2991  })
2992  .def_property_readonly("regions",
2993  [](PyOperationBase &self) {
2994  return PyRegionList(
2995  self.getOperation().getRef());
2996  })
2997  .def_property_readonly(
2998  "results",
2999  [](PyOperationBase &self) {
3000  return PyOpResultList(self.getOperation().getRef());
3001  },
3002  "Returns the list of Operation results.")
3003  .def_property_readonly(
3004  "result",
3005  [](PyOperationBase &self) {
3006  auto &operation = self.getOperation();
3007  auto numResults = mlirOperationGetNumResults(operation);
3008  if (numResults != 1) {
3009  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
3010  throw py::value_error(
3011  (Twine("Cannot call .result on operation ") +
3012  StringRef(name.data, name.length) + " which has " +
3013  Twine(numResults) +
3014  " results (it is only valid for operations with a "
3015  "single result)")
3016  .str());
3017  }
3018  return PyOpResult(operation.getRef(),
3019  mlirOperationGetResult(operation, 0))
3020  .maybeDownCast();
3021  },
3022  "Shortcut to get an op result if it has only one (throws an error "
3023  "otherwise).")
3024  .def_property_readonly(
3025  "location",
3026  [](PyOperationBase &self) {
3027  PyOperation &operation = self.getOperation();
3028  return PyLocation(operation.getContext(),
3029  mlirOperationGetLocation(operation.get()));
3030  },
3031  "Returns the source location the operation was defined or derived "
3032  "from.")
3033  .def_property_readonly("parent",
3034  [](PyOperationBase &self) -> py::object {
3035  auto parent =
3036  self.getOperation().getParentOperation();
3037  if (parent)
3038  return parent->getObject();
3039  return py::none();
3040  })
3041  .def(
3042  "__str__",
3043  [](PyOperationBase &self) {
3044  return self.getAsm(/*binary=*/false,
3045  /*largeElementsLimit=*/std::nullopt,
3046  /*enableDebugInfo=*/false,
3047  /*prettyDebugInfo=*/false,
3048  /*printGenericOpForm=*/false,
3049  /*useLocalScope=*/false,
3050  /*assumeVerified=*/false,
3051  /*skipRegions=*/false);
3052  },
3053  "Returns the assembly form of the operation.")
3054  .def("print",
3055  py::overload_cast<PyAsmState &, pybind11::object, bool>(
3057  py::arg("state"), py::arg("file") = py::none(),
3058  py::arg("binary") = false, kOperationPrintStateDocstring)
3059  .def("print",
3060  py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3061  bool, py::object, bool, bool>(
3063  // Careful: Lots of arguments must match up with print method.
3064  py::arg("large_elements_limit") = py::none(),
3065  py::arg("enable_debug_info") = false,
3066  py::arg("pretty_debug_info") = false,
3067  py::arg("print_generic_op_form") = false,
3068  py::arg("use_local_scope") = false,
3069  py::arg("assume_verified") = false, py::arg("file") = py::none(),
3070  py::arg("binary") = false, py::arg("skip_regions") = false,
3072  .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
3073  py::arg("desired_version") = py::none(),
3075  .def("get_asm", &PyOperationBase::getAsm,
3076  // Careful: Lots of arguments must match up with get_asm method.
3077  py::arg("binary") = false,
3078  py::arg("large_elements_limit") = py::none(),
3079  py::arg("enable_debug_info") = false,
3080  py::arg("pretty_debug_info") = false,
3081  py::arg("print_generic_op_form") = false,
3082  py::arg("use_local_scope") = false,
3083  py::arg("assume_verified") = false, py::arg("skip_regions") = false,
3085  .def("verify", &PyOperationBase::verify,
3086  "Verify the operation. Raises MLIRError if verification fails, and "
3087  "returns true otherwise.")
3088  .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
3089  "Puts self immediately after the other operation in its parent "
3090  "block.")
3091  .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
3092  "Puts self immediately before the other operation in its parent "
3093  "block.")
3094  .def(
3095  "clone",
3096  [](PyOperationBase &self, py::object ip) {
3097  return self.getOperation().clone(ip);
3098  },
3099  py::arg("ip") = py::none())
3100  .def(
3101  "detach_from_parent",
3102  [](PyOperationBase &self) {
3103  PyOperation &operation = self.getOperation();
3104  operation.checkValid();
3105  if (!operation.isAttached())
3106  throw py::value_error("Detached operation has no parent.");
3107 
3108  operation.detachFromParent();
3109  return operation.createOpView();
3110  },
3111  "Detaches the operation from its parent block.")
3112  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3113  .def("walk", &PyOperationBase::walk, py::arg("callback"),
3114  py::arg("walk_order") = MlirWalkPostOrder);
3115 
3116  py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
3117  .def_static("create", &PyOperation::create, py::arg("name"),
3118  py::arg("results") = py::none(),
3119  py::arg("operands") = py::none(),
3120  py::arg("attributes") = py::none(),
3121  py::arg("successors") = py::none(), py::arg("regions") = 0,
3122  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3123  py::arg("infer_type") = false, kOperationCreateDocstring)
3124  .def_static(
3125  "parse",
3126  [](const std::string &sourceStr, const std::string &sourceName,
3127  DefaultingPyMlirContext context) {
3128  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3129  ->createOpView();
3130  },
3131  py::arg("source"), py::kw_only(), py::arg("source_name") = "",
3132  py::arg("context") = py::none(),
3133  "Parses an operation. Supports both text assembly format and binary "
3134  "bytecode format.")
3135  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3136  &PyOperation::getCapsule)
3137  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3138  .def_property_readonly("operation", [](py::object self) { return self; })
3139  .def_property_readonly("opview", &PyOperation::createOpView)
3140  .def_property_readonly(
3141  "successors",
3142  [](PyOperationBase &self) {
3143  return PyOpSuccessors(self.getOperation().getRef());
3144  },
3145  "Returns the list of Operation successors.");
3146 
3147  auto opViewClass =
3148  py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
3149  .def(py::init<py::object>(), py::arg("operation"))
3150  .def_property_readonly("operation", &PyOpView::getOperationObject)
3151  .def_property_readonly("opview", [](py::object self) { return self; })
3152  .def(
3153  "__str__",
3154  [](PyOpView &self) { return py::str(self.getOperationObject()); })
3155  .def_property_readonly(
3156  "successors",
3157  [](PyOperationBase &self) {
3158  return PyOpSuccessors(self.getOperation().getRef());
3159  },
3160  "Returns the list of Operation successors.");
3161  opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3162  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3163  opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3164  opViewClass.attr("build_generic") = classmethod(
3165  &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3166  py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3167  py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3168  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3169  "Builds a specific, generated OpView based on class level attributes.");
3170  opViewClass.attr("parse") = classmethod(
3171  [](const py::object &cls, const std::string &sourceStr,
3172  const std::string &sourceName, DefaultingPyMlirContext context) {
3173  PyOperationRef parsed =
3174  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3175 
3176  // Check if the expected operation was parsed, and cast to to the
3177  // appropriate `OpView` subclass if successful.
3178  // NOTE: This accesses attributes that have been automatically added to
3179  // `OpView` subclasses, and is not intended to be used on `OpView`
3180  // directly.
3181  std::string clsOpName =
3182  py::cast<std::string>(cls.attr("OPERATION_NAME"));
3183  MlirStringRef identifier =
3185  std::string_view parsedOpName(identifier.data, identifier.length);
3186  if (clsOpName != parsedOpName)
3187  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3188  parsedOpName + "'");
3189  return PyOpView::constructDerived(cls, *parsed.get());
3190  },
3191  py::arg("cls"), py::arg("source"), py::kw_only(),
3192  py::arg("source_name") = "", py::arg("context") = py::none(),
3193  "Parses a specific, generated OpView based on class level attributes");
3194 
3195  //----------------------------------------------------------------------------
3196  // Mapping of PyRegion.
3197  //----------------------------------------------------------------------------
3198  py::class_<PyRegion>(m, "Region", py::module_local())
3199  .def_property_readonly(
3200  "blocks",
3201  [](PyRegion &self) {
3202  return PyBlockList(self.getParentOperation(), self.get());
3203  },
3204  "Returns a forward-optimized sequence of blocks.")
3205  .def_property_readonly(
3206  "owner",
3207  [](PyRegion &self) {
3208  return self.getParentOperation()->createOpView();
3209  },
3210  "Returns the operation owning this region.")
3211  .def(
3212  "__iter__",
3213  [](PyRegion &self) {
3214  self.checkValid();
3215  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3216  return PyBlockIterator(self.getParentOperation(), firstBlock);
3217  },
3218  "Iterates over blocks in the region.")
3219  .def("__eq__",
3220  [](PyRegion &self, PyRegion &other) {
3221  return self.get().ptr == other.get().ptr;
3222  })
3223  .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3224 
3225  //----------------------------------------------------------------------------
3226  // Mapping of PyBlock.
3227  //----------------------------------------------------------------------------
3228  py::class_<PyBlock>(m, "Block", py::module_local())
3229  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3230  .def_property_readonly(
3231  "owner",
3232  [](PyBlock &self) {
3233  return self.getParentOperation()->createOpView();
3234  },
3235  "Returns the owning operation of this block.")
3236  .def_property_readonly(
3237  "region",
3238  [](PyBlock &self) {
3239  MlirRegion region = mlirBlockGetParentRegion(self.get());
3240  return PyRegion(self.getParentOperation(), region);
3241  },
3242  "Returns the owning region of this block.")
3243  .def_property_readonly(
3244  "arguments",
3245  [](PyBlock &self) {
3246  return PyBlockArgumentList(self.getParentOperation(), self.get());
3247  },
3248  "Returns a list of block arguments.")
3249  .def(
3250  "add_argument",
3251  [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3252  return mlirBlockAddArgument(self.get(), type, loc);
3253  },
3254  "Append an argument of the specified type to the block and returns "
3255  "the newly added argument.")
3256  .def(
3257  "erase_argument",
3258  [](PyBlock &self, unsigned index) {
3259  return mlirBlockEraseArgument(self.get(), index);
3260  },
3261  "Erase the argument at 'index' and remove it from the argument list.")
3262  .def_property_readonly(
3263  "operations",
3264  [](PyBlock &self) {
3265  return PyOperationList(self.getParentOperation(), self.get());
3266  },
3267  "Returns a forward-optimized sequence of operations.")
3268  .def_static(
3269  "create_at_start",
3270  [](PyRegion &parent, const py::list &pyArgTypes,
3271  const std::optional<py::sequence> &pyArgLocs) {
3272  parent.checkValid();
3273  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3274  mlirRegionInsertOwnedBlock(parent, 0, block);
3275  return PyBlock(parent.getParentOperation(), block);
3276  },
3277  py::arg("parent"), py::arg("arg_types") = py::list(),
3278  py::arg("arg_locs") = std::nullopt,
3279  "Creates and returns a new Block at the beginning of the given "
3280  "region (with given argument types and locations).")
3281  .def(
3282  "append_to",
3283  [](PyBlock &self, PyRegion &region) {
3284  MlirBlock b = self.get();
3286  mlirBlockDetach(b);
3287  mlirRegionAppendOwnedBlock(region.get(), b);
3288  },
3289  "Append this block to a region, transferring ownership if necessary")
3290  .def(
3291  "create_before",
3292  [](PyBlock &self, const py::args &pyArgTypes,
3293  const std::optional<py::sequence> &pyArgLocs) {
3294  self.checkValid();
3295  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3296  MlirRegion region = mlirBlockGetParentRegion(self.get());
3297  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3298  return PyBlock(self.getParentOperation(), block);
3299  },
3300  py::arg("arg_locs") = std::nullopt,
3301  "Creates and returns a new Block before this block "
3302  "(with given argument types and locations).")
3303  .def(
3304  "create_after",
3305  [](PyBlock &self, const py::args &pyArgTypes,
3306  const std::optional<py::sequence> &pyArgLocs) {
3307  self.checkValid();
3308  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3309  MlirRegion region = mlirBlockGetParentRegion(self.get());
3310  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3311  return PyBlock(self.getParentOperation(), block);
3312  },
3313  py::arg("arg_locs") = std::nullopt,
3314  "Creates and returns a new Block after this block "
3315  "(with given argument types and locations).")
3316  .def(
3317  "__iter__",
3318  [](PyBlock &self) {
3319  self.checkValid();
3320  MlirOperation firstOperation =
3322  return PyOperationIterator(self.getParentOperation(),
3323  firstOperation);
3324  },
3325  "Iterates over operations in the block.")
3326  .def("__eq__",
3327  [](PyBlock &self, PyBlock &other) {
3328  return self.get().ptr == other.get().ptr;
3329  })
3330  .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3331  .def("__hash__",
3332  [](PyBlock &self) {
3333  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3334  })
3335  .def(
3336  "__str__",
3337  [](PyBlock &self) {
3338  self.checkValid();
3339  PyPrintAccumulator printAccum;
3340  mlirBlockPrint(self.get(), printAccum.getCallback(),
3341  printAccum.getUserData());
3342  return printAccum.join();
3343  },
3344  "Returns the assembly form of the block.")
3345  .def(
3346  "append",
3347  [](PyBlock &self, PyOperationBase &operation) {
3348  if (operation.getOperation().isAttached())
3349  operation.getOperation().detachFromParent();
3350 
3351  MlirOperation mlirOperation = operation.getOperation().get();
3352  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3353  operation.getOperation().setAttached(
3354  self.getParentOperation().getObject());
3355  },
3356  py::arg("operation"),
3357  "Appends an operation to this block. If the operation is currently "
3358  "in another block, it will be moved.");
3359 
3360  //----------------------------------------------------------------------------
3361  // Mapping of PyInsertionPoint.
3362  //----------------------------------------------------------------------------
3363 
3364  py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
3365  .def(py::init<PyBlock &>(), py::arg("block"),
3366  "Inserts after the last operation but still inside the block.")
3367  .def("__enter__", &PyInsertionPoint::contextEnter)
3368  .def("__exit__", &PyInsertionPoint::contextExit)
3369  .def_property_readonly_static(
3370  "current",
3371  [](py::object & /*class*/) {
3372  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3373  if (!ip)
3374  throw py::value_error("No current InsertionPoint");
3375  return ip;
3376  },
3377  "Gets the InsertionPoint bound to the current thread or raises "
3378  "ValueError if none has been set")
3379  .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3380  "Inserts before a referenced operation.")
3381  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3382  py::arg("block"), "Inserts at the beginning of the block.")
3383  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3384  py::arg("block"), "Inserts before the block terminator.")
3385  .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3386  "Inserts an operation.")
3387  .def_property_readonly(
3388  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3389  "Returns the block that this InsertionPoint points to.")
3390  .def_property_readonly(
3391  "ref_operation",
3392  [](PyInsertionPoint &self) -> py::object {
3393  auto refOperation = self.getRefOperation();
3394  if (refOperation)
3395  return refOperation->getObject();
3396  return py::none();
3397  },
3398  "The reference operation before which new operations are "
3399  "inserted, or None if the insertion point is at the end of "
3400  "the block");
3401 
3402  //----------------------------------------------------------------------------
3403  // Mapping of PyAttribute.
3404  //----------------------------------------------------------------------------
3405  py::class_<PyAttribute>(m, "Attribute", py::module_local())
3406  // Delegate to the PyAttribute copy constructor, which will also lifetime
3407  // extend the backing context which owns the MlirAttribute.
3408  .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
3409  "Casts the passed attribute to the generic Attribute")
3410  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3411  &PyAttribute::getCapsule)
3412  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3413  .def_static(
3414  "parse",
3415  [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3416  PyMlirContext::ErrorCapture errors(context->getRef());
3417  MlirAttribute attr = mlirAttributeParseGet(
3418  context->get(), toMlirStringRef(attrSpec));
3419  if (mlirAttributeIsNull(attr))
3420  throw MLIRError("Unable to parse attribute", errors.take());
3421  return attr;
3422  },
3423  py::arg("asm"), py::arg("context") = py::none(),
3424  "Parses an attribute from an assembly form. Raises an MLIRError on "
3425  "failure.")
3426  .def_property_readonly(
3427  "context",
3428  [](PyAttribute &self) { return self.getContext().getObject(); },
3429  "Context that owns the Attribute")
3430  .def_property_readonly(
3431  "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
3432  .def(
3433  "get_named",
3434  [](PyAttribute &self, std::string name) {
3435  return PyNamedAttribute(self, std::move(name));
3436  },
3437  py::keep_alive<0, 1>(), "Binds a name to the attribute")
3438  .def("__eq__",
3439  [](PyAttribute &self, PyAttribute &other) { return self == other; })
3440  .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3441  .def("__hash__",
3442  [](PyAttribute &self) {
3443  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3444  })
3445  .def(
3446  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3448  .def(
3449  "__str__",
3450  [](PyAttribute &self) {
3451  PyPrintAccumulator printAccum;
3452  mlirAttributePrint(self, printAccum.getCallback(),
3453  printAccum.getUserData());
3454  return printAccum.join();
3455  },
3456  "Returns the assembly form of the Attribute.")
3457  .def("__repr__",
3458  [](PyAttribute &self) {
3459  // Generally, assembly formats are not printed for __repr__ because
3460  // this can cause exceptionally long debug output and exceptions.
3461  // However, attribute values are generally considered useful and
3462  // are printed. This may need to be re-evaluated if debug dumps end
3463  // up being excessive.
3464  PyPrintAccumulator printAccum;
3465  printAccum.parts.append("Attribute(");
3466  mlirAttributePrint(self, printAccum.getCallback(),
3467  printAccum.getUserData());
3468  printAccum.parts.append(")");
3469  return printAccum.join();
3470  })
3471  .def_property_readonly(
3472  "typeid",
3473  [](PyAttribute &self) -> MlirTypeID {
3474  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3475  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3476  "mlirTypeID was expected to be non-null.");
3477  return mlirTypeID;
3478  })
3480  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3481  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3482  "mlirTypeID was expected to be non-null.");
3483  std::optional<pybind11::function> typeCaster =
3484  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3485  mlirAttributeGetDialect(self));
3486  if (!typeCaster)
3487  return py::cast(self);
3488  return typeCaster.value()(self);
3489  });
3490 
3491  //----------------------------------------------------------------------------
3492  // Mapping of PyNamedAttribute
3493  //----------------------------------------------------------------------------
3494  py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3495  .def("__repr__",
3496  [](PyNamedAttribute &self) {
3497  PyPrintAccumulator printAccum;
3498  printAccum.parts.append("NamedAttribute(");
3499  printAccum.parts.append(
3500  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3501  mlirIdentifierStr(self.namedAttr.name).length));
3502  printAccum.parts.append("=");
3503  mlirAttributePrint(self.namedAttr.attribute,
3504  printAccum.getCallback(),
3505  printAccum.getUserData());
3506  printAccum.parts.append(")");
3507  return printAccum.join();
3508  })
3509  .def_property_readonly(
3510  "name",
3511  [](PyNamedAttribute &self) {
3512  return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3513  mlirIdentifierStr(self.namedAttr.name).length);
3514  },
3515  "The name of the NamedAttribute binding")
3516  .def_property_readonly(
3517  "attr",
3518  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3519  py::keep_alive<0, 1>(),
3520  "The underlying generic attribute of the NamedAttribute binding");
3521 
3522  //----------------------------------------------------------------------------
3523  // Mapping of PyType.
3524  //----------------------------------------------------------------------------
3525  py::class_<PyType>(m, "Type", py::module_local())
3526  // Delegate to the PyType copy constructor, which will also lifetime
3527  // extend the backing context which owns the MlirType.
3528  .def(py::init<PyType &>(), py::arg("cast_from_type"),
3529  "Casts the passed type to the generic Type")
3530  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3531  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3532  .def_static(
3533  "parse",
3534  [](std::string typeSpec, DefaultingPyMlirContext context) {
3535  PyMlirContext::ErrorCapture errors(context->getRef());
3536  MlirType type =
3537  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3538  if (mlirTypeIsNull(type))
3539  throw MLIRError("Unable to parse type", errors.take());
3540  return type;
3541  },
3542  py::arg("asm"), py::arg("context") = py::none(),
3544  .def_property_readonly(
3545  "context", [](PyType &self) { return self.getContext().getObject(); },
3546  "Context that owns the Type")
3547  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3548  .def("__eq__", [](PyType &self, py::object &other) { return false; })
3549  .def("__hash__",
3550  [](PyType &self) {
3551  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3552  })
3553  .def(
3554  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3555  .def(
3556  "__str__",
3557  [](PyType &self) {
3558  PyPrintAccumulator printAccum;
3559  mlirTypePrint(self, printAccum.getCallback(),
3560  printAccum.getUserData());
3561  return printAccum.join();
3562  },
3563  "Returns the assembly form of the type.")
3564  .def("__repr__",
3565  [](PyType &self) {
3566  // Generally, assembly formats are not printed for __repr__ because
3567  // this can cause exceptionally long debug output and exceptions.
3568  // However, types are an exception as they typically have compact
3569  // assembly forms and printing them is useful.
3570  PyPrintAccumulator printAccum;
3571  printAccum.parts.append("Type(");
3572  mlirTypePrint(self, printAccum.getCallback(),
3573  printAccum.getUserData());
3574  printAccum.parts.append(")");
3575  return printAccum.join();
3576  })
3578  [](PyType &self) {
3579  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3580  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3581  "mlirTypeID was expected to be non-null.");
3582  std::optional<pybind11::function> typeCaster =
3583  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3584  mlirTypeGetDialect(self));
3585  if (!typeCaster)
3586  return py::cast(self);
3587  return typeCaster.value()(self);
3588  })
3589  .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
3590  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3591  if (!mlirTypeIDIsNull(mlirTypeID))
3592  return mlirTypeID;
3593  auto origRepr =
3594  pybind11::repr(pybind11::cast(self)).cast<std::string>();
3595  throw py::value_error(
3596  (origRepr + llvm::Twine(" has no typeid.")).str());
3597  });
3598 
3599  //----------------------------------------------------------------------------
3600  // Mapping of PyTypeID.
3601  //----------------------------------------------------------------------------
3602  py::class_<PyTypeID>(m, "TypeID", py::module_local())
3603  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3604  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3605  // Note, this tests whether the underlying TypeIDs are the same,
3606  // not whether the wrapper MlirTypeIDs are the same, nor whether
3607  // the Python objects are the same (i.e., PyTypeID is a value type).
3608  .def("__eq__",
3609  [](PyTypeID &self, PyTypeID &other) { return self == other; })
3610  .def("__eq__",
3611  [](PyTypeID &self, const py::object &other) { return false; })
3612  // Note, this gives the hash value of the underlying TypeID, not the
3613  // hash value of the Python object, nor the hash value of the
3614  // MlirTypeID wrapper.
3615  .def("__hash__", [](PyTypeID &self) {
3616  return static_cast<size_t>(mlirTypeIDHashValue(self));
3617  });
3618 
3619  //----------------------------------------------------------------------------
3620  // Mapping of Value.
3621  //----------------------------------------------------------------------------
3622  py::class_<PyValue>(m, "Value", py::module_local())
3623  .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
3624  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3625  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3626  .def_property_readonly(
3627  "context",
3628  [](PyValue &self) { return self.getParentOperation()->getContext(); },
3629  "Context in which the value lives.")
3630  .def(
3631  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3633  .def_property_readonly(
3634  "owner",
3635  [](PyValue &self) -> py::object {
3636  MlirValue v = self.get();
3637  if (mlirValueIsAOpResult(v)) {
3638  assert(
3639  mlirOperationEqual(self.getParentOperation()->get(),
3640  mlirOpResultGetOwner(self.get())) &&
3641  "expected the owner of the value in Python to match that in "
3642  "the IR");
3643  return self.getParentOperation().getObject();
3644  }
3645 
3646  if (mlirValueIsABlockArgument(v)) {
3647  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3648  return py::cast(PyBlock(self.getParentOperation(), block));
3649  }
3650 
3651  assert(false && "Value must be a block argument or an op result");
3652  return py::none();
3653  })
3654  .def_property_readonly("uses",
3655  [](PyValue &self) {
3656  return PyOpOperandIterator(
3657  mlirValueGetFirstUse(self.get()));
3658  })
3659  .def("__eq__",
3660  [](PyValue &self, PyValue &other) {
3661  return self.get().ptr == other.get().ptr;
3662  })
3663  .def("__eq__", [](PyValue &self, py::object other) { return false; })
3664  .def("__hash__",
3665  [](PyValue &self) {
3666  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3667  })
3668  .def(
3669  "__str__",
3670  [](PyValue &self) {
3671  PyPrintAccumulator printAccum;
3672  printAccum.parts.append("Value(");
3673  mlirValuePrint(self.get(), printAccum.getCallback(),
3674  printAccum.getUserData());
3675  printAccum.parts.append(")");
3676  return printAccum.join();
3677  },
3679  .def(
3680  "get_name",
3681  [](PyValue &self, bool useLocalScope) {
3682  PyPrintAccumulator printAccum;
3683  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3684  if (useLocalScope)
3686  MlirAsmState valueState =
3687  mlirAsmStateCreateForValue(self.get(), flags);
3688  mlirValuePrintAsOperand(self.get(), valueState,
3689  printAccum.getCallback(),
3690  printAccum.getUserData());
3692  mlirAsmStateDestroy(valueState);
3693  return printAccum.join();
3694  },
3695  py::arg("use_local_scope") = false)
3696  .def(
3697  "get_name",
3698  [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
3699  PyPrintAccumulator printAccum;
3700  MlirAsmState valueState = state.get().get();
3701  mlirValuePrintAsOperand(self.get(), valueState,
3702  printAccum.getCallback(),
3703  printAccum.getUserData());
3704  return printAccum.join();
3705  },
3706  py::arg("state"), kGetNameAsOperand)
3707  .def_property_readonly(
3708  "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
3709  .def(
3710  "set_type",
3711  [](PyValue &self, const PyType &type) {
3712  return mlirValueSetType(self.get(), type);
3713  },
3714  py::arg("type"))
3715  .def(
3716  "replace_all_uses_with",
3717  [](PyValue &self, PyValue &with) {
3718  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3719  },
3722  [](PyValue &self) { return self.maybeDownCast(); });
3723  PyBlockArgument::bind(m);
3724  PyOpResult::bind(m);
3725  PyOpOperand::bind(m);
3726 
3727  py::class_<PyAsmState>(m, "AsmState", py::module_local())
3728  .def(py::init<PyValue &, bool>(), py::arg("value"),
3729  py::arg("use_local_scope") = false)
3730  .def(py::init<PyOperationBase &, bool>(), py::arg("op"),
3731  py::arg("use_local_scope") = false);
3732 
3733  //----------------------------------------------------------------------------
3734  // Mapping of SymbolTable.
3735  //----------------------------------------------------------------------------
3736  py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3737  .def(py::init<PyOperationBase &>())
3738  .def("__getitem__", &PySymbolTable::dunderGetItem)
3739  .def("insert", &PySymbolTable::insert, py::arg("operation"))
3740  .def("erase", &PySymbolTable::erase, py::arg("operation"))
3741  .def("__delitem__", &PySymbolTable::dunderDel)
3742  .def("__contains__",
3743  [](PySymbolTable &table, const std::string &name) {
3745  table, mlirStringRefCreate(name.data(), name.length())));
3746  })
3747  // Static helpers.
3748  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3749  py::arg("symbol"), py::arg("name"))
3750  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3751  py::arg("symbol"))
3752  .def_static("get_visibility", &PySymbolTable::getVisibility,
3753  py::arg("symbol"))
3754  .def_static("set_visibility", &PySymbolTable::setVisibility,
3755  py::arg("symbol"), py::arg("visibility"))
3756  .def_static("replace_all_symbol_uses",
3757  &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3758  py::arg("new_symbol"), py::arg("from_op"))
3759  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3760  py::arg("from_op"), py::arg("all_sym_uses_visible"),
3761  py::arg("callback"));
3762 
3763  // Container bindings.
3764  PyBlockArgumentList::bind(m);
3765  PyBlockIterator::bind(m);
3766  PyBlockList::bind(m);
3767  PyOperationIterator::bind(m);
3768  PyOperationList::bind(m);
3769  PyOpAttributeMap::bind(m);
3770  PyOpOperandIterator::bind(m);
3771  PyOpOperandList::bind(m);
3772  PyOpResultList::bind(m);
3773  PyOpSuccessors::bind(m);
3774  PyRegionIterator::bind(m);
3775  PyRegionList::bind(m);
3776 
3777  // Debug bindings.
3779 
3780  // Attribute builder getter.
3782 
3783  py::register_local_exception_translator([](std::exception_ptr p) {
3784  // We can't define exceptions with custom fields through pybind, so instead
3785  // the exception class is defined in python and imported here.
3786  try {
3787  if (p)
3788  std::rethrow_exception(p);
3789  } catch (const MLIRError &e) {
3790  py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
3791  .attr("MLIRError")(e.message, e.errorDiagnostics);
3792  PyErr_SetObject(PyExc_Exception, obj.ptr());
3793  }
3794  });
3795 }
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition: Debug.cpp:20
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools.
Definition: Debug.cpp:28
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:114
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:176
static py::object createCustomDialectWrapper(const std::string &dialectNamespace, py::object dialectDescriptor)
Definition: IRCore.cpp:193
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:57
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:172
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:205
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:60
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:146
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:87
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:51
static const char kDumpDocstring[]
Definition: IRCore.cpp:154
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:157
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:54
static void maybeInsertOperation(PyOperationRef &op, const py::object &maybeIp)
Definition: IRCore.cpp:1399
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:211
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:136
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:123
static void populateResultTypes(StringRef name, py::list resultTypeList, const py::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1570
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:68
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:40
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:48
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:164
py::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:187
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:273
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:118
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:348
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:189
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:180
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:255
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:357
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:235
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:367
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:454
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:198
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:330
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:264
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:445
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:216
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
const float * table
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:125
MlirStringCallback getCallback()
Definition: PybindUtils.h:132
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:209
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:303
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:311
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:518
static PyLocation & resolve()
Definition: IRCore.cpp:1065
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:292
static PyMlirContext & resolve()
Definition: IRCore.cpp:789
ReferrentTy * get() const
Definition: PybindUtils.h:47
Wrapper around an MlirAsmState.
Definition: IRModule.h:785
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1002
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:1004
static PyAttribute createFromCapsule(pybind11::object capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:1933
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:1929
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:1925
Wrapper around an MlirBlock.
Definition: IRModule.h:820
MlirBlock get()
Definition: IRModule.h:827
PyOperationRef & getParentOperation()
Definition: IRModule.h:828
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:399
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:951
PyDiagnosticHandler(MlirContext context, pybind11::object callback)
Definition: IRCore.cpp:945
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:349
pybind11::str getMessage()
Definition: IRCore.cpp:981
PyLocation getLocation()
Definition: IRCore.cpp:974
DiagnosticInfo getInfo()
Definition: IRCore.cpp:1002
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:351
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:969
pybind11::tuple getNotes()
Definition: IRCore.cpp:989
Wrapper around an MlirDialect.
Definition: IRModule.h:454
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:491
static PyDialectRegistry createFromCapsule(pybind11::object capsule)
Definition: IRCore.cpp:1031
pybind11::object getCapsule()
Definition: IRCore.cpp:1026
User-level dialect object.
Definition: IRModule.h:478
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:467
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:1013
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:34
std::optional< pybind11::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:172
std::optional< pybind11::function > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:145
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:844
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:1902
PyInsertionPoint(PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:1857
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:1889
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:1863
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1915
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1911
Wrapper around an MlirLocation.
Definition: IRModule.h:318
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:320
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1043
static PyLocation createFromCapsule(pybind11::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1047
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1059
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1055
MlirLocation get() const
Definition: IRModule.h:324
pybind11::object attachDiagnosticHandler(pybind11::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:724
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:185
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:189
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:615
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:676
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:649
void clearOperationAndInside(PyOperationBase &op)
Clears the operaiton and all operations inside using clearOperation(MlirOperation).
Definition: IRCore.cpp:701
static PyMlirContext * createNewContextForInit()
For the case of a python init (py::init) method, pybind11 is quite strict about needing to return a p...
Definition: IRCore.cpp:622
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:712
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:714
size_t clearLiveOperations()
Clears the live operations map, returning the number of entries which were invalidated.
Definition: IRCore.cpp:660
std::vector< PyOperation * > getLiveOperationObjects()
Get a list of Python objects which are still in the live context map.
Definition: IRCore.cpp:653
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:718
void clearOperation(MlirOperation op)
Removes an operation from the live operations map and sets it invalid.
Definition: IRCore.cpp:668
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:627
size_t getLiveOperationCount()
Gets the count of live operations associated with this context.
Definition: IRCore.cpp:651
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:611
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:541
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1092
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1125
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1118
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1026
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:1945
MlirNamedAttribute namedAttr
Definition: IRModule.h:1035
pybind11::object getObject()
Definition: IRModule.h:88
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:76
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:734
PyOpView(const pybind11::object &operationObject)
Definition: IRCore.cpp:1847
static pybind11::object buildGeneric(const pybind11::object &cls, std::optional< pybind11::list > resultTypeList, pybind11::list operandList, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, DefaultingPyLocation location, const pybind11::object &maybeIp)
Definition: IRCore.cpp:1658
static pybind11::object constructDerived(const pybind11::object &cls, const PyOperation &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:1836
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:571
void print(std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, py::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition: IRCore.cpp:1221
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1284
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void writeBytecode(const pybind11::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1263
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1340
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1349
bool verify()
Verify the operation.
Definition: IRCore.cpp:1358
pybind11::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions)
Definition: IRCore.cpp:1316
pybind11::object clone(const pybind11::object &ip)
Clones this operation.
Definition: IRCore.cpp:1540
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:640
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1560
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1385
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:618
PyOperationRef getRef()
Definition: IRModule.h:653
MlirOperation get() const
Definition: IRModule.h:648
void setAttached(const pybind11::object &parent=pybind11::object())
Definition: IRModule.h:659
static pybind11::object create(const std::string &name, std::optional< std::vector< PyType * >> results, std::optional< std::vector< PyValue * >> operands, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, DefaultingPyLocation location, const pybind11::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1414
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1549
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1173
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1366
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1376
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Creates a detached operation.
Definition: IRCore.cpp:1189
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1203
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1390
void checkValid() const
Definition: IRCore.cpp:1215
Wrapper around an MlirRegion.
Definition: IRModule.h:766
PyOperationRef & getParentOperation()
Definition: IRModule.h:775
MlirRegion get()
Definition: IRModule.h:774
Bindings for MLIR symbol tables.
Definition: IRModule.h:1232
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2065
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2136
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2118
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2092
MlirAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2070
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2055
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2035
static MlirAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2080
pybind11::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2043
static MlirAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2107
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, pybind11::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2148
Tracks an entry in the thread context stack.
Definition: IRModule.h:107
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:809
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:921
static pybind11::object pushContext(PyMlirContext &context)
Definition: IRCore.cpp:871
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:866
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:901
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:879
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:861
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:890
static pybind11::object pushLocation(PyLocation &location)
Definition: IRCore.cpp:912
PyMlirContext * getContext()
Definition: IRCore.cpp:838
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:856
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:804
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:844
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:904
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:1987
static PyTypeID createFromCapsule(pybind11::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:1981
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:1977
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:906
Wrapper around the generic MlirType.
Definition: IRModule.h:880
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:1961
static PyType createFromCapsule(pybind11::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:1965
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:882
bool operator==(const PyType &other) const
Definition: IRCore.cpp:1957
Wrapper around the generic MlirValue.
Definition: IRModule.h:1133
static PyValue createFromCapsule(pybind11::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:2014
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1139
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:1995
pybind11::object maybeDownCast()
Definition: IRCore.cpp:1999
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:251
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:956
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1034
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:735
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:700
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:528
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:259
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1185
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:127
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:587
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1164
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:728
@ MlirWalkPreOrder
Definition: IR.h:729
@ MlirWalkPostOrder
Definition: IR.h:730
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:661
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:376
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:329
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1112
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1169
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:219
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:944
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:82
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1133
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:314
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1046
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1029
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1085
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:71
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:800
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:1007
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:599
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:990
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:287
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1066
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:671
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1017
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1096
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1104
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:839
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:724
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:898
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:692
MlirWalkResult
Operation walk result.
Definition: IR.h:721
@ MlirWalkResultInterrupt
Definition: IR.h:723
@ MlirWalkResultSkip
Definition: IR.h:724
@ MlirWalkResultAdvance
Definition: IR.h:722
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:780
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:666
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:999
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:98
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:595
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:502
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:908
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:215
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:270
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:889
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:186
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:93
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:200
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:786
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:961
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:514
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:823
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:817
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:864
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:926
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1159
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1050
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:982
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:309
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1015
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1149
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:291
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:532
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:282
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:510
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition: IR.cpp:912
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:686
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:997
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:299
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:716
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:335
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1062
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:347
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1025
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:854
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Skip printing regions.
Definition: IR.cpp:227
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:564
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:536
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:321
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:255
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:223
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1054
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1145
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:173
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:238
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:572
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:388
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:843
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:266
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:965
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:948
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:568
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:235
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:741
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:53
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:835
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:540
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:295
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: IR.cpp:917
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:676
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1110
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1174
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1077
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:313
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:776
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:847
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1071
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:591
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:234
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1081
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:952
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:756
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:506
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:544
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1058
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1121
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:652
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:106
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:380
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:384
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:325
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:196
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:278
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:102
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:930
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:242
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1141
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:763
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:903
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:371
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:210
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:518
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1092
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:576
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:714
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:969
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:192
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:168
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:456
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1089
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:75
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:720
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:657
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:984
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:707
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:978
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:974
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:69
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:493
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1100
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:519
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:769
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Definition: Polynomial.h:262
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:162
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:530
void populateIRCore(pybind11::module &m)
PyObjectRef< PyOperation > PyOperationRef
Definition: IRModule.h:614
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:340
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:262
static void dundeSetItemNamed(const std::string &attributeKind, py::function func, bool replace)
Definition: IRCore.cpp:271
static py::function dundeGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:265
static void bind(py::module &m)
Definition: IRCore.cpp:277
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:235
static void bind(py::module &m)
Definition: IRCore.cpp:240
static bool get(const py::object &)
Definition: IRCore.cpp:238
static void set(py::object &o, bool enable)
Definition: IRCore.cpp:236
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::list parts
Definition: PybindUtils.h:103
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1284
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1289
Materialized diagnostic information.
Definition: IRModule.h:361
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:427
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:437
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:428