MLIR  20.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 #include <optional>
25 #include <utility>
26 
27 namespace py = pybind11;
28 using namespace py::literals;
29 using namespace mlir;
30 using namespace mlir::python;
31 
32 using llvm::SmallVector;
33 using llvm::StringRef;
34 using llvm::Twine;
35 
36 //------------------------------------------------------------------------------
37 // Docstrings (trivial, non-duplicated docstrings are included inline).
38 //------------------------------------------------------------------------------
39 
40 static const char kContextParseTypeDocstring[] =
41  R"(Parses the assembly form of a type.
42 
43 Returns a Type object or raises an MLIRError if the type cannot be parsed.
44 
45 See also: https://mlir.llvm.org/docs/LangRef/#type-system
46 )";
47 
49  R"(Gets a Location representing a caller and callsite)";
50 
51 static const char kContextGetFileLocationDocstring[] =
52  R"(Gets a Location representing a file, line and column)";
53 
54 static const char kContextGetFusedLocationDocstring[] =
55  R"(Gets a Location representing a fused location with optional metadata)";
56 
57 static const char kContextGetNameLocationDocString[] =
58  R"(Gets a Location representing a named location with optional child location)";
59 
60 static const char kModuleParseDocstring[] =
61  R"(Parses a module's assembly format from a string.
62 
63 Returns a new MlirModule or raises an MLIRError if the parsing fails.
64 
65 See also: https://mlir.llvm.org/docs/LangRef/
66 )";
67 
68 static const char kOperationCreateDocstring[] =
69  R"(Creates a new operation.
70 
71 Args:
72  name: Operation name (e.g. "dialect.operation").
73  results: Sequence of Type representing op result types.
74  attributes: Dict of str:Attribute.
75  successors: List of Block for the operation's successors.
76  regions: Number of regions to create.
77  location: A Location object (defaults to resolve from context manager).
78  ip: An InsertionPoint (defaults to resolve from context manager or set to
79  False to disable insertion, even with an insertion point set in the
80  context manager).
81  infer_type: Whether to infer result types.
82 Returns:
83  A new "detached" Operation object. Detached operations can be added
84  to blocks, which causes them to become "attached."
85 )";
86 
87 static const char kOperationPrintDocstring[] =
88  R"(Prints the assembly form of the operation to a file like object.
89 
90 Args:
91  file: The file like object to write to. Defaults to sys.stdout.
92  binary: Whether to write bytes (True) or str (False). Defaults to False.
93  large_elements_limit: Whether to elide elements attributes above this
94  number of elements. Defaults to None (no limit).
95  enable_debug_info: Whether to print debug/location information. Defaults
96  to False.
97  pretty_debug_info: Whether to format debug information for easier reading
98  by a human (warning: the result is unparseable).
99  print_generic_op_form: Whether to print the generic assembly forms of all
100  ops. Defaults to False.
101  use_local_Scope: Whether to print in a way that is more optimized for
102  multi-threaded access but may not be consistent with how the overall
103  module prints.
104  assume_verified: By default, if not printing generic form, the verifier
105  will be run and if it fails, generic form will be printed with a comment
106  about failed verification. While a reasonable default for interactive use,
107  for systematic use, it is often better for the caller to verify explicitly
108  and report failures in a more robust fashion. Set this to True if doing this
109  in order to avoid running a redundant verification. If the IR is actually
110  invalid, behavior is undefined.
111  skip_regions: Whether to skip printing regions. Defaults to False.
112 )";
113 
114 static const char kOperationPrintStateDocstring[] =
115  R"(Prints the assembly form of the operation to a file like object.
116 
117 Args:
118  file: The file like object to write to. Defaults to sys.stdout.
119  binary: Whether to write bytes (True) or str (False). Defaults to False.
120  state: AsmState capturing the operation numbering and flags.
121 )";
122 
123 static const char kOperationGetAsmDocstring[] =
124  R"(Gets the assembly form of the operation with all options available.
125 
126 Args:
127  binary: Whether to return a bytes (True) or str (False) object. Defaults to
128  False.
129  ... others ...: See the print() method for common keyword arguments for
130  configuring the printout.
131 Returns:
132  Either a bytes or str object, depending on the setting of the 'binary'
133  argument.
134 )";
135 
136 static const char kOperationPrintBytecodeDocstring[] =
137  R"(Write the bytecode form of the operation to a file like object.
138 
139 Args:
140  file: The file like object to write to.
141  desired_version: The version of bytecode to emit.
142 Returns:
143  The bytecode writer status.
144 )";
145 
146 static const char kOperationStrDunderDocstring[] =
147  R"(Gets the assembly form of the operation with default options.
148 
149 If more advanced control over the assembly formatting or I/O options is needed,
150 use the dedicated print or get_asm method, which supports keyword arguments to
151 customize behavior.
152 )";
153 
154 static const char kDumpDocstring[] =
155  R"(Dumps a debug representation of the object to stderr.)";
156 
157 static const char kAppendBlockDocstring[] =
158  R"(Appends a new block, with argument types as positional args.
159 
160 Returns:
161  The created block.
162 )";
163 
164 static const char kValueDunderStrDocstring[] =
165  R"(Returns the string form of the value.
166 
167 If the value is a block argument, this is the assembly form of its type and the
168 position in the argument list. If the value is an operation result, this is
169 equivalent to printing the operation that produced it.
170 )";
171 
172 static const char kGetNameAsOperand[] =
173  R"(Returns the string form of value as an operand (i.e., the ValueID).
174 )";
175 
177  R"(Replace all uses of value with the new value, updating anything in
178 the IR that uses 'self' to use the other value instead.
179 )";
180 
182  R"("Replace all uses of this value with the 'with' value, except for those
183 in 'exceptions'. 'exceptions' can be either a single operation or a list of
184 operations.
185 )";
186 
187 //------------------------------------------------------------------------------
188 // Utilities.
189 //------------------------------------------------------------------------------
190 
191 /// Helper for creating an @classmethod.
192 template <class Func, typename... Args>
193 py::object classmethod(Func f, Args... args) {
194  py::object cf = py::cpp_function(f, args...);
195  return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
196 }
197 
198 static py::object
199 createCustomDialectWrapper(const std::string &dialectNamespace,
200  py::object dialectDescriptor) {
201  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
202  if (!dialectClass) {
203  // Use the base class.
204  return py::cast(PyDialect(std::move(dialectDescriptor)));
205  }
206 
207  // Create the custom implementation.
208  return (*dialectClass)(std::move(dialectDescriptor));
209 }
210 
211 static MlirStringRef toMlirStringRef(const std::string &s) {
212  return mlirStringRefCreate(s.data(), s.size());
213 }
214 
215 /// Create a block, using the current location context if no locations are
216 /// specified.
217 static MlirBlock createBlock(const py::sequence &pyArgTypes,
218  const std::optional<py::sequence> &pyArgLocs) {
219  SmallVector<MlirType> argTypes;
220  argTypes.reserve(pyArgTypes.size());
221  for (const auto &pyType : pyArgTypes)
222  argTypes.push_back(pyType.cast<PyType &>());
223 
225  if (pyArgLocs) {
226  argLocs.reserve(pyArgLocs->size());
227  for (const auto &pyLoc : *pyArgLocs)
228  argLocs.push_back(pyLoc.cast<PyLocation &>());
229  } else if (!argTypes.empty()) {
230  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
231  }
232 
233  if (argTypes.size() != argLocs.size())
234  throw py::value_error(("Expected " + Twine(argTypes.size()) +
235  " locations, got: " + Twine(argLocs.size()))
236  .str());
237  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
238 }
239 
240 /// Wrapper for the global LLVM debugging flag.
242  static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
243 
244  static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
245 
246  static void bind(py::module &m) {
247  // Debug flags.
248  py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
249  .def_property_static("flag", &PyGlobalDebugFlag::get,
250  &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
251  .def_static(
252  "set_types",
253  [](const std::string &type) {
254  mlirSetGlobalDebugType(type.c_str());
255  },
256  "types"_a, "Sets specific debug types to be produced by LLVM")
257  .def_static("set_types", [](const std::vector<std::string> &types) {
258  std::vector<const char *> pointers;
259  pointers.reserve(types.size());
260  for (const std::string &str : types)
261  pointers.push_back(str.c_str());
262  mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
263  });
264  }
265 };
266 
268  static bool dunderContains(const std::string &attributeKind) {
269  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
270  }
271  static py::function dundeGetItemNamed(const std::string &attributeKind) {
272  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
273  if (!builder)
274  throw py::key_error(attributeKind);
275  return *builder;
276  }
277  static void dundeSetItemNamed(const std::string &attributeKind,
278  py::function func, bool replace) {
279  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
280  replace);
281  }
282 
283  static void bind(py::module &m) {
284  py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
285  .def_static("contains", &PyAttrBuilderMap::dunderContains)
286  .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
287  .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
288  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
289  "Register an attribute builder for building MLIR "
290  "attributes from python values.");
291  }
292 };
293 
294 //------------------------------------------------------------------------------
295 // PyBlock
296 //------------------------------------------------------------------------------
297 
298 py::object PyBlock::getCapsule() {
299  return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
300 }
301 
302 //------------------------------------------------------------------------------
303 // Collections.
304 //------------------------------------------------------------------------------
305 
306 namespace {
307 
308 class PyRegionIterator {
309 public:
310  PyRegionIterator(PyOperationRef operation)
311  : operation(std::move(operation)) {}
312 
313  PyRegionIterator &dunderIter() { return *this; }
314 
315  PyRegion dunderNext() {
316  operation->checkValid();
317  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
318  throw py::stop_iteration();
319  }
320  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
321  return PyRegion(operation, region);
322  }
323 
324  static void bind(py::module &m) {
325  py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
326  .def("__iter__", &PyRegionIterator::dunderIter)
327  .def("__next__", &PyRegionIterator::dunderNext);
328  }
329 
330 private:
331  PyOperationRef operation;
332  int nextIndex = 0;
333 };
334 
335 /// Regions of an op are fixed length and indexed numerically so are represented
336 /// with a sequence-like container.
337 class PyRegionList {
338 public:
339  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
340 
341  PyRegionIterator dunderIter() {
342  operation->checkValid();
343  return PyRegionIterator(operation);
344  }
345 
346  intptr_t dunderLen() {
347  operation->checkValid();
348  return mlirOperationGetNumRegions(operation->get());
349  }
350 
351  PyRegion dunderGetItem(intptr_t index) {
352  // dunderLen checks validity.
353  if (index < 0 || index >= dunderLen()) {
354  throw py::index_error("attempt to access out of bounds region");
355  }
356  MlirRegion region = mlirOperationGetRegion(operation->get(), index);
357  return PyRegion(operation, region);
358  }
359 
360  static void bind(py::module &m) {
361  py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
362  .def("__len__", &PyRegionList::dunderLen)
363  .def("__iter__", &PyRegionList::dunderIter)
364  .def("__getitem__", &PyRegionList::dunderGetItem);
365  }
366 
367 private:
368  PyOperationRef operation;
369 };
370 
371 class PyBlockIterator {
372 public:
373  PyBlockIterator(PyOperationRef operation, MlirBlock next)
374  : operation(std::move(operation)), next(next) {}
375 
376  PyBlockIterator &dunderIter() { return *this; }
377 
378  PyBlock dunderNext() {
379  operation->checkValid();
380  if (mlirBlockIsNull(next)) {
381  throw py::stop_iteration();
382  }
383 
384  PyBlock returnBlock(operation, next);
385  next = mlirBlockGetNextInRegion(next);
386  return returnBlock;
387  }
388 
389  static void bind(py::module &m) {
390  py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
391  .def("__iter__", &PyBlockIterator::dunderIter)
392  .def("__next__", &PyBlockIterator::dunderNext);
393  }
394 
395 private:
396  PyOperationRef operation;
397  MlirBlock next;
398 };
399 
400 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
401 /// we present them as a more full-featured list-like container but optimize
402 /// it for forward iteration. Blocks are always owned by a region.
403 class PyBlockList {
404 public:
405  PyBlockList(PyOperationRef operation, MlirRegion region)
406  : operation(std::move(operation)), region(region) {}
407 
408  PyBlockIterator dunderIter() {
409  operation->checkValid();
410  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
411  }
412 
413  intptr_t dunderLen() {
414  operation->checkValid();
415  intptr_t count = 0;
416  MlirBlock block = mlirRegionGetFirstBlock(region);
417  while (!mlirBlockIsNull(block)) {
418  count += 1;
419  block = mlirBlockGetNextInRegion(block);
420  }
421  return count;
422  }
423 
424  PyBlock dunderGetItem(intptr_t index) {
425  operation->checkValid();
426  if (index < 0) {
427  throw py::index_error("attempt to access out of bounds block");
428  }
429  MlirBlock block = mlirRegionGetFirstBlock(region);
430  while (!mlirBlockIsNull(block)) {
431  if (index == 0) {
432  return PyBlock(operation, block);
433  }
434  block = mlirBlockGetNextInRegion(block);
435  index -= 1;
436  }
437  throw py::index_error("attempt to access out of bounds block");
438  }
439 
440  PyBlock appendBlock(const py::args &pyArgTypes,
441  const std::optional<py::sequence> &pyArgLocs) {
442  operation->checkValid();
443  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
444  mlirRegionAppendOwnedBlock(region, block);
445  return PyBlock(operation, block);
446  }
447 
448  static void bind(py::module &m) {
449  py::class_<PyBlockList>(m, "BlockList", py::module_local())
450  .def("__getitem__", &PyBlockList::dunderGetItem)
451  .def("__iter__", &PyBlockList::dunderIter)
452  .def("__len__", &PyBlockList::dunderLen)
453  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
454  py::arg("arg_locs") = std::nullopt);
455  }
456 
457 private:
458  PyOperationRef operation;
459  MlirRegion region;
460 };
461 
462 class PyOperationIterator {
463 public:
464  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
465  : parentOperation(std::move(parentOperation)), next(next) {}
466 
467  PyOperationIterator &dunderIter() { return *this; }
468 
469  py::object dunderNext() {
470  parentOperation->checkValid();
471  if (mlirOperationIsNull(next)) {
472  throw py::stop_iteration();
473  }
474 
475  PyOperationRef returnOperation =
476  PyOperation::forOperation(parentOperation->getContext(), next);
477  next = mlirOperationGetNextInBlock(next);
478  return returnOperation->createOpView();
479  }
480 
481  static void bind(py::module &m) {
482  py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
483  .def("__iter__", &PyOperationIterator::dunderIter)
484  .def("__next__", &PyOperationIterator::dunderNext);
485  }
486 
487 private:
488  PyOperationRef parentOperation;
489  MlirOperation next;
490 };
491 
492 /// Operations are exposed by the C-API as a forward-only linked list. In
493 /// Python, we present them as a more full-featured list-like container but
494 /// optimize it for forward iteration. Iterable operations are always owned
495 /// by a block.
496 class PyOperationList {
497 public:
498  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
499  : parentOperation(std::move(parentOperation)), block(block) {}
500 
501  PyOperationIterator dunderIter() {
502  parentOperation->checkValid();
503  return PyOperationIterator(parentOperation,
505  }
506 
507  intptr_t dunderLen() {
508  parentOperation->checkValid();
509  intptr_t count = 0;
510  MlirOperation childOp = mlirBlockGetFirstOperation(block);
511  while (!mlirOperationIsNull(childOp)) {
512  count += 1;
513  childOp = mlirOperationGetNextInBlock(childOp);
514  }
515  return count;
516  }
517 
518  py::object dunderGetItem(intptr_t index) {
519  parentOperation->checkValid();
520  if (index < 0) {
521  throw py::index_error("attempt to access out of bounds operation");
522  }
523  MlirOperation childOp = mlirBlockGetFirstOperation(block);
524  while (!mlirOperationIsNull(childOp)) {
525  if (index == 0) {
526  return PyOperation::forOperation(parentOperation->getContext(), childOp)
527  ->createOpView();
528  }
529  childOp = mlirOperationGetNextInBlock(childOp);
530  index -= 1;
531  }
532  throw py::index_error("attempt to access out of bounds operation");
533  }
534 
535  static void bind(py::module &m) {
536  py::class_<PyOperationList>(m, "OperationList", py::module_local())
537  .def("__getitem__", &PyOperationList::dunderGetItem)
538  .def("__iter__", &PyOperationList::dunderIter)
539  .def("__len__", &PyOperationList::dunderLen);
540  }
541 
542 private:
543  PyOperationRef parentOperation;
544  MlirBlock block;
545 };
546 
547 class PyOpOperand {
548 public:
549  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
550 
551  py::object getOwner() {
552  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
553  PyMlirContextRef context =
554  PyMlirContext::forContext(mlirOperationGetContext(owner));
555  return PyOperation::forOperation(context, owner)->createOpView();
556  }
557 
558  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
559 
560  static void bind(py::module &m) {
561  py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
562  .def_property_readonly("owner", &PyOpOperand::getOwner)
563  .def_property_readonly("operand_number",
564  &PyOpOperand::getOperandNumber);
565  }
566 
567 private:
568  MlirOpOperand opOperand;
569 };
570 
571 class PyOpOperandIterator {
572 public:
573  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
574 
575  PyOpOperandIterator &dunderIter() { return *this; }
576 
577  PyOpOperand dunderNext() {
578  if (mlirOpOperandIsNull(opOperand))
579  throw py::stop_iteration();
580 
581  PyOpOperand returnOpOperand(opOperand);
582  opOperand = mlirOpOperandGetNextUse(opOperand);
583  return returnOpOperand;
584  }
585 
586  static void bind(py::module &m) {
587  py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
588  .def("__iter__", &PyOpOperandIterator::dunderIter)
589  .def("__next__", &PyOpOperandIterator::dunderNext);
590  }
591 
592 private:
593  MlirOpOperand opOperand;
594 };
595 
596 } // namespace
597 
598 //------------------------------------------------------------------------------
599 // PyMlirContext
600 //------------------------------------------------------------------------------
601 
602 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
603  py::gil_scoped_acquire acquire;
604  auto &liveContexts = getLiveContexts();
605  liveContexts[context.ptr] = this;
606 }
607 
609  // Note that the only public way to construct an instance is via the
610  // forContext method, which always puts the associated handle into
611  // liveContexts.
612  py::gil_scoped_acquire acquire;
613  getLiveContexts().erase(context.ptr);
614  mlirContextDestroy(context);
615 }
616 
618  return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
619 }
620 
621 py::object PyMlirContext::createFromCapsule(py::object capsule) {
622  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
623  if (mlirContextIsNull(rawContext))
624  throw py::error_already_set();
625  return forContext(rawContext).releaseObject();
626 }
627 
629  MlirContext context = mlirContextCreateWithThreading(false);
630  return new PyMlirContext(context);
631 }
632 
634  py::gil_scoped_acquire acquire;
635  auto &liveContexts = getLiveContexts();
636  auto it = liveContexts.find(context.ptr);
637  if (it == liveContexts.end()) {
638  // Create.
639  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
640  py::object pyRef = py::cast(unownedContextWrapper);
641  assert(pyRef && "cast to py::object failed");
642  liveContexts[context.ptr] = unownedContextWrapper;
643  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
644  }
645  // Use existing.
646  py::object pyRef = py::cast(it->second);
647  return PyMlirContextRef(it->second, std::move(pyRef));
648 }
649 
650 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
651  static LiveContextMap liveContexts;
652  return liveContexts;
653 }
654 
655 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
656 
657 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
658 
659 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
660  std::vector<PyOperation *> liveObjects;
661  for (auto &entry : liveOperations)
662  liveObjects.push_back(entry.second.second);
663  return liveObjects;
664 }
665 
667  for (auto &op : liveOperations)
668  op.second.second->setInvalid();
669  size_t numInvalidated = liveOperations.size();
670  liveOperations.clear();
671  return numInvalidated;
672 }
673 
674 void PyMlirContext::clearOperation(MlirOperation op) {
675  auto it = liveOperations.find(op.ptr);
676  if (it != liveOperations.end()) {
677  it->second.second->setInvalid();
678  liveOperations.erase(it);
679  }
680 }
681 
683  typedef struct {
684  PyOperation &rootOp;
685  bool rootSeen;
686  } callBackData;
687  callBackData data{op.getOperation(), false};
688  // Mark all ops below the op that the passmanager will be rooted
689  // at (but not op itself - note the preorder) as invalid.
690  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
691  void *userData) {
692  callBackData *data = static_cast<callBackData *>(userData);
693  if (LLVM_LIKELY(data->rootSeen))
694  data->rootOp.getOperation().getContext()->clearOperation(op);
695  else
696  data->rootSeen = true;
698  };
699  mlirOperationWalk(op.getOperation(), invalidatingCallback,
700  static_cast<void *>(&data), MlirWalkPreOrder);
701 }
702 void PyMlirContext::clearOperationsInside(MlirOperation op) {
705 }
706 
708  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
709  void *userData) {
710  PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
711  contextRef->clearOperation(op);
713  };
714  mlirOperationWalk(op.getOperation(), invalidatingCallback,
716 }
717 
718 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
719 
720 pybind11::object PyMlirContext::contextEnter() {
721  return PyThreadContextEntry::pushContext(*this);
722 }
723 
724 void PyMlirContext::contextExit(const pybind11::object &excType,
725  const pybind11::object &excVal,
726  const pybind11::object &excTb) {
728 }
729 
730 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
731  // Note that ownership is transferred to the delete callback below by way of
732  // an explicit inc_ref (borrow).
733  PyDiagnosticHandler *pyHandler =
734  new PyDiagnosticHandler(get(), std::move(callback));
735  py::object pyHandlerObject =
736  py::cast(pyHandler, py::return_value_policy::take_ownership);
737  pyHandlerObject.inc_ref();
738 
739  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
740  // guaranteed to be known to pybind.
741  auto handlerCallback =
742  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
743  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
744  py::object pyDiagnosticObject =
745  py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
746 
747  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
748  bool result = false;
749  {
750  // Since this can be called from arbitrary C++ contexts, always get the
751  // gil.
752  py::gil_scoped_acquire gil;
753  try {
754  result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
755  } catch (std::exception &e) {
756  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
757  e.what());
758  pyHandler->hadError = true;
759  }
760  }
761 
762  pyDiagnostic->invalidate();
764  };
765  auto deleteCallback = +[](void *userData) {
766  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
767  assert(pyHandler->registeredID && "handler is not registered");
768  pyHandler->registeredID.reset();
769 
770  // Decrement reference, balancing the inc_ref() above.
771  py::object pyHandlerObject =
772  py::cast(pyHandler, py::return_value_policy::reference);
773  pyHandlerObject.dec_ref();
774  };
775 
776  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
777  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
778  return pyHandlerObject;
779 }
780 
781 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
782  void *userData) {
783  auto *self = static_cast<ErrorCapture *>(userData);
784  // Check if the context requested we emit errors instead of capturing them.
785  if (self->ctx->emitErrorDiagnostics)
786  return mlirLogicalResultFailure();
787 
789  return mlirLogicalResultFailure();
790 
791  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
792  return mlirLogicalResultSuccess();
793 }
794 
797  if (!context) {
798  throw std::runtime_error(
799  "An MLIR function requires a Context but none was provided in the call "
800  "or from the surrounding environment. Either pass to the function with "
801  "a 'context=' argument or establish a default using 'with Context():'");
802  }
803  return *context;
804 }
805 
806 //------------------------------------------------------------------------------
807 // PyThreadContextEntry management
808 //------------------------------------------------------------------------------
809 
810 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
811  static thread_local std::vector<PyThreadContextEntry> stack;
812  return stack;
813 }
814 
816  auto &stack = getStack();
817  if (stack.empty())
818  return nullptr;
819  return &stack.back();
820 }
821 
822 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
823  py::object insertionPoint,
824  py::object location) {
825  auto &stack = getStack();
826  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
827  std::move(location));
828  // If the new stack has more than one entry and the context of the new top
829  // entry matches the previous, copy the insertionPoint and location from the
830  // previous entry if missing from the new top entry.
831  if (stack.size() > 1) {
832  auto &prev = *(stack.rbegin() + 1);
833  auto &current = stack.back();
834  if (current.context.is(prev.context)) {
835  // Default non-context objects from the previous entry.
836  if (!current.insertionPoint)
837  current.insertionPoint = prev.insertionPoint;
838  if (!current.location)
839  current.location = prev.location;
840  }
841  }
842 }
843 
845  if (!context)
846  return nullptr;
847  return py::cast<PyMlirContext *>(context);
848 }
849 
851  if (!insertionPoint)
852  return nullptr;
853  return py::cast<PyInsertionPoint *>(insertionPoint);
854 }
855 
857  if (!location)
858  return nullptr;
859  return py::cast<PyLocation *>(location);
860 }
861 
863  auto *tos = getTopOfStack();
864  return tos ? tos->getContext() : nullptr;
865 }
866 
868  auto *tos = getTopOfStack();
869  return tos ? tos->getInsertionPoint() : nullptr;
870 }
871 
873  auto *tos = getTopOfStack();
874  return tos ? tos->getLocation() : nullptr;
875 }
876 
878  py::object contextObj = py::cast(context);
879  push(FrameKind::Context, /*context=*/contextObj,
880  /*insertionPoint=*/py::object(),
881  /*location=*/py::object());
882  return contextObj;
883 }
884 
886  auto &stack = getStack();
887  if (stack.empty())
888  throw std::runtime_error("Unbalanced Context enter/exit");
889  auto &tos = stack.back();
890  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
891  throw std::runtime_error("Unbalanced Context enter/exit");
892  stack.pop_back();
893 }
894 
895 py::object
897  py::object contextObj =
898  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
899  py::object insertionPointObj = py::cast(insertionPoint);
900  push(FrameKind::InsertionPoint,
901  /*context=*/contextObj,
902  /*insertionPoint=*/insertionPointObj,
903  /*location=*/py::object());
904  return insertionPointObj;
905 }
906 
908  auto &stack = getStack();
909  if (stack.empty())
910  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
911  auto &tos = stack.back();
912  if (tos.frameKind != FrameKind::InsertionPoint &&
913  tos.getInsertionPoint() != &insertionPoint)
914  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
915  stack.pop_back();
916 }
917 
919  py::object contextObj = location.getContext().getObject();
920  py::object locationObj = py::cast(location);
921  push(FrameKind::Location, /*context=*/contextObj,
922  /*insertionPoint=*/py::object(),
923  /*location=*/locationObj);
924  return locationObj;
925 }
926 
928  auto &stack = getStack();
929  if (stack.empty())
930  throw std::runtime_error("Unbalanced Location enter/exit");
931  auto &tos = stack.back();
932  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
933  throw std::runtime_error("Unbalanced Location enter/exit");
934  stack.pop_back();
935 }
936 
937 //------------------------------------------------------------------------------
938 // PyDiagnostic*
939 //------------------------------------------------------------------------------
940 
942  valid = false;
943  if (materializedNotes) {
944  for (auto &noteObject : *materializedNotes) {
945  PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
946  note->invalidate();
947  }
948  }
949 }
950 
952  py::object callback)
953  : context(context), callback(std::move(callback)) {}
954 
956 
958  if (!registeredID)
959  return;
960  MlirDiagnosticHandlerID localID = *registeredID;
961  mlirContextDetachDiagnosticHandler(context, localID);
962  assert(!registeredID && "should have unregistered");
963  // Not strictly necessary but keeps stale pointers from being around to cause
964  // issues.
965  context = {nullptr};
966 }
967 
968 void PyDiagnostic::checkValid() {
969  if (!valid) {
970  throw std::invalid_argument(
971  "Diagnostic is invalid (used outside of callback)");
972  }
973 }
974 
976  checkValid();
977  return mlirDiagnosticGetSeverity(diagnostic);
978 }
979 
981  checkValid();
982  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
983  MlirContext context = mlirLocationGetContext(loc);
984  return PyLocation(PyMlirContext::forContext(context), loc);
985 }
986 
988  checkValid();
989  py::object fileObject = py::module::import("io").attr("StringIO")();
990  PyFileAccumulator accum(fileObject, /*binary=*/false);
991  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
992  return fileObject.attr("getvalue")();
993 }
994 
996  checkValid();
997  if (materializedNotes)
998  return *materializedNotes;
999  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1000  materializedNotes = py::tuple(numNotes);
1001  for (intptr_t i = 0; i < numNotes; ++i) {
1002  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1003  (*materializedNotes)[i] = PyDiagnostic(noteDiag);
1004  }
1005  return *materializedNotes;
1006 }
1007 
1009  std::vector<DiagnosticInfo> notes;
1010  for (py::handle n : getNotes())
1011  notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
1012  return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
1013 }
1014 
1015 //------------------------------------------------------------------------------
1016 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1017 //------------------------------------------------------------------------------
1018 
1019 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1020  bool attrError) {
1021  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1022  {key.data(), key.size()});
1023  if (mlirDialectIsNull(dialect)) {
1024  std::string msg = (Twine("Dialect '") + key + "' not found").str();
1025  if (attrError)
1026  throw py::attribute_error(msg);
1027  throw py::index_error(msg);
1028  }
1029  return dialect;
1030 }
1031 
1033  return py::reinterpret_steal<py::object>(
1035 }
1036 
1038  MlirDialectRegistry rawRegistry =
1039  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1040  if (mlirDialectRegistryIsNull(rawRegistry))
1041  throw py::error_already_set();
1042  return PyDialectRegistry(rawRegistry);
1043 }
1044 
1045 //------------------------------------------------------------------------------
1046 // PyLocation
1047 //------------------------------------------------------------------------------
1048 
1050  return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
1051 }
1052 
1054  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1055  if (mlirLocationIsNull(rawLoc))
1056  throw py::error_already_set();
1058  rawLoc);
1059 }
1060 
1062  return PyThreadContextEntry::pushLocation(*this);
1063 }
1064 
1065 void PyLocation::contextExit(const pybind11::object &excType,
1066  const pybind11::object &excVal,
1067  const pybind11::object &excTb) {
1069 }
1070 
1072  auto *location = PyThreadContextEntry::getDefaultLocation();
1073  if (!location) {
1074  throw std::runtime_error(
1075  "An MLIR function requires a Location but none was provided in the "
1076  "call or from the surrounding environment. Either pass to the function "
1077  "with a 'loc=' argument or establish a default using 'with loc:'");
1078  }
1079  return *location;
1080 }
1081 
1082 //------------------------------------------------------------------------------
1083 // PyModule
1084 //------------------------------------------------------------------------------
1085 
1086 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1087  : BaseContextObject(std::move(contextRef)), module(module) {}
1088 
1090  py::gil_scoped_acquire acquire;
1091  auto &liveModules = getContext()->liveModules;
1092  assert(liveModules.count(module.ptr) == 1 &&
1093  "destroying module not in live map");
1094  liveModules.erase(module.ptr);
1095  mlirModuleDestroy(module);
1096 }
1097 
1098 PyModuleRef PyModule::forModule(MlirModule module) {
1099  MlirContext context = mlirModuleGetContext(module);
1100  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1101 
1102  py::gil_scoped_acquire acquire;
1103  auto &liveModules = contextRef->liveModules;
1104  auto it = liveModules.find(module.ptr);
1105  if (it == liveModules.end()) {
1106  // Create.
1107  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1108  // Note that the default return value policy on cast is automatic_reference,
1109  // which does not take ownership (delete will not be called).
1110  // Just be explicit.
1111  py::object pyRef =
1112  py::cast(unownedModule, py::return_value_policy::take_ownership);
1113  unownedModule->handle = pyRef;
1114  liveModules[module.ptr] =
1115  std::make_pair(unownedModule->handle, unownedModule);
1116  return PyModuleRef(unownedModule, std::move(pyRef));
1117  }
1118  // Use existing.
1119  PyModule *existing = it->second.second;
1120  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1121  return PyModuleRef(existing, std::move(pyRef));
1122 }
1123 
1124 py::object PyModule::createFromCapsule(py::object capsule) {
1125  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1126  if (mlirModuleIsNull(rawModule))
1127  throw py::error_already_set();
1128  return forModule(rawModule).releaseObject();
1129 }
1130 
1131 py::object PyModule::getCapsule() {
1132  return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
1133 }
1134 
1135 //------------------------------------------------------------------------------
1136 // PyOperation
1137 //------------------------------------------------------------------------------
1138 
1139 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1140  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1141 
1143  // If the operation has already been invalidated there is nothing to do.
1144  if (!valid)
1145  return;
1146 
1147  // Otherwise, invalidate the operation and remove it from live map when it is
1148  // attached.
1149  if (isAttached()) {
1150  getContext()->clearOperation(*this);
1151  } else {
1152  // And destroy it when it is detached, i.e. owned by Python, in which case
1153  // all nested operations must be invalidated at removed from the live map as
1154  // well.
1155  erase();
1156  }
1157 }
1158 
1159 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1160  MlirOperation operation,
1161  py::object parentKeepAlive) {
1162  auto &liveOperations = contextRef->liveOperations;
1163  // Create.
1164  PyOperation *unownedOperation =
1165  new PyOperation(std::move(contextRef), operation);
1166  // Note that the default return value policy on cast is automatic_reference,
1167  // which does not take ownership (delete will not be called).
1168  // Just be explicit.
1169  py::object pyRef =
1170  py::cast(unownedOperation, py::return_value_policy::take_ownership);
1171  unownedOperation->handle = pyRef;
1172  if (parentKeepAlive) {
1173  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1174  }
1175  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1176  return PyOperationRef(unownedOperation, std::move(pyRef));
1177 }
1178 
1180  MlirOperation operation,
1181  py::object parentKeepAlive) {
1182  auto &liveOperations = contextRef->liveOperations;
1183  auto it = liveOperations.find(operation.ptr);
1184  if (it == liveOperations.end()) {
1185  // Create.
1186  return createInstance(std::move(contextRef), operation,
1187  std::move(parentKeepAlive));
1188  }
1189  // Use existing.
1190  PyOperation *existing = it->second.second;
1191  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1192  return PyOperationRef(existing, std::move(pyRef));
1193 }
1194 
1196  MlirOperation operation,
1197  py::object parentKeepAlive) {
1198  auto &liveOperations = contextRef->liveOperations;
1199  assert(liveOperations.count(operation.ptr) == 0 &&
1200  "cannot create detached operation that already exists");
1201  (void)liveOperations;
1202 
1203  PyOperationRef created = createInstance(std::move(contextRef), operation,
1204  std::move(parentKeepAlive));
1205  created->attached = false;
1206  return created;
1207 }
1208 
1210  const std::string &sourceStr,
1211  const std::string &sourceName) {
1212  PyMlirContext::ErrorCapture errors(contextRef);
1213  MlirOperation op =
1214  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1215  toMlirStringRef(sourceName));
1216  if (mlirOperationIsNull(op))
1217  throw MLIRError("Unable to parse operation assembly", errors.take());
1218  return PyOperation::createDetached(std::move(contextRef), op);
1219 }
1220 
1222  if (!valid) {
1223  throw std::runtime_error("the operation has been invalidated");
1224  }
1225 }
1226 
1227 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1228  bool enableDebugInfo, bool prettyDebugInfo,
1229  bool printGenericOpForm, bool useLocalScope,
1230  bool assumeVerified, py::object fileObject,
1231  bool binary, bool skipRegions) {
1232  PyOperation &operation = getOperation();
1233  operation.checkValid();
1234  if (fileObject.is_none())
1235  fileObject = py::module::import("sys").attr("stdout");
1236 
1237  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1238  if (largeElementsLimit)
1239  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1240  if (enableDebugInfo)
1241  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1242  /*prettyForm=*/prettyDebugInfo);
1243  if (printGenericOpForm)
1245  if (useLocalScope)
1247  if (assumeVerified)
1249  if (skipRegions)
1251 
1252  PyFileAccumulator accum(fileObject, binary);
1253  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1254  accum.getUserData());
1256 }
1257 
1258 void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1259  bool binary) {
1260  PyOperation &operation = getOperation();
1261  operation.checkValid();
1262  if (fileObject.is_none())
1263  fileObject = py::module::import("sys").attr("stdout");
1264  PyFileAccumulator accum(fileObject, binary);
1265  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1266  accum.getUserData());
1267 }
1268 
1269 void PyOperationBase::writeBytecode(const py::object &fileObject,
1270  std::optional<int64_t> bytecodeVersion) {
1271  PyOperation &operation = getOperation();
1272  operation.checkValid();
1273  PyFileAccumulator accum(fileObject, /*binary=*/true);
1274 
1275  if (!bytecodeVersion.has_value())
1276  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1277  accum.getUserData());
1278 
1279  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1280  mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1282  operation, config, accum.getCallback(), accum.getUserData());
1284  if (mlirLogicalResultIsFailure(res))
1285  throw py::value_error((Twine("Unable to honor desired bytecode version ") +
1286  Twine(*bytecodeVersion))
1287  .str());
1288 }
1289 
1291  std::function<MlirWalkResult(MlirOperation)> callback,
1292  MlirWalkOrder walkOrder) {
1293  PyOperation &operation = getOperation();
1294  operation.checkValid();
1295  struct UserData {
1296  std::function<MlirWalkResult(MlirOperation)> callback;
1297  bool gotException;
1298  std::string exceptionWhat;
1299  py::object exceptionType;
1300  };
1301  UserData userData{callback, false, {}, {}};
1302  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1303  void *userData) {
1304  UserData *calleeUserData = static_cast<UserData *>(userData);
1305  try {
1306  return (calleeUserData->callback)(op);
1307  } catch (py::error_already_set &e) {
1308  calleeUserData->gotException = true;
1309  calleeUserData->exceptionWhat = e.what();
1310  calleeUserData->exceptionType = e.type();
1312  }
1313  };
1314  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1315  if (userData.gotException) {
1316  std::string message("Exception raised in callback: ");
1317  message.append(userData.exceptionWhat);
1318  throw std::runtime_error(message);
1319  }
1320 }
1321 
1322 py::object PyOperationBase::getAsm(bool binary,
1323  std::optional<int64_t> largeElementsLimit,
1324  bool enableDebugInfo, bool prettyDebugInfo,
1325  bool printGenericOpForm, bool useLocalScope,
1326  bool assumeVerified, bool skipRegions) {
1327  py::object fileObject;
1328  if (binary) {
1329  fileObject = py::module::import("io").attr("BytesIO")();
1330  } else {
1331  fileObject = py::module::import("io").attr("StringIO")();
1332  }
1333  print(/*largeElementsLimit=*/largeElementsLimit,
1334  /*enableDebugInfo=*/enableDebugInfo,
1335  /*prettyDebugInfo=*/prettyDebugInfo,
1336  /*printGenericOpForm=*/printGenericOpForm,
1337  /*useLocalScope=*/useLocalScope,
1338  /*assumeVerified=*/assumeVerified,
1339  /*fileObject=*/fileObject,
1340  /*binary=*/binary,
1341  /*skipRegions=*/skipRegions);
1342 
1343  return fileObject.attr("getvalue")();
1344 }
1345 
1347  PyOperation &operation = getOperation();
1348  PyOperation &otherOp = other.getOperation();
1349  operation.checkValid();
1350  otherOp.checkValid();
1351  mlirOperationMoveAfter(operation, otherOp);
1352  operation.parentKeepAlive = otherOp.parentKeepAlive;
1353 }
1354 
1356  PyOperation &operation = getOperation();
1357  PyOperation &otherOp = other.getOperation();
1358  operation.checkValid();
1359  otherOp.checkValid();
1360  mlirOperationMoveBefore(operation, otherOp);
1361  operation.parentKeepAlive = otherOp.parentKeepAlive;
1362 }
1363 
1365  PyOperation &op = getOperation();
1367  if (!mlirOperationVerify(op.get()))
1368  throw MLIRError("Verification failed", errors.take());
1369  return true;
1370 }
1371 
1372 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1373  checkValid();
1374  if (!isAttached())
1375  throw py::value_error("Detached operations have no parent");
1376  MlirOperation operation = mlirOperationGetParentOperation(get());
1377  if (mlirOperationIsNull(operation))
1378  return {};
1379  return PyOperation::forOperation(getContext(), operation);
1380 }
1381 
1383  checkValid();
1384  std::optional<PyOperationRef> parentOperation = getParentOperation();
1385  MlirBlock block = mlirOperationGetBlock(get());
1386  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1387  assert(parentOperation && "Operation has no parent");
1388  return PyBlock{std::move(*parentOperation), block};
1389 }
1390 
1392  checkValid();
1393  return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1394 }
1395 
1396 py::object PyOperation::createFromCapsule(py::object capsule) {
1397  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1398  if (mlirOperationIsNull(rawOperation))
1399  throw py::error_already_set();
1400  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1401  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1402  .releaseObject();
1403 }
1404 
1406  const py::object &maybeIp) {
1407  // InsertPoint active?
1408  if (!maybeIp.is(py::cast(false))) {
1409  PyInsertionPoint *ip;
1410  if (maybeIp.is_none()) {
1412  } else {
1413  ip = py::cast<PyInsertionPoint *>(maybeIp);
1414  }
1415  if (ip)
1416  ip->insert(*op.get());
1417  }
1418 }
1419 
1420 py::object PyOperation::create(const std::string &name,
1421  std::optional<std::vector<PyType *>> results,
1422  std::optional<std::vector<PyValue *>> operands,
1423  std::optional<py::dict> attributes,
1424  std::optional<std::vector<PyBlock *>> successors,
1425  int regions, DefaultingPyLocation location,
1426  const py::object &maybeIp, bool inferType) {
1427  llvm::SmallVector<MlirValue, 4> mlirOperands;
1428  llvm::SmallVector<MlirType, 4> mlirResults;
1429  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1431 
1432  // General parameter validation.
1433  if (regions < 0)
1434  throw py::value_error("number of regions must be >= 0");
1435 
1436  // Unpack/validate operands.
1437  if (operands) {
1438  mlirOperands.reserve(operands->size());
1439  for (PyValue *operand : *operands) {
1440  if (!operand)
1441  throw py::value_error("operand value cannot be None");
1442  mlirOperands.push_back(operand->get());
1443  }
1444  }
1445 
1446  // Unpack/validate results.
1447  if (results) {
1448  mlirResults.reserve(results->size());
1449  for (PyType *result : *results) {
1450  // TODO: Verify result type originate from the same context.
1451  if (!result)
1452  throw py::value_error("result type cannot be None");
1453  mlirResults.push_back(*result);
1454  }
1455  }
1456  // Unpack/validate attributes.
1457  if (attributes) {
1458  mlirAttributes.reserve(attributes->size());
1459  for (auto &it : *attributes) {
1460  std::string key;
1461  try {
1462  key = it.first.cast<std::string>();
1463  } catch (py::cast_error &err) {
1464  std::string msg = "Invalid attribute key (not a string) when "
1465  "attempting to create the operation \"" +
1466  name + "\" (" + err.what() + ")";
1467  throw py::cast_error(msg);
1468  }
1469  try {
1470  auto &attribute = it.second.cast<PyAttribute &>();
1471  // TODO: Verify attribute originates from the same context.
1472  mlirAttributes.emplace_back(std::move(key), attribute);
1473  } catch (py::reference_cast_error &) {
1474  // This exception seems thrown when the value is "None".
1475  std::string msg =
1476  "Found an invalid (`None`?) attribute value for the key \"" + key +
1477  "\" when attempting to create the operation \"" + name + "\"";
1478  throw py::cast_error(msg);
1479  } catch (py::cast_error &err) {
1480  std::string msg = "Invalid attribute value for the key \"" + key +
1481  "\" when attempting to create the operation \"" +
1482  name + "\" (" + err.what() + ")";
1483  throw py::cast_error(msg);
1484  }
1485  }
1486  }
1487  // Unpack/validate successors.
1488  if (successors) {
1489  mlirSuccessors.reserve(successors->size());
1490  for (auto *successor : *successors) {
1491  // TODO: Verify successor originate from the same context.
1492  if (!successor)
1493  throw py::value_error("successor block cannot be None");
1494  mlirSuccessors.push_back(successor->get());
1495  }
1496  }
1497 
1498  // Apply unpacked/validated to the operation state. Beyond this
1499  // point, exceptions cannot be thrown or else the state will leak.
1500  MlirOperationState state =
1501  mlirOperationStateGet(toMlirStringRef(name), location);
1502  if (!mlirOperands.empty())
1503  mlirOperationStateAddOperands(&state, mlirOperands.size(),
1504  mlirOperands.data());
1505  state.enableResultTypeInference = inferType;
1506  if (!mlirResults.empty())
1507  mlirOperationStateAddResults(&state, mlirResults.size(),
1508  mlirResults.data());
1509  if (!mlirAttributes.empty()) {
1510  // Note that the attribute names directly reference bytes in
1511  // mlirAttributes, so that vector must not be changed from here
1512  // on.
1513  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1514  mlirNamedAttributes.reserve(mlirAttributes.size());
1515  for (auto &it : mlirAttributes)
1516  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1518  toMlirStringRef(it.first)),
1519  it.second));
1520  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1521  mlirNamedAttributes.data());
1522  }
1523  if (!mlirSuccessors.empty())
1524  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1525  mlirSuccessors.data());
1526  if (regions) {
1528  mlirRegions.resize(regions);
1529  for (int i = 0; i < regions; ++i)
1530  mlirRegions[i] = mlirRegionCreate();
1531  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1532  mlirRegions.data());
1533  }
1534 
1535  // Construct the operation.
1536  MlirOperation operation = mlirOperationCreate(&state);
1537  if (!operation.ptr)
1538  throw py::value_error("Operation creation failed");
1539  PyOperationRef created =
1540  PyOperation::createDetached(location->getContext(), operation);
1541  maybeInsertOperation(created, maybeIp);
1542 
1543  return created.getObject();
1544 }
1545 
1546 py::object PyOperation::clone(const py::object &maybeIp) {
1547  MlirOperation clonedOperation = mlirOperationClone(operation);
1548  PyOperationRef cloned =
1549  PyOperation::createDetached(getContext(), clonedOperation);
1550  maybeInsertOperation(cloned, maybeIp);
1551 
1552  return cloned->createOpView();
1553 }
1554 
1556  checkValid();
1557  MlirIdentifier ident = mlirOperationGetName(get());
1558  MlirStringRef identStr = mlirIdentifierStr(ident);
1559  auto operationCls = PyGlobals::get().lookupOperationClass(
1560  StringRef(identStr.data, identStr.length));
1561  if (operationCls)
1562  return PyOpView::constructDerived(*operationCls, *getRef().get());
1563  return py::cast(PyOpView(getRef().getObject()));
1564 }
1565 
1567  checkValid();
1569  mlirOperationDestroy(operation);
1570 }
1571 
1572 //------------------------------------------------------------------------------
1573 // PyOpView
1574 //------------------------------------------------------------------------------
1575 
1576 static void populateResultTypes(StringRef name, py::list resultTypeList,
1577  const py::object &resultSegmentSpecObj,
1578  std::vector<int32_t> &resultSegmentLengths,
1579  std::vector<PyType *> &resultTypes) {
1580  resultTypes.reserve(resultTypeList.size());
1581  if (resultSegmentSpecObj.is_none()) {
1582  // Non-variadic result unpacking.
1583  for (const auto &it : llvm::enumerate(resultTypeList)) {
1584  try {
1585  resultTypes.push_back(py::cast<PyType *>(it.value()));
1586  if (!resultTypes.back())
1587  throw py::cast_error();
1588  } catch (py::cast_error &err) {
1589  throw py::value_error((llvm::Twine("Result ") +
1590  llvm::Twine(it.index()) + " of operation \"" +
1591  name + "\" must be a Type (" + err.what() + ")")
1592  .str());
1593  }
1594  }
1595  } else {
1596  // Sized result unpacking.
1597  auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1598  if (resultSegmentSpec.size() != resultTypeList.size()) {
1599  throw py::value_error((llvm::Twine("Operation \"") + name +
1600  "\" requires " +
1601  llvm::Twine(resultSegmentSpec.size()) +
1602  " result segments but was provided " +
1603  llvm::Twine(resultTypeList.size()))
1604  .str());
1605  }
1606  resultSegmentLengths.reserve(resultTypeList.size());
1607  for (const auto &it :
1608  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1609  int segmentSpec = std::get<1>(it.value());
1610  if (segmentSpec == 1 || segmentSpec == 0) {
1611  // Unpack unary element.
1612  try {
1613  auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1614  if (resultType) {
1615  resultTypes.push_back(resultType);
1616  resultSegmentLengths.push_back(1);
1617  } else if (segmentSpec == 0) {
1618  // Allowed to be optional.
1619  resultSegmentLengths.push_back(0);
1620  } else {
1621  throw py::cast_error("was None and result is not optional");
1622  }
1623  } catch (py::cast_error &err) {
1624  throw py::value_error((llvm::Twine("Result ") +
1625  llvm::Twine(it.index()) + " of operation \"" +
1626  name + "\" must be a Type (" + err.what() +
1627  ")")
1628  .str());
1629  }
1630  } else if (segmentSpec == -1) {
1631  // Unpack sequence by appending.
1632  try {
1633  if (std::get<0>(it.value()).is_none()) {
1634  // Treat it as an empty list.
1635  resultSegmentLengths.push_back(0);
1636  } else {
1637  // Unpack the list.
1638  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1639  for (py::object segmentItem : segment) {
1640  resultTypes.push_back(py::cast<PyType *>(segmentItem));
1641  if (!resultTypes.back()) {
1642  throw py::cast_error("contained a None item");
1643  }
1644  }
1645  resultSegmentLengths.push_back(segment.size());
1646  }
1647  } catch (std::exception &err) {
1648  // NOTE: Sloppy to be using a catch-all here, but there are at least
1649  // three different unrelated exceptions that can be thrown in the
1650  // above "casts". Just keep the scope above small and catch them all.
1651  throw py::value_error((llvm::Twine("Result ") +
1652  llvm::Twine(it.index()) + " of operation \"" +
1653  name + "\" must be a Sequence of Types (" +
1654  err.what() + ")")
1655  .str());
1656  }
1657  } else {
1658  throw py::value_error("Unexpected segment spec");
1659  }
1660  }
1661  }
1662 }
1663 
1665  const py::object &cls, std::optional<py::list> resultTypeList,
1666  py::list operandList, std::optional<py::dict> attributes,
1667  std::optional<std::vector<PyBlock *>> successors,
1668  std::optional<int> regions, DefaultingPyLocation location,
1669  const py::object &maybeIp) {
1670  PyMlirContextRef context = location->getContext();
1671  // Class level operation construction metadata.
1672  std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1673  // Operand and result segment specs are either none, which does no
1674  // variadic unpacking, or a list of ints with segment sizes, where each
1675  // element is either a positive number (typically 1 for a scalar) or -1 to
1676  // indicate that it is derived from the length of the same-indexed operand
1677  // or result (implying that it is a list at that position).
1678  py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1679  py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1680 
1681  std::vector<int32_t> operandSegmentLengths;
1682  std::vector<int32_t> resultSegmentLengths;
1683 
1684  // Validate/determine region count.
1685  auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1686  int opMinRegionCount = std::get<0>(opRegionSpec);
1687  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1688  if (!regions) {
1689  regions = opMinRegionCount;
1690  }
1691  if (*regions < opMinRegionCount) {
1692  throw py::value_error(
1693  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1694  llvm::Twine(opMinRegionCount) +
1695  " regions but was built with regions=" + llvm::Twine(*regions))
1696  .str());
1697  }
1698  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1699  throw py::value_error(
1700  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1701  llvm::Twine(opMinRegionCount) +
1702  " regions but was built with regions=" + llvm::Twine(*regions))
1703  .str());
1704  }
1705 
1706  // Unpack results.
1707  std::vector<PyType *> resultTypes;
1708  if (resultTypeList.has_value()) {
1709  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1710  resultSegmentLengths, resultTypes);
1711  }
1712 
1713  // Unpack operands.
1714  std::vector<PyValue *> operands;
1715  operands.reserve(operands.size());
1716  if (operandSegmentSpecObj.is_none()) {
1717  // Non-sized operand unpacking.
1718  for (const auto &it : llvm::enumerate(operandList)) {
1719  try {
1720  operands.push_back(py::cast<PyValue *>(it.value()));
1721  if (!operands.back())
1722  throw py::cast_error();
1723  } catch (py::cast_error &err) {
1724  throw py::value_error((llvm::Twine("Operand ") +
1725  llvm::Twine(it.index()) + " of operation \"" +
1726  name + "\" must be a Value (" + err.what() + ")")
1727  .str());
1728  }
1729  }
1730  } else {
1731  // Sized operand unpacking.
1732  auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1733  if (operandSegmentSpec.size() != operandList.size()) {
1734  throw py::value_error((llvm::Twine("Operation \"") + name +
1735  "\" requires " +
1736  llvm::Twine(operandSegmentSpec.size()) +
1737  "operand segments but was provided " +
1738  llvm::Twine(operandList.size()))
1739  .str());
1740  }
1741  operandSegmentLengths.reserve(operandList.size());
1742  for (const auto &it :
1743  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1744  int segmentSpec = std::get<1>(it.value());
1745  if (segmentSpec == 1 || segmentSpec == 0) {
1746  // Unpack unary element.
1747  try {
1748  auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1749  if (operandValue) {
1750  operands.push_back(operandValue);
1751  operandSegmentLengths.push_back(1);
1752  } else if (segmentSpec == 0) {
1753  // Allowed to be optional.
1754  operandSegmentLengths.push_back(0);
1755  } else {
1756  throw py::cast_error("was None and operand is not optional");
1757  }
1758  } catch (py::cast_error &err) {
1759  throw py::value_error((llvm::Twine("Operand ") +
1760  llvm::Twine(it.index()) + " of operation \"" +
1761  name + "\" must be a Value (" + err.what() +
1762  ")")
1763  .str());
1764  }
1765  } else if (segmentSpec == -1) {
1766  // Unpack sequence by appending.
1767  try {
1768  if (std::get<0>(it.value()).is_none()) {
1769  // Treat it as an empty list.
1770  operandSegmentLengths.push_back(0);
1771  } else {
1772  // Unpack the list.
1773  auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1774  for (py::object segmentItem : segment) {
1775  operands.push_back(py::cast<PyValue *>(segmentItem));
1776  if (!operands.back()) {
1777  throw py::cast_error("contained a None item");
1778  }
1779  }
1780  operandSegmentLengths.push_back(segment.size());
1781  }
1782  } catch (std::exception &err) {
1783  // NOTE: Sloppy to be using a catch-all here, but there are at least
1784  // three different unrelated exceptions that can be thrown in the
1785  // above "casts". Just keep the scope above small and catch them all.
1786  throw py::value_error((llvm::Twine("Operand ") +
1787  llvm::Twine(it.index()) + " of operation \"" +
1788  name + "\" must be a Sequence of Values (" +
1789  err.what() + ")")
1790  .str());
1791  }
1792  } else {
1793  throw py::value_error("Unexpected segment spec");
1794  }
1795  }
1796  }
1797 
1798  // Merge operand/result segment lengths into attributes if needed.
1799  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1800  // Dup.
1801  if (attributes) {
1802  attributes = py::dict(*attributes);
1803  } else {
1804  attributes = py::dict();
1805  }
1806  if (attributes->contains("resultSegmentSizes") ||
1807  attributes->contains("operandSegmentSizes")) {
1808  throw py::value_error("Manually setting a 'resultSegmentSizes' or "
1809  "'operandSegmentSizes' attribute is unsupported. "
1810  "Use Operation.create for such low-level access.");
1811  }
1812 
1813  // Add resultSegmentSizes attribute.
1814  if (!resultSegmentLengths.empty()) {
1815  MlirAttribute segmentLengthAttr =
1816  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1817  resultSegmentLengths.data());
1818  (*attributes)["resultSegmentSizes"] =
1819  PyAttribute(context, segmentLengthAttr);
1820  }
1821 
1822  // Add operandSegmentSizes attribute.
1823  if (!operandSegmentLengths.empty()) {
1824  MlirAttribute segmentLengthAttr =
1825  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1826  operandSegmentLengths.data());
1827  (*attributes)["operandSegmentSizes"] =
1828  PyAttribute(context, segmentLengthAttr);
1829  }
1830  }
1831 
1832  // Delegate to create.
1833  return PyOperation::create(name,
1834  /*results=*/std::move(resultTypes),
1835  /*operands=*/std::move(operands),
1836  /*attributes=*/std::move(attributes),
1837  /*successors=*/std::move(successors),
1838  /*regions=*/*regions, location, maybeIp,
1839  !resultTypeList);
1840 }
1841 
1842 pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1843  const PyOperation &operation) {
1844  // TODO: pybind11 2.6 supports a more direct form.
1845  // Upgrade many years from now.
1846  // auto opViewType = py::type::of<PyOpView>();
1847  py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1848  py::object instance = cls.attr("__new__")(cls);
1849  opViewType.attr("__init__")(instance, operation);
1850  return instance;
1851 }
1852 
1853 PyOpView::PyOpView(const py::object &operationObject)
1854  // Casting through the PyOperationBase base-class and then back to the
1855  // Operation lets us accept any PyOperationBase subclass.
1856  : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1857  operationObject(operation.getRef().getObject()) {}
1858 
1859 //------------------------------------------------------------------------------
1860 // PyInsertionPoint.
1861 //------------------------------------------------------------------------------
1862 
1864 
1866  : refOperation(beforeOperationBase.getOperation().getRef()),
1867  block((*refOperation)->getBlock()) {}
1868 
1870  PyOperation &operation = operationBase.getOperation();
1871  if (operation.isAttached())
1872  throw py::value_error(
1873  "Attempt to insert operation that is already attached");
1874  block.getParentOperation()->checkValid();
1875  MlirOperation beforeOp = {nullptr};
1876  if (refOperation) {
1877  // Insert before operation.
1878  (*refOperation)->checkValid();
1879  beforeOp = (*refOperation)->get();
1880  } else {
1881  // Insert at end (before null) is only valid if the block does not
1882  // already end in a known terminator (violating this will cause assertion
1883  // failures later).
1884  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1885  throw py::index_error("Cannot insert operation at the end of a block "
1886  "that already has a terminator. Did you mean to "
1887  "use 'InsertionPoint.at_block_terminator(block)' "
1888  "versus 'InsertionPoint(block)'?");
1889  }
1890  }
1891  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1892  operation.setAttached();
1893 }
1894 
1896  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1897  if (mlirOperationIsNull(firstOp)) {
1898  // Just insert at end.
1899  return PyInsertionPoint(block);
1900  }
1901 
1902  // Insert before first op.
1904  block.getParentOperation()->getContext(), firstOp);
1905  return PyInsertionPoint{block, std::move(firstOpRef)};
1906 }
1907 
1909  MlirOperation terminator = mlirBlockGetTerminator(block.get());
1910  if (mlirOperationIsNull(terminator))
1911  throw py::value_error("Block has no terminator");
1912  PyOperationRef terminatorOpRef = PyOperation::forOperation(
1913  block.getParentOperation()->getContext(), terminator);
1914  return PyInsertionPoint{block, std::move(terminatorOpRef)};
1915 }
1916 
1919 }
1920 
1921 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1922  const pybind11::object &excVal,
1923  const pybind11::object &excTb) {
1925 }
1926 
1927 //------------------------------------------------------------------------------
1928 // PyAttribute.
1929 //------------------------------------------------------------------------------
1930 
1931 bool PyAttribute::operator==(const PyAttribute &other) const {
1932  return mlirAttributeEqual(attr, other.attr);
1933 }
1934 
1936  return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1937 }
1938 
1940  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1941  if (mlirAttributeIsNull(rawAttr))
1942  throw py::error_already_set();
1943  return PyAttribute(
1945 }
1946 
1947 //------------------------------------------------------------------------------
1948 // PyNamedAttribute.
1949 //------------------------------------------------------------------------------
1950 
1951 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1952  : ownedName(new std::string(std::move(ownedName))) {
1955  toMlirStringRef(*this->ownedName)),
1956  attr);
1957 }
1958 
1959 //------------------------------------------------------------------------------
1960 // PyType.
1961 //------------------------------------------------------------------------------
1962 
1963 bool PyType::operator==(const PyType &other) const {
1964  return mlirTypeEqual(type, other.type);
1965 }
1966 
1967 py::object PyType::getCapsule() {
1968  return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1969 }
1970 
1971 PyType PyType::createFromCapsule(py::object capsule) {
1972  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1973  if (mlirTypeIsNull(rawType))
1974  throw py::error_already_set();
1976  rawType);
1977 }
1978 
1979 //------------------------------------------------------------------------------
1980 // PyTypeID.
1981 //------------------------------------------------------------------------------
1982 
1983 py::object PyTypeID::getCapsule() {
1984  return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
1985 }
1986 
1988  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1989  if (mlirTypeIDIsNull(mlirTypeID))
1990  throw py::error_already_set();
1991  return PyTypeID(mlirTypeID);
1992 }
1993 bool PyTypeID::operator==(const PyTypeID &other) const {
1994  return mlirTypeIDEqual(typeID, other.typeID);
1995 }
1996 
1997 //------------------------------------------------------------------------------
1998 // PyValue and subclasses.
1999 //------------------------------------------------------------------------------
2000 
2001 pybind11::object PyValue::getCapsule() {
2002  return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
2003 }
2004 
2005 pybind11::object PyValue::maybeDownCast() {
2006  MlirType type = mlirValueGetType(get());
2007  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2008  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2009  "mlirTypeID was expected to be non-null.");
2010  std::optional<pybind11::function> valueCaster =
2012  // py::return_value_policy::move means use std::move to move the return value
2013  // contents into a new instance that will be owned by Python.
2014  py::object thisObj = py::cast(this, py::return_value_policy::move);
2015  if (!valueCaster)
2016  return thisObj;
2017  return valueCaster.value()(thisObj);
2018 }
2019 
2020 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
2021  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2022  if (mlirValueIsNull(value))
2023  throw py::error_already_set();
2024  MlirOperation owner;
2025  if (mlirValueIsAOpResult(value))
2026  owner = mlirOpResultGetOwner(value);
2027  if (mlirValueIsABlockArgument(value))
2029  if (mlirOperationIsNull(owner))
2030  throw py::error_already_set();
2031  MlirContext ctx = mlirOperationGetContext(owner);
2032  PyOperationRef ownerRef =
2034  return PyValue(ownerRef, value);
2035 }
2036 
2037 //------------------------------------------------------------------------------
2038 // PySymbolTable.
2039 //------------------------------------------------------------------------------
2040 
2042  : operation(operation.getOperation().getRef()) {
2043  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2044  if (mlirSymbolTableIsNull(symbolTable)) {
2045  throw py::cast_error("Operation is not a Symbol Table.");
2046  }
2047 }
2048 
2049 py::object PySymbolTable::dunderGetItem(const std::string &name) {
2050  operation->checkValid();
2051  MlirOperation symbol = mlirSymbolTableLookup(
2052  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2053  if (mlirOperationIsNull(symbol))
2054  throw py::key_error("Symbol '" + name + "' not in the symbol table.");
2055 
2056  return PyOperation::forOperation(operation->getContext(), symbol,
2057  operation.getObject())
2058  ->createOpView();
2059 }
2060 
2062  operation->checkValid();
2063  symbol.getOperation().checkValid();
2064  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2065  // The operation is also erased, so we must invalidate it. There may be Python
2066  // references to this operation so we don't want to delete it from the list of
2067  // live operations here.
2068  symbol.getOperation().valid = false;
2069 }
2070 
2071 void PySymbolTable::dunderDel(const std::string &name) {
2072  py::object operation = dunderGetItem(name);
2073  erase(py::cast<PyOperationBase &>(operation));
2074 }
2075 
2076 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2077  operation->checkValid();
2078  symbol.getOperation().checkValid();
2079  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2081  if (mlirAttributeIsNull(symbolAttr))
2082  throw py::value_error("Expected operation to have a symbol name.");
2083  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2084 }
2085 
2087  // Op must already be a symbol.
2088  PyOperation &operation = symbol.getOperation();
2089  operation.checkValid();
2091  MlirAttribute existingNameAttr =
2092  mlirOperationGetAttributeByName(operation.get(), attrName);
2093  if (mlirAttributeIsNull(existingNameAttr))
2094  throw py::value_error("Expected operation to have a symbol name.");
2095  return existingNameAttr;
2096 }
2097 
2099  const std::string &name) {
2100  // Op must already be a symbol.
2101  PyOperation &operation = symbol.getOperation();
2102  operation.checkValid();
2104  MlirAttribute existingNameAttr =
2105  mlirOperationGetAttributeByName(operation.get(), attrName);
2106  if (mlirAttributeIsNull(existingNameAttr))
2107  throw py::value_error("Expected operation to have a symbol name.");
2108  MlirAttribute newNameAttr =
2109  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2110  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2111 }
2112 
2114  PyOperation &operation = symbol.getOperation();
2115  operation.checkValid();
2117  MlirAttribute existingVisAttr =
2118  mlirOperationGetAttributeByName(operation.get(), attrName);
2119  if (mlirAttributeIsNull(existingVisAttr))
2120  throw py::value_error("Expected operation to have a symbol visibility.");
2121  return existingVisAttr;
2122 }
2123 
2125  const std::string &visibility) {
2126  if (visibility != "public" && visibility != "private" &&
2127  visibility != "nested")
2128  throw py::value_error(
2129  "Expected visibility to be 'public', 'private' or 'nested'");
2130  PyOperation &operation = symbol.getOperation();
2131  operation.checkValid();
2133  MlirAttribute existingVisAttr =
2134  mlirOperationGetAttributeByName(operation.get(), attrName);
2135  if (mlirAttributeIsNull(existingVisAttr))
2136  throw py::value_error("Expected operation to have a symbol visibility.");
2137  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2138  toMlirStringRef(visibility));
2139  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2140 }
2141 
2142 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2143  const std::string &newSymbol,
2144  PyOperationBase &from) {
2145  PyOperation &fromOperation = from.getOperation();
2146  fromOperation.checkValid();
2148  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2149  from.getOperation())))
2150 
2151  throw py::value_error("Symbol rename failed");
2152 }
2153 
2155  bool allSymUsesVisible,
2156  py::object callback) {
2157  PyOperation &fromOperation = from.getOperation();
2158  fromOperation.checkValid();
2159  struct UserData {
2160  PyMlirContextRef context;
2161  py::object callback;
2162  bool gotException;
2163  std::string exceptionWhat;
2164  py::object exceptionType;
2165  };
2166  UserData userData{
2167  fromOperation.getContext(), std::move(callback), false, {}, {}};
2169  fromOperation.get(), allSymUsesVisible,
2170  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2171  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2172  auto pyFoundOp =
2173  PyOperation::forOperation(calleeUserData->context, foundOp);
2174  if (calleeUserData->gotException)
2175  return;
2176  try {
2177  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2178  } catch (py::error_already_set &e) {
2179  calleeUserData->gotException = true;
2180  calleeUserData->exceptionWhat = e.what();
2181  calleeUserData->exceptionType = e.type();
2182  }
2183  },
2184  static_cast<void *>(&userData));
2185  if (userData.gotException) {
2186  std::string message("Exception raised in callback: ");
2187  message.append(userData.exceptionWhat);
2188  throw std::runtime_error(message);
2189  }
2190 }
2191 
2192 namespace {
2193 /// CRTP base class for Python MLIR values that subclass Value and should be
2194 /// castable from it. The value hierarchy is one level deep and is not supposed
2195 /// to accommodate other levels unless core MLIR changes.
2196 template <typename DerivedTy>
2197 class PyConcreteValue : public PyValue {
2198 public:
2199  // Derived classes must define statics for:
2200  // IsAFunctionTy isaFunction
2201  // const char *pyClassName
2202  // and redefine bindDerived.
2203  using ClassTy = py::class_<DerivedTy, PyValue>;
2204  using IsAFunctionTy = bool (*)(MlirValue);
2205 
2206  PyConcreteValue() = default;
2207  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2208  : PyValue(operationRef, value) {}
2209  PyConcreteValue(PyValue &orig)
2210  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2211 
2212  /// Attempts to cast the original value to the derived type and throws on
2213  /// type mismatches.
2214  static MlirValue castFrom(PyValue &orig) {
2215  if (!DerivedTy::isaFunction(orig.get())) {
2216  auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2217  throw py::value_error((Twine("Cannot cast value to ") +
2218  DerivedTy::pyClassName + " (from " + origRepr +
2219  ")")
2220  .str());
2221  }
2222  return orig.get();
2223  }
2224 
2225  /// Binds the Python module objects to functions of this class.
2226  static void bind(py::module &m) {
2227  auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
2228  cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
2229  cls.def_static(
2230  "isinstance",
2231  [](PyValue &otherValue) -> bool {
2232  return DerivedTy::isaFunction(otherValue);
2233  },
2234  py::arg("other_value"));
2236  [](DerivedTy &self) { return self.maybeDownCast(); });
2237  DerivedTy::bindDerived(cls);
2238  }
2239 
2240  /// Implemented by derived classes to add methods to the Python subclass.
2241  static void bindDerived(ClassTy &m) {}
2242 };
2243 
2244 /// Python wrapper for MlirBlockArgument.
2245 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2246 public:
2247  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2248  static constexpr const char *pyClassName = "BlockArgument";
2249  using PyConcreteValue::PyConcreteValue;
2250 
2251  static void bindDerived(ClassTy &c) {
2252  c.def_property_readonly("owner", [](PyBlockArgument &self) {
2253  return PyBlock(self.getParentOperation(),
2254  mlirBlockArgumentGetOwner(self.get()));
2255  });
2256  c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
2257  return mlirBlockArgumentGetArgNumber(self.get());
2258  });
2259  c.def(
2260  "set_type",
2261  [](PyBlockArgument &self, PyType type) {
2262  return mlirBlockArgumentSetType(self.get(), type);
2263  },
2264  py::arg("type"));
2265  }
2266 };
2267 
2268 /// Python wrapper for MlirOpResult.
2269 class PyOpResult : public PyConcreteValue<PyOpResult> {
2270 public:
2271  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2272  static constexpr const char *pyClassName = "OpResult";
2273  using PyConcreteValue::PyConcreteValue;
2274 
2275  static void bindDerived(ClassTy &c) {
2276  c.def_property_readonly("owner", [](PyOpResult &self) {
2277  assert(
2278  mlirOperationEqual(self.getParentOperation()->get(),
2279  mlirOpResultGetOwner(self.get())) &&
2280  "expected the owner of the value in Python to match that in the IR");
2281  return self.getParentOperation().getObject();
2282  });
2283  c.def_property_readonly("result_number", [](PyOpResult &self) {
2284  return mlirOpResultGetResultNumber(self.get());
2285  });
2286  }
2287 };
2288 
2289 /// Returns the list of types of the values held by container.
2290 template <typename Container>
2291 static std::vector<MlirType> getValueTypes(Container &container,
2292  PyMlirContextRef &context) {
2293  std::vector<MlirType> result;
2294  result.reserve(container.size());
2295  for (int i = 0, e = container.size(); i < e; ++i) {
2296  result.push_back(mlirValueGetType(container.getElement(i).get()));
2297  }
2298  return result;
2299 }
2300 
2301 /// A list of block arguments. Internally, these are stored as consecutive
2302 /// elements, random access is cheap. The argument list is associated with the
2303 /// operation that contains the block (detached blocks are not allowed in
2304 /// Python bindings) and extends its lifetime.
2305 class PyBlockArgumentList
2306  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2307 public:
2308  static constexpr const char *pyClassName = "BlockArgumentList";
2310 
2311  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2312  intptr_t startIndex = 0, intptr_t length = -1,
2313  intptr_t step = 1)
2314  : Sliceable(startIndex,
2315  length == -1 ? mlirBlockGetNumArguments(block) : length,
2316  step),
2317  operation(std::move(operation)), block(block) {}
2318 
2319  static void bindDerived(ClassTy &c) {
2320  c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2321  return getValueTypes(self, self.operation->getContext());
2322  });
2323  }
2324 
2325 private:
2326  /// Give the parent CRTP class access to hook implementations below.
2327  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2328 
2329  /// Returns the number of arguments in the list.
2330  intptr_t getRawNumElements() {
2331  operation->checkValid();
2332  return mlirBlockGetNumArguments(block);
2333  }
2334 
2335  /// Returns `pos`-the element in the list.
2336  PyBlockArgument getRawElement(intptr_t pos) {
2337  MlirValue argument = mlirBlockGetArgument(block, pos);
2338  return PyBlockArgument(operation, argument);
2339  }
2340 
2341  /// Returns a sublist of this list.
2342  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2343  intptr_t step) {
2344  return PyBlockArgumentList(operation, block, startIndex, length, step);
2345  }
2346 
2347  PyOperationRef operation;
2348  MlirBlock block;
2349 };
2350 
2351 /// A list of operation operands. Internally, these are stored as consecutive
2352 /// elements, random access is cheap. The (returned) operand list is associated
2353 /// with the operation whose operands these are, and thus extends the lifetime
2354 /// of this operation.
2355 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2356 public:
2357  static constexpr const char *pyClassName = "OpOperandList";
2358  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2359 
2360  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2361  intptr_t length = -1, intptr_t step = 1)
2362  : Sliceable(startIndex,
2363  length == -1 ? mlirOperationGetNumOperands(operation->get())
2364  : length,
2365  step),
2366  operation(operation) {}
2367 
2368  void dunderSetItem(intptr_t index, PyValue value) {
2369  index = wrapIndex(index);
2370  mlirOperationSetOperand(operation->get(), index, value.get());
2371  }
2372 
2373  static void bindDerived(ClassTy &c) {
2374  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2375  }
2376 
2377 private:
2378  /// Give the parent CRTP class access to hook implementations below.
2379  friend class Sliceable<PyOpOperandList, PyValue>;
2380 
2381  intptr_t getRawNumElements() {
2382  operation->checkValid();
2383  return mlirOperationGetNumOperands(operation->get());
2384  }
2385 
2386  PyValue getRawElement(intptr_t pos) {
2387  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2388  MlirOperation owner;
2389  if (mlirValueIsAOpResult(operand))
2390  owner = mlirOpResultGetOwner(operand);
2391  else if (mlirValueIsABlockArgument(operand))
2393  else
2394  assert(false && "Value must be an block arg or op result.");
2395  PyOperationRef pyOwner =
2396  PyOperation::forOperation(operation->getContext(), owner);
2397  return PyValue(pyOwner, operand);
2398  }
2399 
2400  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2401  return PyOpOperandList(operation, startIndex, length, step);
2402  }
2403 
2404  PyOperationRef operation;
2405 };
2406 
2407 /// A list of operation results. Internally, these are stored as consecutive
2408 /// elements, random access is cheap. The (returned) result list is associated
2409 /// with the operation whose results these are, and thus extends the lifetime of
2410 /// this operation.
2411 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2412 public:
2413  static constexpr const char *pyClassName = "OpResultList";
2414  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2415 
2416  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2417  intptr_t length = -1, intptr_t step = 1)
2418  : Sliceable(startIndex,
2419  length == -1 ? mlirOperationGetNumResults(operation->get())
2420  : length,
2421  step),
2422  operation(std::move(operation)) {}
2423 
2424  static void bindDerived(ClassTy &c) {
2425  c.def_property_readonly("types", [](PyOpResultList &self) {
2426  return getValueTypes(self, self.operation->getContext());
2427  });
2428  c.def_property_readonly("owner", [](PyOpResultList &self) {
2429  return self.operation->createOpView();
2430  });
2431  }
2432 
2433 private:
2434  /// Give the parent CRTP class access to hook implementations below.
2435  friend class Sliceable<PyOpResultList, PyOpResult>;
2436 
2437  intptr_t getRawNumElements() {
2438  operation->checkValid();
2439  return mlirOperationGetNumResults(operation->get());
2440  }
2441 
2442  PyOpResult getRawElement(intptr_t index) {
2443  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2444  return PyOpResult(value);
2445  }
2446 
2447  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2448  return PyOpResultList(operation, startIndex, length, step);
2449  }
2450 
2451  PyOperationRef operation;
2452 };
2453 
2454 /// A list of operation successors. Internally, these are stored as consecutive
2455 /// elements, random access is cheap. The (returned) successor list is
2456 /// associated with the operation whose successors these are, and thus extends
2457 /// the lifetime of this operation.
2458 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2459 public:
2460  static constexpr const char *pyClassName = "OpSuccessors";
2461 
2462  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2463  intptr_t length = -1, intptr_t step = 1)
2464  : Sliceable(startIndex,
2465  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2466  : length,
2467  step),
2468  operation(operation) {}
2469 
2470  void dunderSetItem(intptr_t index, PyBlock block) {
2471  index = wrapIndex(index);
2472  mlirOperationSetSuccessor(operation->get(), index, block.get());
2473  }
2474 
2475  static void bindDerived(ClassTy &c) {
2476  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2477  }
2478 
2479 private:
2480  /// Give the parent CRTP class access to hook implementations below.
2481  friend class Sliceable<PyOpSuccessors, PyBlock>;
2482 
2483  intptr_t getRawNumElements() {
2484  operation->checkValid();
2485  return mlirOperationGetNumSuccessors(operation->get());
2486  }
2487 
2488  PyBlock getRawElement(intptr_t pos) {
2489  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2490  return PyBlock(operation, block);
2491  }
2492 
2493  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2494  return PyOpSuccessors(operation, startIndex, length, step);
2495  }
2496 
2497  PyOperationRef operation;
2498 };
2499 
2500 /// A list of operation attributes. Can be indexed by name, producing
2501 /// attributes, or by index, producing named attributes.
2502 class PyOpAttributeMap {
2503 public:
2504  PyOpAttributeMap(PyOperationRef operation)
2505  : operation(std::move(operation)) {}
2506 
2507  MlirAttribute dunderGetItemNamed(const std::string &name) {
2508  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2509  toMlirStringRef(name));
2510  if (mlirAttributeIsNull(attr)) {
2511  throw py::key_error("attempt to access a non-existent attribute");
2512  }
2513  return attr;
2514  }
2515 
2516  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2517  if (index < 0 || index >= dunderLen()) {
2518  throw py::index_error("attempt to access out of bounds attribute");
2519  }
2520  MlirNamedAttribute namedAttr =
2521  mlirOperationGetAttribute(operation->get(), index);
2522  return PyNamedAttribute(
2523  namedAttr.attribute,
2524  std::string(mlirIdentifierStr(namedAttr.name).data,
2525  mlirIdentifierStr(namedAttr.name).length));
2526  }
2527 
2528  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2529  mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2530  attr);
2531  }
2532 
2533  void dunderDelItem(const std::string &name) {
2534  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2535  toMlirStringRef(name));
2536  if (!removed)
2537  throw py::key_error("attempt to delete a non-existent attribute");
2538  }
2539 
2540  intptr_t dunderLen() {
2541  return mlirOperationGetNumAttributes(operation->get());
2542  }
2543 
2544  bool dunderContains(const std::string &name) {
2546  operation->get(), toMlirStringRef(name)));
2547  }
2548 
2549  static void bind(py::module &m) {
2550  py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2551  .def("__contains__", &PyOpAttributeMap::dunderContains)
2552  .def("__len__", &PyOpAttributeMap::dunderLen)
2553  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2554  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2555  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2556  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2557  }
2558 
2559 private:
2560  PyOperationRef operation;
2561 };
2562 
2563 } // namespace
2564 
2565 //------------------------------------------------------------------------------
2566 // Populates the core exports of the 'ir' submodule.
2567 //------------------------------------------------------------------------------
2568 
2569 void mlir::python::populateIRCore(py::module &m) {
2570  //----------------------------------------------------------------------------
2571  // Enums.
2572  //----------------------------------------------------------------------------
2573  py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2574  .value("ERROR", MlirDiagnosticError)
2575  .value("WARNING", MlirDiagnosticWarning)
2576  .value("NOTE", MlirDiagnosticNote)
2577  .value("REMARK", MlirDiagnosticRemark);
2578 
2579  py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2580  .value("PRE_ORDER", MlirWalkPreOrder)
2581  .value("POST_ORDER", MlirWalkPostOrder);
2582 
2583  py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2584  .value("ADVANCE", MlirWalkResultAdvance)
2585  .value("INTERRUPT", MlirWalkResultInterrupt)
2586  .value("SKIP", MlirWalkResultSkip);
2587 
2588  //----------------------------------------------------------------------------
2589  // Mapping of Diagnostics.
2590  //----------------------------------------------------------------------------
2591  py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2592  .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2593  .def_property_readonly("location", &PyDiagnostic::getLocation)
2594  .def_property_readonly("message", &PyDiagnostic::getMessage)
2595  .def_property_readonly("notes", &PyDiagnostic::getNotes)
2596  .def("__str__", [](PyDiagnostic &self) -> py::str {
2597  if (!self.isValid())
2598  return "<Invalid Diagnostic>";
2599  return self.getMessage();
2600  });
2601 
2602  py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
2603  py::module_local())
2604  .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
2605  .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
2606  .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
2607  .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
2608  .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
2609  .def("__str__",
2610  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2611 
2612  py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2613  .def("detach", &PyDiagnosticHandler::detach)
2614  .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2615  .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2616  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2617  .def("__exit__", &PyDiagnosticHandler::contextExit);
2618 
2619  //----------------------------------------------------------------------------
2620  // Mapping of MlirContext.
2621  // Note that this is exported as _BaseContext. The containing, Python level
2622  // __init__.py will subclass it with site-specific functionality and set a
2623  // "Context" attribute on this module.
2624  //----------------------------------------------------------------------------
2625  py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2626  .def(py::init<>(&PyMlirContext::createNewContextForInit))
2627  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2628  .def("_get_context_again",
2629  [](PyMlirContext &self) {
2630  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2631  return ref.releaseObject();
2632  })
2633  .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2634  .def("_get_live_operation_objects",
2635  &PyMlirContext::getLiveOperationObjects)
2636  .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2637  .def("_clear_live_operations_inside",
2638  py::overload_cast<MlirOperation>(
2639  &PyMlirContext::clearOperationsInside))
2640  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2641  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2642  &PyMlirContext::getCapsule)
2643  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2644  .def("__enter__", &PyMlirContext::contextEnter)
2645  .def("__exit__", &PyMlirContext::contextExit)
2646  .def_property_readonly_static(
2647  "current",
2648  [](py::object & /*class*/) {
2649  auto *context = PyThreadContextEntry::getDefaultContext();
2650  if (!context)
2651  return py::none().cast<py::object>();
2652  return py::cast(context);
2653  },
2654  "Gets the Context bound to the current thread or raises ValueError")
2655  .def_property_readonly(
2656  "dialects",
2657  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2658  "Gets a container for accessing dialects by name")
2659  .def_property_readonly(
2660  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2661  "Alias for 'dialect'")
2662  .def(
2663  "get_dialect_descriptor",
2664  [=](PyMlirContext &self, std::string &name) {
2665  MlirDialect dialect = mlirContextGetOrLoadDialect(
2666  self.get(), {name.data(), name.size()});
2667  if (mlirDialectIsNull(dialect)) {
2668  throw py::value_error(
2669  (Twine("Dialect '") + name + "' not found").str());
2670  }
2671  return PyDialectDescriptor(self.getRef(), dialect);
2672  },
2673  py::arg("dialect_name"),
2674  "Gets or loads a dialect by name, returning its descriptor object")
2675  .def_property(
2676  "allow_unregistered_dialects",
2677  [](PyMlirContext &self) -> bool {
2679  },
2680  [](PyMlirContext &self, bool value) {
2682  })
2683  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2684  py::arg("callback"),
2685  "Attaches a diagnostic handler that will receive callbacks")
2686  .def(
2687  "enable_multithreading",
2688  [](PyMlirContext &self, bool enable) {
2689  mlirContextEnableMultithreading(self.get(), enable);
2690  },
2691  py::arg("enable"))
2692  .def(
2693  "is_registered_operation",
2694  [](PyMlirContext &self, std::string &name) {
2696  self.get(), MlirStringRef{name.data(), name.size()});
2697  },
2698  py::arg("operation_name"))
2699  .def(
2700  "append_dialect_registry",
2701  [](PyMlirContext &self, PyDialectRegistry &registry) {
2702  mlirContextAppendDialectRegistry(self.get(), registry);
2703  },
2704  py::arg("registry"))
2705  .def_property("emit_error_diagnostics", nullptr,
2706  &PyMlirContext::setEmitErrorDiagnostics,
2707  "Emit error diagnostics to diagnostic handlers. By default "
2708  "error diagnostics are captured and reported through "
2709  "MLIRError exceptions.")
2710  .def("load_all_available_dialects", [](PyMlirContext &self) {
2712  });
2713 
2714  //----------------------------------------------------------------------------
2715  // Mapping of PyDialectDescriptor
2716  //----------------------------------------------------------------------------
2717  py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2718  .def_property_readonly("namespace",
2719  [](PyDialectDescriptor &self) {
2720  MlirStringRef ns =
2721  mlirDialectGetNamespace(self.get());
2722  return py::str(ns.data, ns.length);
2723  })
2724  .def("__repr__", [](PyDialectDescriptor &self) {
2726  std::string repr("<DialectDescriptor ");
2727  repr.append(ns.data, ns.length);
2728  repr.append(">");
2729  return repr;
2730  });
2731 
2732  //----------------------------------------------------------------------------
2733  // Mapping of PyDialects
2734  //----------------------------------------------------------------------------
2735  py::class_<PyDialects>(m, "Dialects", py::module_local())
2736  .def("__getitem__",
2737  [=](PyDialects &self, std::string keyName) {
2738  MlirDialect dialect =
2739  self.getDialectForKey(keyName, /*attrError=*/false);
2740  py::object descriptor =
2741  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2742  return createCustomDialectWrapper(keyName, std::move(descriptor));
2743  })
2744  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2745  MlirDialect dialect =
2746  self.getDialectForKey(attrName, /*attrError=*/true);
2747  py::object descriptor =
2748  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2749  return createCustomDialectWrapper(attrName, std::move(descriptor));
2750  });
2751 
2752  //----------------------------------------------------------------------------
2753  // Mapping of PyDialect
2754  //----------------------------------------------------------------------------
2755  py::class_<PyDialect>(m, "Dialect", py::module_local())
2756  .def(py::init<py::object>(), py::arg("descriptor"))
2757  .def_property_readonly(
2758  "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2759  .def("__repr__", [](py::object self) {
2760  auto clazz = self.attr("__class__");
2761  return py::str("<Dialect ") +
2762  self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2763  clazz.attr("__module__") + py::str(".") +
2764  clazz.attr("__name__") + py::str(")>");
2765  });
2766 
2767  //----------------------------------------------------------------------------
2768  // Mapping of PyDialectRegistry
2769  //----------------------------------------------------------------------------
2770  py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2771  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2772  &PyDialectRegistry::getCapsule)
2773  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2774  .def(py::init<>());
2775 
2776  //----------------------------------------------------------------------------
2777  // Mapping of Location
2778  //----------------------------------------------------------------------------
2779  py::class_<PyLocation>(m, "Location", py::module_local())
2780  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2781  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2782  .def("__enter__", &PyLocation::contextEnter)
2783  .def("__exit__", &PyLocation::contextExit)
2784  .def("__eq__",
2785  [](PyLocation &self, PyLocation &other) -> bool {
2786  return mlirLocationEqual(self, other);
2787  })
2788  .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2789  .def_property_readonly_static(
2790  "current",
2791  [](py::object & /*class*/) {
2792  auto *loc = PyThreadContextEntry::getDefaultLocation();
2793  if (!loc)
2794  throw py::value_error("No current Location");
2795  return loc;
2796  },
2797  "Gets the Location bound to the current thread or raises ValueError")
2798  .def_static(
2799  "unknown",
2800  [](DefaultingPyMlirContext context) {
2801  return PyLocation(context->getRef(),
2802  mlirLocationUnknownGet(context->get()));
2803  },
2804  py::arg("context") = py::none(),
2805  "Gets a Location representing an unknown location")
2806  .def_static(
2807  "callsite",
2808  [](PyLocation callee, const std::vector<PyLocation> &frames,
2809  DefaultingPyMlirContext context) {
2810  if (frames.empty())
2811  throw py::value_error("No caller frames provided");
2812  MlirLocation caller = frames.back().get();
2813  for (const PyLocation &frame :
2814  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2815  caller = mlirLocationCallSiteGet(frame.get(), caller);
2816  return PyLocation(context->getRef(),
2817  mlirLocationCallSiteGet(callee.get(), caller));
2818  },
2819  py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2821  .def_static(
2822  "file",
2823  [](std::string filename, int line, int col,
2824  DefaultingPyMlirContext context) {
2825  return PyLocation(
2826  context->getRef(),
2828  context->get(), toMlirStringRef(filename), line, col));
2829  },
2830  py::arg("filename"), py::arg("line"), py::arg("col"),
2831  py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2832  .def_static(
2833  "fused",
2834  [](const std::vector<PyLocation> &pyLocations,
2835  std::optional<PyAttribute> metadata,
2836  DefaultingPyMlirContext context) {
2838  locations.reserve(pyLocations.size());
2839  for (auto &pyLocation : pyLocations)
2840  locations.push_back(pyLocation.get());
2841  MlirLocation location = mlirLocationFusedGet(
2842  context->get(), locations.size(), locations.data(),
2843  metadata ? metadata->get() : MlirAttribute{0});
2844  return PyLocation(context->getRef(), location);
2845  },
2846  py::arg("locations"), py::arg("metadata") = py::none(),
2847  py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2848  .def_static(
2849  "name",
2850  [](std::string name, std::optional<PyLocation> childLoc,
2851  DefaultingPyMlirContext context) {
2852  return PyLocation(
2853  context->getRef(),
2855  context->get(), toMlirStringRef(name),
2856  childLoc ? childLoc->get()
2857  : mlirLocationUnknownGet(context->get())));
2858  },
2859  py::arg("name"), py::arg("childLoc") = py::none(),
2860  py::arg("context") = py::none(), kContextGetNameLocationDocString)
2861  .def_static(
2862  "from_attr",
2863  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2864  return PyLocation(context->getRef(),
2865  mlirLocationFromAttribute(attribute));
2866  },
2867  py::arg("attribute"), py::arg("context") = py::none(),
2868  "Gets a Location from a LocationAttr")
2869  .def_property_readonly(
2870  "context",
2871  [](PyLocation &self) { return self.getContext().getObject(); },
2872  "Context that owns the Location")
2873  .def_property_readonly(
2874  "attr",
2875  [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2876  "Get the underlying LocationAttr")
2877  .def(
2878  "emit_error",
2879  [](PyLocation &self, std::string message) {
2880  mlirEmitError(self, message.c_str());
2881  },
2882  py::arg("message"), "Emits an error at this location")
2883  .def("__repr__", [](PyLocation &self) {
2884  PyPrintAccumulator printAccum;
2885  mlirLocationPrint(self, printAccum.getCallback(),
2886  printAccum.getUserData());
2887  return printAccum.join();
2888  });
2889 
2890  //----------------------------------------------------------------------------
2891  // Mapping of Module
2892  //----------------------------------------------------------------------------
2893  py::class_<PyModule>(m, "Module", py::module_local())
2894  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2895  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2896  .def_static(
2897  "parse",
2898  [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2899  PyMlirContext::ErrorCapture errors(context->getRef());
2900  MlirModule module = mlirModuleCreateParse(
2901  context->get(), toMlirStringRef(moduleAsm));
2902  if (mlirModuleIsNull(module))
2903  throw MLIRError("Unable to parse module assembly", errors.take());
2904  return PyModule::forModule(module).releaseObject();
2905  },
2906  py::arg("asm"), py::arg("context") = py::none(),
2908  .def_static(
2909  "create",
2910  [](DefaultingPyLocation loc) {
2911  MlirModule module = mlirModuleCreateEmpty(loc);
2912  return PyModule::forModule(module).releaseObject();
2913  },
2914  py::arg("loc") = py::none(), "Creates an empty module")
2915  .def_property_readonly(
2916  "context",
2917  [](PyModule &self) { return self.getContext().getObject(); },
2918  "Context that created the Module")
2919  .def_property_readonly(
2920  "operation",
2921  [](PyModule &self) {
2922  return PyOperation::forOperation(self.getContext(),
2923  mlirModuleGetOperation(self.get()),
2924  self.getRef().releaseObject())
2925  .releaseObject();
2926  },
2927  "Accesses the module as an operation")
2928  .def_property_readonly(
2929  "body",
2930  [](PyModule &self) {
2931  PyOperationRef moduleOp = PyOperation::forOperation(
2932  self.getContext(), mlirModuleGetOperation(self.get()),
2933  self.getRef().releaseObject());
2934  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2935  return returnBlock;
2936  },
2937  "Return the block for this module")
2938  .def(
2939  "dump",
2940  [](PyModule &self) {
2942  },
2944  .def(
2945  "__str__",
2946  [](py::object self) {
2947  // Defer to the operation's __str__.
2948  return self.attr("operation").attr("__str__")();
2949  },
2951 
2952  //----------------------------------------------------------------------------
2953  // Mapping of Operation.
2954  //----------------------------------------------------------------------------
2955  py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2956  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2957  [](PyOperationBase &self) {
2958  return self.getOperation().getCapsule();
2959  })
2960  .def("__eq__",
2961  [](PyOperationBase &self, PyOperationBase &other) {
2962  return &self.getOperation() == &other.getOperation();
2963  })
2964  .def("__eq__",
2965  [](PyOperationBase &self, py::object other) { return false; })
2966  .def("__hash__",
2967  [](PyOperationBase &self) {
2968  return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2969  })
2970  .def_property_readonly("attributes",
2971  [](PyOperationBase &self) {
2972  return PyOpAttributeMap(
2973  self.getOperation().getRef());
2974  })
2975  .def_property_readonly(
2976  "context",
2977  [](PyOperationBase &self) {
2978  PyOperation &concreteOperation = self.getOperation();
2979  concreteOperation.checkValid();
2980  return concreteOperation.getContext().getObject();
2981  },
2982  "Context that owns the Operation")
2983  .def_property_readonly("name",
2984  [](PyOperationBase &self) {
2985  auto &concreteOperation = self.getOperation();
2986  concreteOperation.checkValid();
2987  MlirOperation operation =
2988  concreteOperation.get();
2990  mlirOperationGetName(operation));
2991  return py::str(name.data, name.length);
2992  })
2993  .def_property_readonly("operands",
2994  [](PyOperationBase &self) {
2995  return PyOpOperandList(
2996  self.getOperation().getRef());
2997  })
2998  .def_property_readonly("regions",
2999  [](PyOperationBase &self) {
3000  return PyRegionList(
3001  self.getOperation().getRef());
3002  })
3003  .def_property_readonly(
3004  "results",
3005  [](PyOperationBase &self) {
3006  return PyOpResultList(self.getOperation().getRef());
3007  },
3008  "Returns the list of Operation results.")
3009  .def_property_readonly(
3010  "result",
3011  [](PyOperationBase &self) {
3012  auto &operation = self.getOperation();
3013  auto numResults = mlirOperationGetNumResults(operation);
3014  if (numResults != 1) {
3015  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
3016  throw py::value_error(
3017  (Twine("Cannot call .result on operation ") +
3018  StringRef(name.data, name.length) + " which has " +
3019  Twine(numResults) +
3020  " results (it is only valid for operations with a "
3021  "single result)")
3022  .str());
3023  }
3024  return PyOpResult(operation.getRef(),
3025  mlirOperationGetResult(operation, 0))
3026  .maybeDownCast();
3027  },
3028  "Shortcut to get an op result if it has only one (throws an error "
3029  "otherwise).")
3030  .def_property_readonly(
3031  "location",
3032  [](PyOperationBase &self) {
3033  PyOperation &operation = self.getOperation();
3034  return PyLocation(operation.getContext(),
3035  mlirOperationGetLocation(operation.get()));
3036  },
3037  "Returns the source location the operation was defined or derived "
3038  "from.")
3039  .def_property_readonly("parent",
3040  [](PyOperationBase &self) -> py::object {
3041  auto parent =
3042  self.getOperation().getParentOperation();
3043  if (parent)
3044  return parent->getObject();
3045  return py::none();
3046  })
3047  .def(
3048  "__str__",
3049  [](PyOperationBase &self) {
3050  return self.getAsm(/*binary=*/false,
3051  /*largeElementsLimit=*/std::nullopt,
3052  /*enableDebugInfo=*/false,
3053  /*prettyDebugInfo=*/false,
3054  /*printGenericOpForm=*/false,
3055  /*useLocalScope=*/false,
3056  /*assumeVerified=*/false,
3057  /*skipRegions=*/false);
3058  },
3059  "Returns the assembly form of the operation.")
3060  .def("print",
3061  py::overload_cast<PyAsmState &, pybind11::object, bool>(
3063  py::arg("state"), py::arg("file") = py::none(),
3064  py::arg("binary") = false, kOperationPrintStateDocstring)
3065  .def("print",
3066  py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3067  bool, py::object, bool, bool>(
3069  // Careful: Lots of arguments must match up with print method.
3070  py::arg("large_elements_limit") = py::none(),
3071  py::arg("enable_debug_info") = false,
3072  py::arg("pretty_debug_info") = false,
3073  py::arg("print_generic_op_form") = false,
3074  py::arg("use_local_scope") = false,
3075  py::arg("assume_verified") = false, py::arg("file") = py::none(),
3076  py::arg("binary") = false, py::arg("skip_regions") = false,
3078  .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
3079  py::arg("desired_version") = py::none(),
3081  .def("get_asm", &PyOperationBase::getAsm,
3082  // Careful: Lots of arguments must match up with get_asm method.
3083  py::arg("binary") = false,
3084  py::arg("large_elements_limit") = py::none(),
3085  py::arg("enable_debug_info") = false,
3086  py::arg("pretty_debug_info") = false,
3087  py::arg("print_generic_op_form") = false,
3088  py::arg("use_local_scope") = false,
3089  py::arg("assume_verified") = false, py::arg("skip_regions") = false,
3091  .def("verify", &PyOperationBase::verify,
3092  "Verify the operation. Raises MLIRError if verification fails, and "
3093  "returns true otherwise.")
3094  .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
3095  "Puts self immediately after the other operation in its parent "
3096  "block.")
3097  .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
3098  "Puts self immediately before the other operation in its parent "
3099  "block.")
3100  .def(
3101  "clone",
3102  [](PyOperationBase &self, py::object ip) {
3103  return self.getOperation().clone(ip);
3104  },
3105  py::arg("ip") = py::none())
3106  .def(
3107  "detach_from_parent",
3108  [](PyOperationBase &self) {
3109  PyOperation &operation = self.getOperation();
3110  operation.checkValid();
3111  if (!operation.isAttached())
3112  throw py::value_error("Detached operation has no parent.");
3113 
3114  operation.detachFromParent();
3115  return operation.createOpView();
3116  },
3117  "Detaches the operation from its parent block.")
3118  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3119  .def("walk", &PyOperationBase::walk, py::arg("callback"),
3120  py::arg("walk_order") = MlirWalkPostOrder);
3121 
3122  py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
3123  .def_static("create", &PyOperation::create, py::arg("name"),
3124  py::arg("results") = py::none(),
3125  py::arg("operands") = py::none(),
3126  py::arg("attributes") = py::none(),
3127  py::arg("successors") = py::none(), py::arg("regions") = 0,
3128  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3129  py::arg("infer_type") = false, kOperationCreateDocstring)
3130  .def_static(
3131  "parse",
3132  [](const std::string &sourceStr, const std::string &sourceName,
3133  DefaultingPyMlirContext context) {
3134  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3135  ->createOpView();
3136  },
3137  py::arg("source"), py::kw_only(), py::arg("source_name") = "",
3138  py::arg("context") = py::none(),
3139  "Parses an operation. Supports both text assembly format and binary "
3140  "bytecode format.")
3141  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3142  &PyOperation::getCapsule)
3143  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3144  .def_property_readonly("operation", [](py::object self) { return self; })
3145  .def_property_readonly("opview", &PyOperation::createOpView)
3146  .def_property_readonly(
3147  "successors",
3148  [](PyOperationBase &self) {
3149  return PyOpSuccessors(self.getOperation().getRef());
3150  },
3151  "Returns the list of Operation successors.");
3152 
3153  auto opViewClass =
3154  py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
3155  .def(py::init<py::object>(), py::arg("operation"))
3156  .def_property_readonly("operation", &PyOpView::getOperationObject)
3157  .def_property_readonly("opview", [](py::object self) { return self; })
3158  .def(
3159  "__str__",
3160  [](PyOpView &self) { return py::str(self.getOperationObject()); })
3161  .def_property_readonly(
3162  "successors",
3163  [](PyOperationBase &self) {
3164  return PyOpSuccessors(self.getOperation().getRef());
3165  },
3166  "Returns the list of Operation successors.");
3167  opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3168  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3169  opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3170  opViewClass.attr("build_generic") = classmethod(
3171  &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3172  py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3173  py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3174  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3175  "Builds a specific, generated OpView based on class level attributes.");
3176  opViewClass.attr("parse") = classmethod(
3177  [](const py::object &cls, const std::string &sourceStr,
3178  const std::string &sourceName, DefaultingPyMlirContext context) {
3179  PyOperationRef parsed =
3180  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3181 
3182  // Check if the expected operation was parsed, and cast to to the
3183  // appropriate `OpView` subclass if successful.
3184  // NOTE: This accesses attributes that have been automatically added to
3185  // `OpView` subclasses, and is not intended to be used on `OpView`
3186  // directly.
3187  std::string clsOpName =
3188  py::cast<std::string>(cls.attr("OPERATION_NAME"));
3189  MlirStringRef identifier =
3191  std::string_view parsedOpName(identifier.data, identifier.length);
3192  if (clsOpName != parsedOpName)
3193  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3194  parsedOpName + "'");
3195  return PyOpView::constructDerived(cls, *parsed.get());
3196  },
3197  py::arg("cls"), py::arg("source"), py::kw_only(),
3198  py::arg("source_name") = "", py::arg("context") = py::none(),
3199  "Parses a specific, generated OpView based on class level attributes");
3200 
3201  //----------------------------------------------------------------------------
3202  // Mapping of PyRegion.
3203  //----------------------------------------------------------------------------
3204  py::class_<PyRegion>(m, "Region", py::module_local())
3205  .def_property_readonly(
3206  "blocks",
3207  [](PyRegion &self) {
3208  return PyBlockList(self.getParentOperation(), self.get());
3209  },
3210  "Returns a forward-optimized sequence of blocks.")
3211  .def_property_readonly(
3212  "owner",
3213  [](PyRegion &self) {
3214  return self.getParentOperation()->createOpView();
3215  },
3216  "Returns the operation owning this region.")
3217  .def(
3218  "__iter__",
3219  [](PyRegion &self) {
3220  self.checkValid();
3221  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3222  return PyBlockIterator(self.getParentOperation(), firstBlock);
3223  },
3224  "Iterates over blocks in the region.")
3225  .def("__eq__",
3226  [](PyRegion &self, PyRegion &other) {
3227  return self.get().ptr == other.get().ptr;
3228  })
3229  .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3230 
3231  //----------------------------------------------------------------------------
3232  // Mapping of PyBlock.
3233  //----------------------------------------------------------------------------
3234  py::class_<PyBlock>(m, "Block", py::module_local())
3235  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3236  .def_property_readonly(
3237  "owner",
3238  [](PyBlock &self) {
3239  return self.getParentOperation()->createOpView();
3240  },
3241  "Returns the owning operation of this block.")
3242  .def_property_readonly(
3243  "region",
3244  [](PyBlock &self) {
3245  MlirRegion region = mlirBlockGetParentRegion(self.get());
3246  return PyRegion(self.getParentOperation(), region);
3247  },
3248  "Returns the owning region of this block.")
3249  .def_property_readonly(
3250  "arguments",
3251  [](PyBlock &self) {
3252  return PyBlockArgumentList(self.getParentOperation(), self.get());
3253  },
3254  "Returns a list of block arguments.")
3255  .def(
3256  "add_argument",
3257  [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3258  return mlirBlockAddArgument(self.get(), type, loc);
3259  },
3260  "Append an argument of the specified type to the block and returns "
3261  "the newly added argument.")
3262  .def(
3263  "erase_argument",
3264  [](PyBlock &self, unsigned index) {
3265  return mlirBlockEraseArgument(self.get(), index);
3266  },
3267  "Erase the argument at 'index' and remove it from the argument list.")
3268  .def_property_readonly(
3269  "operations",
3270  [](PyBlock &self) {
3271  return PyOperationList(self.getParentOperation(), self.get());
3272  },
3273  "Returns a forward-optimized sequence of operations.")
3274  .def_static(
3275  "create_at_start",
3276  [](PyRegion &parent, const py::list &pyArgTypes,
3277  const std::optional<py::sequence> &pyArgLocs) {
3278  parent.checkValid();
3279  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3280  mlirRegionInsertOwnedBlock(parent, 0, block);
3281  return PyBlock(parent.getParentOperation(), block);
3282  },
3283  py::arg("parent"), py::arg("arg_types") = py::list(),
3284  py::arg("arg_locs") = std::nullopt,
3285  "Creates and returns a new Block at the beginning of the given "
3286  "region (with given argument types and locations).")
3287  .def(
3288  "append_to",
3289  [](PyBlock &self, PyRegion &region) {
3290  MlirBlock b = self.get();
3292  mlirBlockDetach(b);
3293  mlirRegionAppendOwnedBlock(region.get(), b);
3294  },
3295  "Append this block to a region, transferring ownership if necessary")
3296  .def(
3297  "create_before",
3298  [](PyBlock &self, const py::args &pyArgTypes,
3299  const std::optional<py::sequence> &pyArgLocs) {
3300  self.checkValid();
3301  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3302  MlirRegion region = mlirBlockGetParentRegion(self.get());
3303  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3304  return PyBlock(self.getParentOperation(), block);
3305  },
3306  py::arg("arg_locs") = std::nullopt,
3307  "Creates and returns a new Block before this block "
3308  "(with given argument types and locations).")
3309  .def(
3310  "create_after",
3311  [](PyBlock &self, const py::args &pyArgTypes,
3312  const std::optional<py::sequence> &pyArgLocs) {
3313  self.checkValid();
3314  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3315  MlirRegion region = mlirBlockGetParentRegion(self.get());
3316  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3317  return PyBlock(self.getParentOperation(), block);
3318  },
3319  py::arg("arg_locs") = std::nullopt,
3320  "Creates and returns a new Block after this block "
3321  "(with given argument types and locations).")
3322  .def(
3323  "__iter__",
3324  [](PyBlock &self) {
3325  self.checkValid();
3326  MlirOperation firstOperation =
3328  return PyOperationIterator(self.getParentOperation(),
3329  firstOperation);
3330  },
3331  "Iterates over operations in the block.")
3332  .def("__eq__",
3333  [](PyBlock &self, PyBlock &other) {
3334  return self.get().ptr == other.get().ptr;
3335  })
3336  .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3337  .def("__hash__",
3338  [](PyBlock &self) {
3339  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3340  })
3341  .def(
3342  "__str__",
3343  [](PyBlock &self) {
3344  self.checkValid();
3345  PyPrintAccumulator printAccum;
3346  mlirBlockPrint(self.get(), printAccum.getCallback(),
3347  printAccum.getUserData());
3348  return printAccum.join();
3349  },
3350  "Returns the assembly form of the block.")
3351  .def(
3352  "append",
3353  [](PyBlock &self, PyOperationBase &operation) {
3354  if (operation.getOperation().isAttached())
3355  operation.getOperation().detachFromParent();
3356 
3357  MlirOperation mlirOperation = operation.getOperation().get();
3358  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3359  operation.getOperation().setAttached(
3360  self.getParentOperation().getObject());
3361  },
3362  py::arg("operation"),
3363  "Appends an operation to this block. If the operation is currently "
3364  "in another block, it will be moved.");
3365 
3366  //----------------------------------------------------------------------------
3367  // Mapping of PyInsertionPoint.
3368  //----------------------------------------------------------------------------
3369 
3370  py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
3371  .def(py::init<PyBlock &>(), py::arg("block"),
3372  "Inserts after the last operation but still inside the block.")
3373  .def("__enter__", &PyInsertionPoint::contextEnter)
3374  .def("__exit__", &PyInsertionPoint::contextExit)
3375  .def_property_readonly_static(
3376  "current",
3377  [](py::object & /*class*/) {
3378  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3379  if (!ip)
3380  throw py::value_error("No current InsertionPoint");
3381  return ip;
3382  },
3383  "Gets the InsertionPoint bound to the current thread or raises "
3384  "ValueError if none has been set")
3385  .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3386  "Inserts before a referenced operation.")
3387  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3388  py::arg("block"), "Inserts at the beginning of the block.")
3389  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3390  py::arg("block"), "Inserts before the block terminator.")
3391  .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3392  "Inserts an operation.")
3393  .def_property_readonly(
3394  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3395  "Returns the block that this InsertionPoint points to.")
3396  .def_property_readonly(
3397  "ref_operation",
3398  [](PyInsertionPoint &self) -> py::object {
3399  auto refOperation = self.getRefOperation();
3400  if (refOperation)
3401  return refOperation->getObject();
3402  return py::none();
3403  },
3404  "The reference operation before which new operations are "
3405  "inserted, or None if the insertion point is at the end of "
3406  "the block");
3407 
3408  //----------------------------------------------------------------------------
3409  // Mapping of PyAttribute.
3410  //----------------------------------------------------------------------------
3411  py::class_<PyAttribute>(m, "Attribute", py::module_local())
3412  // Delegate to the PyAttribute copy constructor, which will also lifetime
3413  // extend the backing context which owns the MlirAttribute.
3414  .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
3415  "Casts the passed attribute to the generic Attribute")
3416  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3417  &PyAttribute::getCapsule)
3418  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3419  .def_static(
3420  "parse",
3421  [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3422  PyMlirContext::ErrorCapture errors(context->getRef());
3423  MlirAttribute attr = mlirAttributeParseGet(
3424  context->get(), toMlirStringRef(attrSpec));
3425  if (mlirAttributeIsNull(attr))
3426  throw MLIRError("Unable to parse attribute", errors.take());
3427  return attr;
3428  },
3429  py::arg("asm"), py::arg("context") = py::none(),
3430  "Parses an attribute from an assembly form. Raises an MLIRError on "
3431  "failure.")
3432  .def_property_readonly(
3433  "context",
3434  [](PyAttribute &self) { return self.getContext().getObject(); },
3435  "Context that owns the Attribute")
3436  .def_property_readonly(
3437  "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
3438  .def(
3439  "get_named",
3440  [](PyAttribute &self, std::string name) {
3441  return PyNamedAttribute(self, std::move(name));
3442  },
3443  py::keep_alive<0, 1>(), "Binds a name to the attribute")
3444  .def("__eq__",
3445  [](PyAttribute &self, PyAttribute &other) { return self == other; })
3446  .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3447  .def("__hash__",
3448  [](PyAttribute &self) {
3449  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3450  })
3451  .def(
3452  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3454  .def(
3455  "__str__",
3456  [](PyAttribute &self) {
3457  PyPrintAccumulator printAccum;
3458  mlirAttributePrint(self, printAccum.getCallback(),
3459  printAccum.getUserData());
3460  return printAccum.join();
3461  },
3462  "Returns the assembly form of the Attribute.")
3463  .def("__repr__",
3464  [](PyAttribute &self) {
3465  // Generally, assembly formats are not printed for __repr__ because
3466  // this can cause exceptionally long debug output and exceptions.
3467  // However, attribute values are generally considered useful and
3468  // are printed. This may need to be re-evaluated if debug dumps end
3469  // up being excessive.
3470  PyPrintAccumulator printAccum;
3471  printAccum.parts.append("Attribute(");
3472  mlirAttributePrint(self, printAccum.getCallback(),
3473  printAccum.getUserData());
3474  printAccum.parts.append(")");
3475  return printAccum.join();
3476  })
3477  .def_property_readonly(
3478  "typeid",
3479  [](PyAttribute &self) -> MlirTypeID {
3480  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3481  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3482  "mlirTypeID was expected to be non-null.");
3483  return mlirTypeID;
3484  })
3486  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3487  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3488  "mlirTypeID was expected to be non-null.");
3489  std::optional<pybind11::function> typeCaster =
3490  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3491  mlirAttributeGetDialect(self));
3492  if (!typeCaster)
3493  return py::cast(self);
3494  return typeCaster.value()(self);
3495  });
3496 
3497  //----------------------------------------------------------------------------
3498  // Mapping of PyNamedAttribute
3499  //----------------------------------------------------------------------------
3500  py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3501  .def("__repr__",
3502  [](PyNamedAttribute &self) {
3503  PyPrintAccumulator printAccum;
3504  printAccum.parts.append("NamedAttribute(");
3505  printAccum.parts.append(
3506  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3507  mlirIdentifierStr(self.namedAttr.name).length));
3508  printAccum.parts.append("=");
3509  mlirAttributePrint(self.namedAttr.attribute,
3510  printAccum.getCallback(),
3511  printAccum.getUserData());
3512  printAccum.parts.append(")");
3513  return printAccum.join();
3514  })
3515  .def_property_readonly(
3516  "name",
3517  [](PyNamedAttribute &self) {
3518  return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3519  mlirIdentifierStr(self.namedAttr.name).length);
3520  },
3521  "The name of the NamedAttribute binding")
3522  .def_property_readonly(
3523  "attr",
3524  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3525  py::keep_alive<0, 1>(),
3526  "The underlying generic attribute of the NamedAttribute binding");
3527 
3528  //----------------------------------------------------------------------------
3529  // Mapping of PyType.
3530  //----------------------------------------------------------------------------
3531  py::class_<PyType>(m, "Type", py::module_local())
3532  // Delegate to the PyType copy constructor, which will also lifetime
3533  // extend the backing context which owns the MlirType.
3534  .def(py::init<PyType &>(), py::arg("cast_from_type"),
3535  "Casts the passed type to the generic Type")
3536  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3537  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3538  .def_static(
3539  "parse",
3540  [](std::string typeSpec, DefaultingPyMlirContext context) {
3541  PyMlirContext::ErrorCapture errors(context->getRef());
3542  MlirType type =
3543  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3544  if (mlirTypeIsNull(type))
3545  throw MLIRError("Unable to parse type", errors.take());
3546  return type;
3547  },
3548  py::arg("asm"), py::arg("context") = py::none(),
3550  .def_property_readonly(
3551  "context", [](PyType &self) { return self.getContext().getObject(); },
3552  "Context that owns the Type")
3553  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3554  .def("__eq__", [](PyType &self, py::object &other) { return false; })
3555  .def("__hash__",
3556  [](PyType &self) {
3557  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3558  })
3559  .def(
3560  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3561  .def(
3562  "__str__",
3563  [](PyType &self) {
3564  PyPrintAccumulator printAccum;
3565  mlirTypePrint(self, printAccum.getCallback(),
3566  printAccum.getUserData());
3567  return printAccum.join();
3568  },
3569  "Returns the assembly form of the type.")
3570  .def("__repr__",
3571  [](PyType &self) {
3572  // Generally, assembly formats are not printed for __repr__ because
3573  // this can cause exceptionally long debug output and exceptions.
3574  // However, types are an exception as they typically have compact
3575  // assembly forms and printing them is useful.
3576  PyPrintAccumulator printAccum;
3577  printAccum.parts.append("Type(");
3578  mlirTypePrint(self, printAccum.getCallback(),
3579  printAccum.getUserData());
3580  printAccum.parts.append(")");
3581  return printAccum.join();
3582  })
3584  [](PyType &self) {
3585  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3586  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3587  "mlirTypeID was expected to be non-null.");
3588  std::optional<pybind11::function> typeCaster =
3589  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3590  mlirTypeGetDialect(self));
3591  if (!typeCaster)
3592  return py::cast(self);
3593  return typeCaster.value()(self);
3594  })
3595  .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
3596  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3597  if (!mlirTypeIDIsNull(mlirTypeID))
3598  return mlirTypeID;
3599  auto origRepr =
3600  pybind11::repr(pybind11::cast(self)).cast<std::string>();
3601  throw py::value_error(
3602  (origRepr + llvm::Twine(" has no typeid.")).str());
3603  });
3604 
3605  //----------------------------------------------------------------------------
3606  // Mapping of PyTypeID.
3607  //----------------------------------------------------------------------------
3608  py::class_<PyTypeID>(m, "TypeID", py::module_local())
3609  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3610  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3611  // Note, this tests whether the underlying TypeIDs are the same,
3612  // not whether the wrapper MlirTypeIDs are the same, nor whether
3613  // the Python objects are the same (i.e., PyTypeID is a value type).
3614  .def("__eq__",
3615  [](PyTypeID &self, PyTypeID &other) { return self == other; })
3616  .def("__eq__",
3617  [](PyTypeID &self, const py::object &other) { return false; })
3618  // Note, this gives the hash value of the underlying TypeID, not the
3619  // hash value of the Python object, nor the hash value of the
3620  // MlirTypeID wrapper.
3621  .def("__hash__", [](PyTypeID &self) {
3622  return static_cast<size_t>(mlirTypeIDHashValue(self));
3623  });
3624 
3625  //----------------------------------------------------------------------------
3626  // Mapping of Value.
3627  //----------------------------------------------------------------------------
3628  py::class_<PyValue>(m, "Value", py::module_local())
3629  .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
3630  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3631  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3632  .def_property_readonly(
3633  "context",
3634  [](PyValue &self) { return self.getParentOperation()->getContext(); },
3635  "Context in which the value lives.")
3636  .def(
3637  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3639  .def_property_readonly(
3640  "owner",
3641  [](PyValue &self) -> py::object {
3642  MlirValue v = self.get();
3643  if (mlirValueIsAOpResult(v)) {
3644  assert(
3645  mlirOperationEqual(self.getParentOperation()->get(),
3646  mlirOpResultGetOwner(self.get())) &&
3647  "expected the owner of the value in Python to match that in "
3648  "the IR");
3649  return self.getParentOperation().getObject();
3650  }
3651 
3652  if (mlirValueIsABlockArgument(v)) {
3653  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3654  return py::cast(PyBlock(self.getParentOperation(), block));
3655  }
3656 
3657  assert(false && "Value must be a block argument or an op result");
3658  return py::none();
3659  })
3660  .def_property_readonly("uses",
3661  [](PyValue &self) {
3662  return PyOpOperandIterator(
3663  mlirValueGetFirstUse(self.get()));
3664  })
3665  .def("__eq__",
3666  [](PyValue &self, PyValue &other) {
3667  return self.get().ptr == other.get().ptr;
3668  })
3669  .def("__eq__", [](PyValue &self, py::object other) { return false; })
3670  .def("__hash__",
3671  [](PyValue &self) {
3672  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3673  })
3674  .def(
3675  "__str__",
3676  [](PyValue &self) {
3677  PyPrintAccumulator printAccum;
3678  printAccum.parts.append("Value(");
3679  mlirValuePrint(self.get(), printAccum.getCallback(),
3680  printAccum.getUserData());
3681  printAccum.parts.append(")");
3682  return printAccum.join();
3683  },
3685  .def(
3686  "get_name",
3687  [](PyValue &self, bool useLocalScope) {
3688  PyPrintAccumulator printAccum;
3689  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3690  if (useLocalScope)
3692  MlirAsmState valueState =
3693  mlirAsmStateCreateForValue(self.get(), flags);
3694  mlirValuePrintAsOperand(self.get(), valueState,
3695  printAccum.getCallback(),
3696  printAccum.getUserData());
3698  mlirAsmStateDestroy(valueState);
3699  return printAccum.join();
3700  },
3701  py::arg("use_local_scope") = false)
3702  .def(
3703  "get_name",
3704  [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
3705  PyPrintAccumulator printAccum;
3706  MlirAsmState valueState = state.get().get();
3707  mlirValuePrintAsOperand(self.get(), valueState,
3708  printAccum.getCallback(),
3709  printAccum.getUserData());
3710  return printAccum.join();
3711  },
3712  py::arg("state"), kGetNameAsOperand)
3713  .def_property_readonly(
3714  "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
3715  .def(
3716  "set_type",
3717  [](PyValue &self, const PyType &type) {
3718  return mlirValueSetType(self.get(), type);
3719  },
3720  py::arg("type"))
3721  .def(
3722  "replace_all_uses_with",
3723  [](PyValue &self, PyValue &with) {
3724  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3725  },
3727  .def(
3728  "replace_all_uses_except",
3729  [](MlirValue self, MlirValue with, PyOperation &exception) {
3730  MlirOperation exceptedUser = exception.get();
3731  mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
3732  },
3733  py::arg("with"), py::arg("exceptions"),
3735  .def(
3736  "replace_all_uses_except",
3737  [](MlirValue self, MlirValue with, py::list exceptions) {
3738  // Convert Python list to a SmallVector of MlirOperations
3739  llvm::SmallVector<MlirOperation> exceptionOps;
3740  for (py::handle exception : exceptions) {
3741  exceptionOps.push_back(exception.cast<PyOperation &>().get());
3742  }
3743 
3745  self, with, static_cast<intptr_t>(exceptionOps.size()),
3746  exceptionOps.data());
3747  },
3748  py::arg("with"), py::arg("exceptions"),
3751  [](PyValue &self) { return self.maybeDownCast(); });
3752  PyBlockArgument::bind(m);
3753  PyOpResult::bind(m);
3754  PyOpOperand::bind(m);
3755 
3756  py::class_<PyAsmState>(m, "AsmState", py::module_local())
3757  .def(py::init<PyValue &, bool>(), py::arg("value"),
3758  py::arg("use_local_scope") = false)
3759  .def(py::init<PyOperationBase &, bool>(), py::arg("op"),
3760  py::arg("use_local_scope") = false);
3761 
3762  //----------------------------------------------------------------------------
3763  // Mapping of SymbolTable.
3764  //----------------------------------------------------------------------------
3765  py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3766  .def(py::init<PyOperationBase &>())
3767  .def("__getitem__", &PySymbolTable::dunderGetItem)
3768  .def("insert", &PySymbolTable::insert, py::arg("operation"))
3769  .def("erase", &PySymbolTable::erase, py::arg("operation"))
3770  .def("__delitem__", &PySymbolTable::dunderDel)
3771  .def("__contains__",
3772  [](PySymbolTable &table, const std::string &name) {
3774  table, mlirStringRefCreate(name.data(), name.length())));
3775  })
3776  // Static helpers.
3777  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3778  py::arg("symbol"), py::arg("name"))
3779  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3780  py::arg("symbol"))
3781  .def_static("get_visibility", &PySymbolTable::getVisibility,
3782  py::arg("symbol"))
3783  .def_static("set_visibility", &PySymbolTable::setVisibility,
3784  py::arg("symbol"), py::arg("visibility"))
3785  .def_static("replace_all_symbol_uses",
3786  &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3787  py::arg("new_symbol"), py::arg("from_op"))
3788  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3789  py::arg("from_op"), py::arg("all_sym_uses_visible"),
3790  py::arg("callback"));
3791 
3792  // Container bindings.
3793  PyBlockArgumentList::bind(m);
3794  PyBlockIterator::bind(m);
3795  PyBlockList::bind(m);
3796  PyOperationIterator::bind(m);
3797  PyOperationList::bind(m);
3798  PyOpAttributeMap::bind(m);
3799  PyOpOperandIterator::bind(m);
3800  PyOpOperandList::bind(m);
3801  PyOpResultList::bind(m);
3802  PyOpSuccessors::bind(m);
3803  PyRegionIterator::bind(m);
3804  PyRegionList::bind(m);
3805 
3806  // Debug bindings.
3808 
3809  // Attribute builder getter.
3811 
3812  py::register_local_exception_translator([](std::exception_ptr p) {
3813  // We can't define exceptions with custom fields through pybind, so instead
3814  // the exception class is defined in python and imported here.
3815  try {
3816  if (p)
3817  std::rethrow_exception(p);
3818  } catch (const MLIRError &e) {
3819  py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
3820  .attr("MLIRError")(e.message, e.errorDiagnostics);
3821  PyErr_SetObject(PyExc_Exception, obj.ptr());
3822  }
3823  });
3824 }
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition: Debug.cpp:20
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools.
Definition: Debug.cpp:28
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:114
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:176
static py::object createCustomDialectWrapper(const std::string &dialectNamespace, py::object dialectDescriptor)
Definition: IRCore.cpp:199
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:57
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:172
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:211
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:60
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:146
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:87
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:51
static const char kDumpDocstring[]
Definition: IRCore.cpp:154
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:157
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:54
static void maybeInsertOperation(PyOperationRef &op, const py::object &maybeIp)
Definition: IRCore.cpp:1405
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:217
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:136
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:123
static void populateResultTypes(StringRef name, py::list resultTypeList, const py::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1576
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:68
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:40
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:48
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:164
py::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:193
static const char kValueReplaceAllUsesExceptDocstring[]
Definition: IRCore.cpp:181
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:273
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:118
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:348
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:189
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:180
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:255
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:357
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:235
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:367
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:454
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:198
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:330
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:264
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:445
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:216
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
const float * table
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:125
MlirStringCallback getCallback()
Definition: PybindUtils.h:132
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:209
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:303
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:311
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:518
static PyLocation & resolve()
Definition: IRCore.cpp:1071
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:292
static PyMlirContext & resolve()
Definition: IRCore.cpp:795
ReferrentTy * get() const
Definition: PybindUtils.h:47
Wrapper around an MlirAsmState.
Definition: IRModule.h:785
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1002
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:1004
static PyAttribute createFromCapsule(pybind11::object capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:1939
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:1935
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:1931
Wrapper around an MlirBlock.
Definition: IRModule.h:820
MlirBlock get()
Definition: IRModule.h:827
PyOperationRef & getParentOperation()
Definition: IRModule.h:828
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:399
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:957
PyDiagnosticHandler(MlirContext context, pybind11::object callback)
Definition: IRCore.cpp:951
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:349
pybind11::str getMessage()
Definition: IRCore.cpp:987
PyLocation getLocation()
Definition: IRCore.cpp:980
DiagnosticInfo getInfo()
Definition: IRCore.cpp:1008
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:351
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:975
pybind11::tuple getNotes()
Definition: IRCore.cpp:995
Wrapper around an MlirDialect.
Definition: IRModule.h:454
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:491
static PyDialectRegistry createFromCapsule(pybind11::object capsule)
Definition: IRCore.cpp:1037
pybind11::object getCapsule()
Definition: IRCore.cpp:1032
User-level dialect object.
Definition: IRModule.h:478
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:467
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:1019
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:34
std::optional< pybind11::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:172
std::optional< pybind11::function > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:145
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:844
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:1908
PyInsertionPoint(PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:1863
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:1895
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:1869
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1921
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1917
Wrapper around an MlirLocation.
Definition: IRModule.h:318
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:320
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1049
static PyLocation createFromCapsule(pybind11::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1053
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:1065
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:1061
MlirLocation get() const
Definition: IRModule.h:324
pybind11::object attachDiagnosticHandler(pybind11::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:730
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:185
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:189
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:621
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:682
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:655
void clearOperationAndInside(PyOperationBase &op)
Clears the operaiton and all operations inside using clearOperation(MlirOperation).
Definition: IRCore.cpp:707
static PyMlirContext * createNewContextForInit()
For the case of a python init (py::init) method, pybind11 is quite strict about needing to return a p...
Definition: IRCore.cpp:628
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:718
pybind11::object contextEnter()
Enter and exit the context manager.
Definition: IRCore.cpp:720
size_t clearLiveOperations()
Clears the live operations map, returning the number of entries which were invalidated.
Definition: IRCore.cpp:666
std::vector< PyOperation * > getLiveOperationObjects()
Get a list of Python objects which are still in the live context map.
Definition: IRCore.cpp:659
void contextExit(const pybind11::object &excType, const pybind11::object &excVal, const pybind11::object &excTb)
Definition: IRCore.cpp:724
void clearOperation(MlirOperation op)
Removes an operation from the live operations map and sets it invalid.
Definition: IRCore.cpp:674
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:633
size_t getLiveOperationCount()
Gets the count of live operations associated with this context.
Definition: IRCore.cpp:657
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:617
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:541
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1098
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1131
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1124
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1026
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:1951
MlirNamedAttribute namedAttr
Definition: IRModule.h:1035
pybind11::object getObject()
Definition: IRModule.h:88
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:76
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:734
PyOpView(const pybind11::object &operationObject)
Definition: IRCore.cpp:1853
static pybind11::object buildGeneric(const pybind11::object &cls, std::optional< pybind11::list > resultTypeList, pybind11::list operandList, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, DefaultingPyLocation location, const pybind11::object &maybeIp)
Definition: IRCore.cpp:1664
static pybind11::object constructDerived(const pybind11::object &cls, const PyOperation &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:1842
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:571
void print(std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, py::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition: IRCore.cpp:1227
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1290
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void writeBytecode(const pybind11::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1269
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1346
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1355
bool verify()
Verify the operation.
Definition: IRCore.cpp:1364
pybind11::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions)
Definition: IRCore.cpp:1322
pybind11::object clone(const pybind11::object &ip)
Clones this operation.
Definition: IRCore.cpp:1546
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:640
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1566
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1391
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:618
PyOperationRef getRef()
Definition: IRModule.h:653
MlirOperation get() const
Definition: IRModule.h:648
void setAttached(const pybind11::object &parent=pybind11::object())
Definition: IRModule.h:659
static pybind11::object create(const std::string &name, std::optional< std::vector< PyType * >> results, std::optional< std::vector< PyValue * >> operands, std::optional< pybind11::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, DefaultingPyLocation location, const pybind11::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1420
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1555
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1179
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1372
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1382
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive=pybind11::object())
Creates a detached operation.
Definition: IRCore.cpp:1195
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1209
static pybind11::object createFromCapsule(pybind11::object capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1396
void checkValid() const
Definition: IRCore.cpp:1221
Wrapper around an MlirRegion.
Definition: IRModule.h:766
PyOperationRef & getParentOperation()
Definition: IRModule.h:775
MlirRegion get()
Definition: IRModule.h:774
Bindings for MLIR symbol tables.
Definition: IRModule.h:1232
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2071
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2142
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2124
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2098
MlirAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2076
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2061
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2041
static MlirAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2086
pybind11::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2049
static MlirAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2113
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, pybind11::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2154
Tracks an entry in the thread context stack.
Definition: IRModule.h:107
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:815
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:927
static pybind11::object pushContext(PyMlirContext &context)
Definition: IRCore.cpp:877
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:872
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:907
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:885
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:867
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:896
static pybind11::object pushLocation(PyLocation &location)
Definition: IRCore.cpp:918
PyMlirContext * getContext()
Definition: IRCore.cpp:844
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:862
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:810
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:850
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:904
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:1993
static PyTypeID createFromCapsule(pybind11::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:1987
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:1983
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:906
Wrapper around the generic MlirType.
Definition: IRModule.h:880
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:1967
static PyType createFromCapsule(pybind11::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:1971
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:882
bool operator==(const PyType &other) const
Definition: IRCore.cpp:1963
Wrapper around the generic MlirValue.
Definition: IRModule.h:1133
static PyValue createFromCapsule(pybind11::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:2020
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1139
pybind11::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:2001
pybind11::object maybeDownCast()
Definition: IRCore.cpp:2005
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:252
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:958
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1043
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:735
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:701
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:529
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:260
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1201
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:128
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:588
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1180
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:728
@ MlirWalkPreOrder
Definition: IR.h:729
@ MlirWalkPostOrder
Definition: IR.h:730
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:662
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:377
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:330
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1128
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1185
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:220
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:946
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:83
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1149
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:314
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1062
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1045
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1101
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:72
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:802
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:1009
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:600
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:992
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:288
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1082
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:672
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1033
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1112
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1120
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:841
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:725
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:898
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition: IR.cpp:1013
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:693
MlirWalkResult
Operation walk result.
Definition: IR.h:721
@ MlirWalkResultInterrupt
Definition: IR.h:723
@ MlirWalkResultSkip
Definition: IR.h:724
@ MlirWalkResultAdvance
Definition: IR.h:722
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:782
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:667
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1008
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:99
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:596
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:503
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:910
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:216
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:271
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:891
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:187
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:94
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:201
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:788
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:963
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:515
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:825
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:817
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:866
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:928
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1175
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1066
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:984
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:310
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1031
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1165
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:292
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:533
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:282
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:511
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition: IR.cpp:914
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:687
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:999
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:300
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:717
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:336
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1078
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:348
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1041
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:856
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Skip printing regions.
Definition: IR.cpp:228
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:565
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:537
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:322
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:256
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:224
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1070
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1161
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:173
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:239
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:573
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:389
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:845
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:267
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:967
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:950
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:569
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:235
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:743
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:54
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:837
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:541
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:296
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: IR.cpp:919
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:677
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1126
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1190
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1093
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:314
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:778
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:849
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1087
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:592
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:235
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1097
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:954
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:756
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:507
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:545
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1074
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1137
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:653
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:107
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:381
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:385
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:326
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:197
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:279
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:103
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:932
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:243
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1157
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:765
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:905
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:372
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:211
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:519
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1108
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:577
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:715
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:971
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:193
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:169
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:457
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1098
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:76
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:721
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:658
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:986
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:708
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:980
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:976
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:70
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:494
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1116
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:519
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:771
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Definition: Polynomial.h:262
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:162
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:530
void populateIRCore(pybind11::module &m)
PyObjectRef< PyOperation > PyOperationRef
Definition: IRModule.h:614
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:340
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:268
static void dundeSetItemNamed(const std::string &attributeKind, py::function func, bool replace)
Definition: IRCore.cpp:277
static py::function dundeGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:271
static void bind(py::module &m)
Definition: IRCore.cpp:283
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:241
static void bind(py::module &m)
Definition: IRCore.cpp:246
static bool get(const py::object &)
Definition: IRCore.cpp:244
static void set(py::object &o, bool enable)
Definition: IRCore.cpp:242
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::list parts
Definition: PybindUtils.h:103
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1284
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1289
Materialized diagnostic information.
Definition: IRModule.h:361
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:427
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:437
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:428