MLIR  20.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <optional>
10 #include <utility>
11 
12 #include "Globals.h"
13 #include "IRModule.h"
14 #include "NanobindUtils.h"
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
22 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace nb = nanobind;
27 using namespace nb::literals;
28 using namespace mlir;
29 using namespace mlir::python;
30 
31 using llvm::SmallVector;
32 using llvm::StringRef;
33 using llvm::Twine;
34 
35 //------------------------------------------------------------------------------
36 // Docstrings (trivial, non-duplicated docstrings are included inline).
37 //------------------------------------------------------------------------------
38 
39 static const char kContextParseTypeDocstring[] =
40  R"(Parses the assembly form of a type.
41 
42 Returns a Type object or raises an MLIRError if the type cannot be parsed.
43 
44 See also: https://mlir.llvm.org/docs/LangRef/#type-system
45 )";
46 
48  R"(Gets a Location representing a caller and callsite)";
49 
50 static const char kContextGetFileLocationDocstring[] =
51  R"(Gets a Location representing a file, line and column)";
52 
53 static const char kContextGetFusedLocationDocstring[] =
54  R"(Gets a Location representing a fused location with optional metadata)";
55 
56 static const char kContextGetNameLocationDocString[] =
57  R"(Gets a Location representing a named location with optional child location)";
58 
59 static const char kModuleParseDocstring[] =
60  R"(Parses a module's assembly format from a string.
61 
62 Returns a new MlirModule or raises an MLIRError if the parsing fails.
63 
64 See also: https://mlir.llvm.org/docs/LangRef/
65 )";
66 
67 static const char kOperationCreateDocstring[] =
68  R"(Creates a new operation.
69 
70 Args:
71  name: Operation name (e.g. "dialect.operation").
72  results: Sequence of Type representing op result types.
73  attributes: Dict of str:Attribute.
74  successors: List of Block for the operation's successors.
75  regions: Number of regions to create.
76  location: A Location object (defaults to resolve from context manager).
77  ip: An InsertionPoint (defaults to resolve from context manager or set to
78  False to disable insertion, even with an insertion point set in the
79  context manager).
80  infer_type: Whether to infer result types.
81 Returns:
82  A new "detached" Operation object. Detached operations can be added
83  to blocks, which causes them to become "attached."
84 )";
85 
86 static const char kOperationPrintDocstring[] =
87  R"(Prints the assembly form of the operation to a file like object.
88 
89 Args:
90  file: The file like object to write to. Defaults to sys.stdout.
91  binary: Whether to write bytes (True) or str (False). Defaults to False.
92  large_elements_limit: Whether to elide elements attributes above this
93  number of elements. Defaults to None (no limit).
94  enable_debug_info: Whether to print debug/location information. Defaults
95  to False.
96  pretty_debug_info: Whether to format debug information for easier reading
97  by a human (warning: the result is unparseable).
98  print_generic_op_form: Whether to print the generic assembly forms of all
99  ops. Defaults to False.
100  use_local_Scope: Whether to print in a way that is more optimized for
101  multi-threaded access but may not be consistent with how the overall
102  module prints.
103  assume_verified: By default, if not printing generic form, the verifier
104  will be run and if it fails, generic form will be printed with a comment
105  about failed verification. While a reasonable default for interactive use,
106  for systematic use, it is often better for the caller to verify explicitly
107  and report failures in a more robust fashion. Set this to True if doing this
108  in order to avoid running a redundant verification. If the IR is actually
109  invalid, behavior is undefined.
110  skip_regions: Whether to skip printing regions. Defaults to False.
111 )";
112 
113 static const char kOperationPrintStateDocstring[] =
114  R"(Prints the assembly form of the operation to a file like object.
115 
116 Args:
117  file: The file like object to write to. Defaults to sys.stdout.
118  binary: Whether to write bytes (True) or str (False). Defaults to False.
119  state: AsmState capturing the operation numbering and flags.
120 )";
121 
122 static const char kOperationGetAsmDocstring[] =
123  R"(Gets the assembly form of the operation with all options available.
124 
125 Args:
126  binary: Whether to return a bytes (True) or str (False) object. Defaults to
127  False.
128  ... others ...: See the print() method for common keyword arguments for
129  configuring the printout.
130 Returns:
131  Either a bytes or str object, depending on the setting of the 'binary'
132  argument.
133 )";
134 
135 static const char kOperationPrintBytecodeDocstring[] =
136  R"(Write the bytecode form of the operation to a file like object.
137 
138 Args:
139  file: The file like object to write to.
140  desired_version: The version of bytecode to emit.
141 Returns:
142  The bytecode writer status.
143 )";
144 
145 static const char kOperationStrDunderDocstring[] =
146  R"(Gets the assembly form of the operation with default options.
147 
148 If more advanced control over the assembly formatting or I/O options is needed,
149 use the dedicated print or get_asm method, which supports keyword arguments to
150 customize behavior.
151 )";
152 
153 static const char kDumpDocstring[] =
154  R"(Dumps a debug representation of the object to stderr.)";
155 
156 static const char kAppendBlockDocstring[] =
157  R"(Appends a new block, with argument types as positional args.
158 
159 Returns:
160  The created block.
161 )";
162 
163 static const char kValueDunderStrDocstring[] =
164  R"(Returns the string form of the value.
165 
166 If the value is a block argument, this is the assembly form of its type and the
167 position in the argument list. If the value is an operation result, this is
168 equivalent to printing the operation that produced it.
169 )";
170 
171 static const char kGetNameAsOperand[] =
172  R"(Returns the string form of value as an operand (i.e., the ValueID).
173 )";
174 
176  R"(Replace all uses of value with the new value, updating anything in
177 the IR that uses 'self' to use the other value instead.
178 )";
179 
181  R"("Replace all uses of this value with the 'with' value, except for those
182 in 'exceptions'. 'exceptions' can be either a single operation or a list of
183 operations.
184 )";
185 
186 //------------------------------------------------------------------------------
187 // Utilities.
188 //------------------------------------------------------------------------------
189 
190 /// Helper for creating an @classmethod.
191 template <class Func, typename... Args>
192 nb::object classmethod(Func f, Args... args) {
193  nb::object cf = nb::cpp_function(f, args...);
194  return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
195 }
196 
197 static nb::object
198 createCustomDialectWrapper(const std::string &dialectNamespace,
199  nb::object dialectDescriptor) {
200  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
201  if (!dialectClass) {
202  // Use the base class.
203  return nb::cast(PyDialect(std::move(dialectDescriptor)));
204  }
205 
206  // Create the custom implementation.
207  return (*dialectClass)(std::move(dialectDescriptor));
208 }
209 
210 static MlirStringRef toMlirStringRef(const std::string &s) {
211  return mlirStringRefCreate(s.data(), s.size());
212 }
213 
214 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
215  return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
216 }
217 
218 /// Create a block, using the current location context if no locations are
219 /// specified.
220 static MlirBlock createBlock(const nb::sequence &pyArgTypes,
221  const std::optional<nb::sequence> &pyArgLocs) {
222  SmallVector<MlirType> argTypes;
223  argTypes.reserve(nb::len(pyArgTypes));
224  for (const auto &pyType : pyArgTypes)
225  argTypes.push_back(nb::cast<PyType &>(pyType));
226 
228  if (pyArgLocs) {
229  argLocs.reserve(nb::len(*pyArgLocs));
230  for (const auto &pyLoc : *pyArgLocs)
231  argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
232  } else if (!argTypes.empty()) {
233  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
234  }
235 
236  if (argTypes.size() != argLocs.size())
237  throw nb::value_error(("Expected " + Twine(argTypes.size()) +
238  " locations, got: " + Twine(argLocs.size()))
239  .str()
240  .c_str());
241  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
242 }
243 
244 /// Wrapper for the global LLVM debugging flag.
246  static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
247 
248  static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
249 
250  static void bind(nb::module_ &m) {
251  // Debug flags.
252  nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
253  .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
254  &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
255  .def_static(
256  "set_types",
257  [](const std::string &type) {
258  mlirSetGlobalDebugType(type.c_str());
259  },
260  "types"_a, "Sets specific debug types to be produced by LLVM")
261  .def_static("set_types", [](const std::vector<std::string> &types) {
262  std::vector<const char *> pointers;
263  pointers.reserve(types.size());
264  for (const std::string &str : types)
265  pointers.push_back(str.c_str());
266  mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
267  });
268  }
269 };
270 
272  static bool dunderContains(const std::string &attributeKind) {
273  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
274  }
275  static nb::callable dundeGetItemNamed(const std::string &attributeKind) {
276  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
277  if (!builder)
278  throw nb::key_error(attributeKind.c_str());
279  return *builder;
280  }
281  static void dundeSetItemNamed(const std::string &attributeKind,
282  nb::callable func, bool replace) {
283  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
284  replace);
285  }
286 
287  static void bind(nb::module_ &m) {
288  nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
289  .def_static("contains", &PyAttrBuilderMap::dunderContains)
290  .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
291  .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
292  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
293  "Register an attribute builder for building MLIR "
294  "attributes from python values.");
295  }
296 };
297 
298 //------------------------------------------------------------------------------
299 // PyBlock
300 //------------------------------------------------------------------------------
301 
302 nb::object PyBlock::getCapsule() {
303  return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
304 }
305 
306 //------------------------------------------------------------------------------
307 // Collections.
308 //------------------------------------------------------------------------------
309 
310 namespace {
311 
312 class PyRegionIterator {
313 public:
314  PyRegionIterator(PyOperationRef operation)
315  : operation(std::move(operation)) {}
316 
317  PyRegionIterator &dunderIter() { return *this; }
318 
319  PyRegion dunderNext() {
320  operation->checkValid();
321  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
322  throw nb::stop_iteration();
323  }
324  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
325  return PyRegion(operation, region);
326  }
327 
328  static void bind(nb::module_ &m) {
329  nb::class_<PyRegionIterator>(m, "RegionIterator")
330  .def("__iter__", &PyRegionIterator::dunderIter)
331  .def("__next__", &PyRegionIterator::dunderNext);
332  }
333 
334 private:
335  PyOperationRef operation;
336  int nextIndex = 0;
337 };
338 
339 /// Regions of an op are fixed length and indexed numerically so are represented
340 /// with a sequence-like container.
341 class PyRegionList {
342 public:
343  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
344 
345  PyRegionIterator dunderIter() {
346  operation->checkValid();
347  return PyRegionIterator(operation);
348  }
349 
350  intptr_t dunderLen() {
351  operation->checkValid();
352  return mlirOperationGetNumRegions(operation->get());
353  }
354 
355  PyRegion dunderGetItem(intptr_t index) {
356  // dunderLen checks validity.
357  if (index < 0 || index >= dunderLen()) {
358  throw nb::index_error("attempt to access out of bounds region");
359  }
360  MlirRegion region = mlirOperationGetRegion(operation->get(), index);
361  return PyRegion(operation, region);
362  }
363 
364  static void bind(nb::module_ &m) {
365  nb::class_<PyRegionList>(m, "RegionSequence")
366  .def("__len__", &PyRegionList::dunderLen)
367  .def("__iter__", &PyRegionList::dunderIter)
368  .def("__getitem__", &PyRegionList::dunderGetItem);
369  }
370 
371 private:
372  PyOperationRef operation;
373 };
374 
375 class PyBlockIterator {
376 public:
377  PyBlockIterator(PyOperationRef operation, MlirBlock next)
378  : operation(std::move(operation)), next(next) {}
379 
380  PyBlockIterator &dunderIter() { return *this; }
381 
382  PyBlock dunderNext() {
383  operation->checkValid();
384  if (mlirBlockIsNull(next)) {
385  throw nb::stop_iteration();
386  }
387 
388  PyBlock returnBlock(operation, next);
389  next = mlirBlockGetNextInRegion(next);
390  return returnBlock;
391  }
392 
393  static void bind(nb::module_ &m) {
394  nb::class_<PyBlockIterator>(m, "BlockIterator")
395  .def("__iter__", &PyBlockIterator::dunderIter)
396  .def("__next__", &PyBlockIterator::dunderNext);
397  }
398 
399 private:
400  PyOperationRef operation;
401  MlirBlock next;
402 };
403 
404 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
405 /// we present them as a more full-featured list-like container but optimize
406 /// it for forward iteration. Blocks are always owned by a region.
407 class PyBlockList {
408 public:
409  PyBlockList(PyOperationRef operation, MlirRegion region)
410  : operation(std::move(operation)), region(region) {}
411 
412  PyBlockIterator dunderIter() {
413  operation->checkValid();
414  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
415  }
416 
417  intptr_t dunderLen() {
418  operation->checkValid();
419  intptr_t count = 0;
420  MlirBlock block = mlirRegionGetFirstBlock(region);
421  while (!mlirBlockIsNull(block)) {
422  count += 1;
423  block = mlirBlockGetNextInRegion(block);
424  }
425  return count;
426  }
427 
428  PyBlock dunderGetItem(intptr_t index) {
429  operation->checkValid();
430  if (index < 0) {
431  throw nb::index_error("attempt to access out of bounds block");
432  }
433  MlirBlock block = mlirRegionGetFirstBlock(region);
434  while (!mlirBlockIsNull(block)) {
435  if (index == 0) {
436  return PyBlock(operation, block);
437  }
438  block = mlirBlockGetNextInRegion(block);
439  index -= 1;
440  }
441  throw nb::index_error("attempt to access out of bounds block");
442  }
443 
444  PyBlock appendBlock(const nb::args &pyArgTypes,
445  const std::optional<nb::sequence> &pyArgLocs) {
446  operation->checkValid();
447  MlirBlock block =
448  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
449  mlirRegionAppendOwnedBlock(region, block);
450  return PyBlock(operation, block);
451  }
452 
453  static void bind(nb::module_ &m) {
454  nb::class_<PyBlockList>(m, "BlockList")
455  .def("__getitem__", &PyBlockList::dunderGetItem)
456  .def("__iter__", &PyBlockList::dunderIter)
457  .def("__len__", &PyBlockList::dunderLen)
458  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
459  nb::arg("args"), nb::kw_only(),
460  nb::arg("arg_locs") = std::nullopt);
461  }
462 
463 private:
464  PyOperationRef operation;
465  MlirRegion region;
466 };
467 
468 class PyOperationIterator {
469 public:
470  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
471  : parentOperation(std::move(parentOperation)), next(next) {}
472 
473  PyOperationIterator &dunderIter() { return *this; }
474 
475  nb::object dunderNext() {
476  parentOperation->checkValid();
477  if (mlirOperationIsNull(next)) {
478  throw nb::stop_iteration();
479  }
480 
481  PyOperationRef returnOperation =
482  PyOperation::forOperation(parentOperation->getContext(), next);
483  next = mlirOperationGetNextInBlock(next);
484  return returnOperation->createOpView();
485  }
486 
487  static void bind(nb::module_ &m) {
488  nb::class_<PyOperationIterator>(m, "OperationIterator")
489  .def("__iter__", &PyOperationIterator::dunderIter)
490  .def("__next__", &PyOperationIterator::dunderNext);
491  }
492 
493 private:
494  PyOperationRef parentOperation;
495  MlirOperation next;
496 };
497 
498 /// Operations are exposed by the C-API as a forward-only linked list. In
499 /// Python, we present them as a more full-featured list-like container but
500 /// optimize it for forward iteration. Iterable operations are always owned
501 /// by a block.
502 class PyOperationList {
503 public:
504  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
505  : parentOperation(std::move(parentOperation)), block(block) {}
506 
507  PyOperationIterator dunderIter() {
508  parentOperation->checkValid();
509  return PyOperationIterator(parentOperation,
511  }
512 
513  intptr_t dunderLen() {
514  parentOperation->checkValid();
515  intptr_t count = 0;
516  MlirOperation childOp = mlirBlockGetFirstOperation(block);
517  while (!mlirOperationIsNull(childOp)) {
518  count += 1;
519  childOp = mlirOperationGetNextInBlock(childOp);
520  }
521  return count;
522  }
523 
524  nb::object dunderGetItem(intptr_t index) {
525  parentOperation->checkValid();
526  if (index < 0) {
527  throw nb::index_error("attempt to access out of bounds operation");
528  }
529  MlirOperation childOp = mlirBlockGetFirstOperation(block);
530  while (!mlirOperationIsNull(childOp)) {
531  if (index == 0) {
532  return PyOperation::forOperation(parentOperation->getContext(), childOp)
533  ->createOpView();
534  }
535  childOp = mlirOperationGetNextInBlock(childOp);
536  index -= 1;
537  }
538  throw nb::index_error("attempt to access out of bounds operation");
539  }
540 
541  static void bind(nb::module_ &m) {
542  nb::class_<PyOperationList>(m, "OperationList")
543  .def("__getitem__", &PyOperationList::dunderGetItem)
544  .def("__iter__", &PyOperationList::dunderIter)
545  .def("__len__", &PyOperationList::dunderLen);
546  }
547 
548 private:
549  PyOperationRef parentOperation;
550  MlirBlock block;
551 };
552 
553 class PyOpOperand {
554 public:
555  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
556 
557  nb::object getOwner() {
558  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
559  PyMlirContextRef context =
560  PyMlirContext::forContext(mlirOperationGetContext(owner));
561  return PyOperation::forOperation(context, owner)->createOpView();
562  }
563 
564  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
565 
566  static void bind(nb::module_ &m) {
567  nb::class_<PyOpOperand>(m, "OpOperand")
568  .def_prop_ro("owner", &PyOpOperand::getOwner)
569  .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
570  }
571 
572 private:
573  MlirOpOperand opOperand;
574 };
575 
576 class PyOpOperandIterator {
577 public:
578  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
579 
580  PyOpOperandIterator &dunderIter() { return *this; }
581 
582  PyOpOperand dunderNext() {
583  if (mlirOpOperandIsNull(opOperand))
584  throw nb::stop_iteration();
585 
586  PyOpOperand returnOpOperand(opOperand);
587  opOperand = mlirOpOperandGetNextUse(opOperand);
588  return returnOpOperand;
589  }
590 
591  static void bind(nb::module_ &m) {
592  nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
593  .def("__iter__", &PyOpOperandIterator::dunderIter)
594  .def("__next__", &PyOpOperandIterator::dunderNext);
595  }
596 
597 private:
598  MlirOpOperand opOperand;
599 };
600 
601 } // namespace
602 
603 //------------------------------------------------------------------------------
604 // PyMlirContext
605 //------------------------------------------------------------------------------
606 
607 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
608  nb::gil_scoped_acquire acquire;
609  auto &liveContexts = getLiveContexts();
610  liveContexts[context.ptr] = this;
611 }
612 
614  // Note that the only public way to construct an instance is via the
615  // forContext method, which always puts the associated handle into
616  // liveContexts.
617  nb::gil_scoped_acquire acquire;
618  getLiveContexts().erase(context.ptr);
619  mlirContextDestroy(context);
620 }
621 
623  return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
624 }
625 
626 nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
627  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
628  if (mlirContextIsNull(rawContext))
629  throw nb::python_error();
630  return forContext(rawContext).releaseObject();
631 }
632 
634  nb::gil_scoped_acquire acquire;
635  auto &liveContexts = getLiveContexts();
636  auto it = liveContexts.find(context.ptr);
637  if (it == liveContexts.end()) {
638  // Create.
639  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
640  nb::object pyRef = nb::cast(unownedContextWrapper);
641  assert(pyRef && "cast to nb::object failed");
642  liveContexts[context.ptr] = unownedContextWrapper;
643  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
644  }
645  // Use existing.
646  nb::object pyRef = nb::cast(it->second);
647  return PyMlirContextRef(it->second, std::move(pyRef));
648 }
649 
650 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
651  static LiveContextMap liveContexts;
652  return liveContexts;
653 }
654 
655 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
656 
657 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
658 
659 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
660  std::vector<PyOperation *> liveObjects;
661  for (auto &entry : liveOperations)
662  liveObjects.push_back(entry.second.second);
663  return liveObjects;
664 }
665 
667  for (auto &op : liveOperations)
668  op.second.second->setInvalid();
669  size_t numInvalidated = liveOperations.size();
670  liveOperations.clear();
671  return numInvalidated;
672 }
673 
674 void PyMlirContext::clearOperation(MlirOperation op) {
675  auto it = liveOperations.find(op.ptr);
676  if (it != liveOperations.end()) {
677  it->second.second->setInvalid();
678  liveOperations.erase(it);
679  }
680 }
681 
683  typedef struct {
684  PyOperation &rootOp;
685  bool rootSeen;
686  } callBackData;
687  callBackData data{op.getOperation(), false};
688  // Mark all ops below the op that the passmanager will be rooted
689  // at (but not op itself - note the preorder) as invalid.
690  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
691  void *userData) {
692  callBackData *data = static_cast<callBackData *>(userData);
693  if (LLVM_LIKELY(data->rootSeen))
694  data->rootOp.getOperation().getContext()->clearOperation(op);
695  else
696  data->rootSeen = true;
698  };
699  mlirOperationWalk(op.getOperation(), invalidatingCallback,
700  static_cast<void *>(&data), MlirWalkPreOrder);
701 }
702 void PyMlirContext::clearOperationsInside(MlirOperation op) {
705 }
706 
708  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
709  void *userData) {
710  PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
711  contextRef->clearOperation(op);
713  };
714  mlirOperationWalk(op.getOperation(), invalidatingCallback,
716 }
717 
718 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
719 
720 nb::object PyMlirContext::contextEnter(nb::object context) {
721  return PyThreadContextEntry::pushContext(context);
722 }
723 
724 void PyMlirContext::contextExit(const nb::object &excType,
725  const nb::object &excVal,
726  const nb::object &excTb) {
728 }
729 
730 nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
731  // Note that ownership is transferred to the delete callback below by way of
732  // an explicit inc_ref (borrow).
733  PyDiagnosticHandler *pyHandler =
734  new PyDiagnosticHandler(get(), std::move(callback));
735  nb::object pyHandlerObject =
736  nb::cast(pyHandler, nb::rv_policy::take_ownership);
737  pyHandlerObject.inc_ref();
738 
739  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
740  // guaranteed to be known to pybind.
741  auto handlerCallback =
742  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
743  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
744  nb::object pyDiagnosticObject =
745  nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
746 
747  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
748  bool result = false;
749  {
750  // Since this can be called from arbitrary C++ contexts, always get the
751  // gil.
752  nb::gil_scoped_acquire gil;
753  try {
754  result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
755  } catch (std::exception &e) {
756  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
757  e.what());
758  pyHandler->hadError = true;
759  }
760  }
761 
762  pyDiagnostic->invalidate();
764  };
765  auto deleteCallback = +[](void *userData) {
766  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
767  assert(pyHandler->registeredID && "handler is not registered");
768  pyHandler->registeredID.reset();
769 
770  // Decrement reference, balancing the inc_ref() above.
771  nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
772  pyHandlerObject.dec_ref();
773  };
774 
775  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
776  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
777  return pyHandlerObject;
778 }
779 
780 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
781  void *userData) {
782  auto *self = static_cast<ErrorCapture *>(userData);
783  // Check if the context requested we emit errors instead of capturing them.
784  if (self->ctx->emitErrorDiagnostics)
785  return mlirLogicalResultFailure();
786 
788  return mlirLogicalResultFailure();
789 
790  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
791  return mlirLogicalResultSuccess();
792 }
793 
796  if (!context) {
797  throw std::runtime_error(
798  "An MLIR function requires a Context but none was provided in the call "
799  "or from the surrounding environment. Either pass to the function with "
800  "a 'context=' argument or establish a default using 'with Context():'");
801  }
802  return *context;
803 }
804 
805 //------------------------------------------------------------------------------
806 // PyThreadContextEntry management
807 //------------------------------------------------------------------------------
808 
809 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
810  static thread_local std::vector<PyThreadContextEntry> stack;
811  return stack;
812 }
813 
815  auto &stack = getStack();
816  if (stack.empty())
817  return nullptr;
818  return &stack.back();
819 }
820 
821 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
822  nb::object insertionPoint,
823  nb::object location) {
824  auto &stack = getStack();
825  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
826  std::move(location));
827  // If the new stack has more than one entry and the context of the new top
828  // entry matches the previous, copy the insertionPoint and location from the
829  // previous entry if missing from the new top entry.
830  if (stack.size() > 1) {
831  auto &prev = *(stack.rbegin() + 1);
832  auto &current = stack.back();
833  if (current.context.is(prev.context)) {
834  // Default non-context objects from the previous entry.
835  if (!current.insertionPoint)
836  current.insertionPoint = prev.insertionPoint;
837  if (!current.location)
838  current.location = prev.location;
839  }
840  }
841 }
842 
844  if (!context)
845  return nullptr;
846  return nb::cast<PyMlirContext *>(context);
847 }
848 
850  if (!insertionPoint)
851  return nullptr;
852  return nb::cast<PyInsertionPoint *>(insertionPoint);
853 }
854 
856  if (!location)
857  return nullptr;
858  return nb::cast<PyLocation *>(location);
859 }
860 
862  auto *tos = getTopOfStack();
863  return tos ? tos->getContext() : nullptr;
864 }
865 
867  auto *tos = getTopOfStack();
868  return tos ? tos->getInsertionPoint() : nullptr;
869 }
870 
872  auto *tos = getTopOfStack();
873  return tos ? tos->getLocation() : nullptr;
874 }
875 
876 nb::object PyThreadContextEntry::pushContext(nb::object context) {
877  push(FrameKind::Context, /*context=*/context,
878  /*insertionPoint=*/nb::object(),
879  /*location=*/nb::object());
880  return context;
881 }
882 
884  auto &stack = getStack();
885  if (stack.empty())
886  throw std::runtime_error("Unbalanced Context enter/exit");
887  auto &tos = stack.back();
888  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
889  throw std::runtime_error("Unbalanced Context enter/exit");
890  stack.pop_back();
891 }
892 
893 nb::object
894 PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
895  PyInsertionPoint &insertionPoint =
896  nb::cast<PyInsertionPoint &>(insertionPointObj);
897  nb::object contextObj =
898  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
899  push(FrameKind::InsertionPoint,
900  /*context=*/contextObj,
901  /*insertionPoint=*/insertionPointObj,
902  /*location=*/nb::object());
903  return insertionPointObj;
904 }
905 
907  auto &stack = getStack();
908  if (stack.empty())
909  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
910  auto &tos = stack.back();
911  if (tos.frameKind != FrameKind::InsertionPoint &&
912  tos.getInsertionPoint() != &insertionPoint)
913  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
914  stack.pop_back();
915 }
916 
917 nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
918  PyLocation &location = nb::cast<PyLocation &>(locationObj);
919  nb::object contextObj = location.getContext().getObject();
920  push(FrameKind::Location, /*context=*/contextObj,
921  /*insertionPoint=*/nb::object(),
922  /*location=*/locationObj);
923  return locationObj;
924 }
925 
927  auto &stack = getStack();
928  if (stack.empty())
929  throw std::runtime_error("Unbalanced Location enter/exit");
930  auto &tos = stack.back();
931  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
932  throw std::runtime_error("Unbalanced Location enter/exit");
933  stack.pop_back();
934 }
935 
936 //------------------------------------------------------------------------------
937 // PyDiagnostic*
938 //------------------------------------------------------------------------------
939 
941  valid = false;
942  if (materializedNotes) {
943  for (nb::handle noteObject : *materializedNotes) {
944  PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
945  note->invalidate();
946  }
947  }
948 }
949 
951  nb::object callback)
952  : context(context), callback(std::move(callback)) {}
953 
955 
957  if (!registeredID)
958  return;
959  MlirDiagnosticHandlerID localID = *registeredID;
960  mlirContextDetachDiagnosticHandler(context, localID);
961  assert(!registeredID && "should have unregistered");
962  // Not strictly necessary but keeps stale pointers from being around to cause
963  // issues.
964  context = {nullptr};
965 }
966 
967 void PyDiagnostic::checkValid() {
968  if (!valid) {
969  throw std::invalid_argument(
970  "Diagnostic is invalid (used outside of callback)");
971  }
972 }
973 
975  checkValid();
976  return mlirDiagnosticGetSeverity(diagnostic);
977 }
978 
980  checkValid();
981  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
982  MlirContext context = mlirLocationGetContext(loc);
983  return PyLocation(PyMlirContext::forContext(context), loc);
984 }
985 
987  checkValid();
988  nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
989  PyFileAccumulator accum(fileObject, /*binary=*/false);
990  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
991  return nb::cast<nb::str>(fileObject.attr("getvalue")());
992 }
993 
995  checkValid();
996  if (materializedNotes)
997  return *materializedNotes;
998  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
999  nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
1000  for (intptr_t i = 0; i < numNotes; ++i) {
1001  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1002  nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1003  PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
1004  }
1005  materializedNotes = std::move(notes);
1006 
1007  return *materializedNotes;
1008 }
1009 
1011  std::vector<DiagnosticInfo> notes;
1012  for (nb::handle n : getNotes())
1013  notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1014  return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1015  std::move(notes)};
1016 }
1017 
1018 //------------------------------------------------------------------------------
1019 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1020 //------------------------------------------------------------------------------
1021 
1022 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1023  bool attrError) {
1024  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1025  {key.data(), key.size()});
1026  if (mlirDialectIsNull(dialect)) {
1027  std::string msg = (Twine("Dialect '") + key + "' not found").str();
1028  if (attrError)
1029  throw nb::attribute_error(msg.c_str());
1030  throw nb::index_error(msg.c_str());
1031  }
1032  return dialect;
1033 }
1034 
1036  return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1037 }
1038 
1040  MlirDialectRegistry rawRegistry =
1041  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1042  if (mlirDialectRegistryIsNull(rawRegistry))
1043  throw nb::python_error();
1044  return PyDialectRegistry(rawRegistry);
1045 }
1046 
1047 //------------------------------------------------------------------------------
1048 // PyLocation
1049 //------------------------------------------------------------------------------
1050 
1052  return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1053 }
1054 
1056  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1057  if (mlirLocationIsNull(rawLoc))
1058  throw nb::python_error();
1060  rawLoc);
1061 }
1062 
1063 nb::object PyLocation::contextEnter(nb::object locationObj) {
1064  return PyThreadContextEntry::pushLocation(locationObj);
1065 }
1066 
1067 void PyLocation::contextExit(const nb::object &excType,
1068  const nb::object &excVal,
1069  const nb::object &excTb) {
1071 }
1072 
1074  auto *location = PyThreadContextEntry::getDefaultLocation();
1075  if (!location) {
1076  throw std::runtime_error(
1077  "An MLIR function requires a Location but none was provided in the "
1078  "call or from the surrounding environment. Either pass to the function "
1079  "with a 'loc=' argument or establish a default using 'with loc:'");
1080  }
1081  return *location;
1082 }
1083 
1084 //------------------------------------------------------------------------------
1085 // PyModule
1086 //------------------------------------------------------------------------------
1087 
1088 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1089  : BaseContextObject(std::move(contextRef)), module(module) {}
1090 
1092  nb::gil_scoped_acquire acquire;
1093  auto &liveModules = getContext()->liveModules;
1094  assert(liveModules.count(module.ptr) == 1 &&
1095  "destroying module not in live map");
1096  liveModules.erase(module.ptr);
1097  mlirModuleDestroy(module);
1098 }
1099 
1100 PyModuleRef PyModule::forModule(MlirModule module) {
1101  MlirContext context = mlirModuleGetContext(module);
1102  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1103 
1104  nb::gil_scoped_acquire acquire;
1105  auto &liveModules = contextRef->liveModules;
1106  auto it = liveModules.find(module.ptr);
1107  if (it == liveModules.end()) {
1108  // Create.
1109  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1110  // Note that the default return value policy on cast is automatic_reference,
1111  // which does not take ownership (delete will not be called).
1112  // Just be explicit.
1113  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1114  unownedModule->handle = pyRef;
1115  liveModules[module.ptr] =
1116  std::make_pair(unownedModule->handle, unownedModule);
1117  return PyModuleRef(unownedModule, std::move(pyRef));
1118  }
1119  // Use existing.
1120  PyModule *existing = it->second.second;
1121  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1122  return PyModuleRef(existing, std::move(pyRef));
1123 }
1124 
1125 nb::object PyModule::createFromCapsule(nb::object capsule) {
1126  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1127  if (mlirModuleIsNull(rawModule))
1128  throw nb::python_error();
1129  return forModule(rawModule).releaseObject();
1130 }
1131 
1132 nb::object PyModule::getCapsule() {
1133  return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1134 }
1135 
1136 //------------------------------------------------------------------------------
1137 // PyOperation
1138 //------------------------------------------------------------------------------
1139 
1140 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1141  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1142 
1144  // If the operation has already been invalidated there is nothing to do.
1145  if (!valid)
1146  return;
1147 
1148  // Otherwise, invalidate the operation and remove it from live map when it is
1149  // attached.
1150  if (isAttached()) {
1151  getContext()->clearOperation(*this);
1152  } else {
1153  // And destroy it when it is detached, i.e. owned by Python, in which case
1154  // all nested operations must be invalidated at removed from the live map as
1155  // well.
1156  erase();
1157  }
1158 }
1159 
1160 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1161  MlirOperation operation,
1162  nb::object parentKeepAlive) {
1163  auto &liveOperations = contextRef->liveOperations;
1164  // Create.
1165  PyOperation *unownedOperation =
1166  new PyOperation(std::move(contextRef), operation);
1167  // Note that the default return value policy on cast is automatic_reference,
1168  // which does not take ownership (delete will not be called).
1169  // Just be explicit.
1170  nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership);
1171  unownedOperation->handle = pyRef;
1172  if (parentKeepAlive) {
1173  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1174  }
1175  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1176  return PyOperationRef(unownedOperation, std::move(pyRef));
1177 }
1178 
1180  MlirOperation operation,
1181  nb::object parentKeepAlive) {
1182  auto &liveOperations = contextRef->liveOperations;
1183  auto it = liveOperations.find(operation.ptr);
1184  if (it == liveOperations.end()) {
1185  // Create.
1186  return createInstance(std::move(contextRef), operation,
1187  std::move(parentKeepAlive));
1188  }
1189  // Use existing.
1190  PyOperation *existing = it->second.second;
1191  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1192  return PyOperationRef(existing, std::move(pyRef));
1193 }
1194 
1196  MlirOperation operation,
1197  nb::object parentKeepAlive) {
1198  auto &liveOperations = contextRef->liveOperations;
1199  assert(liveOperations.count(operation.ptr) == 0 &&
1200  "cannot create detached operation that already exists");
1201  (void)liveOperations;
1202 
1203  PyOperationRef created = createInstance(std::move(contextRef), operation,
1204  std::move(parentKeepAlive));
1205  created->attached = false;
1206  return created;
1207 }
1208 
1210  const std::string &sourceStr,
1211  const std::string &sourceName) {
1212  PyMlirContext::ErrorCapture errors(contextRef);
1213  MlirOperation op =
1214  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1215  toMlirStringRef(sourceName));
1216  if (mlirOperationIsNull(op))
1217  throw MLIRError("Unable to parse operation assembly", errors.take());
1218  return PyOperation::createDetached(std::move(contextRef), op);
1219 }
1220 
1222  if (!valid) {
1223  throw std::runtime_error("the operation has been invalidated");
1224  }
1225 }
1226 
1227 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1228  bool enableDebugInfo, bool prettyDebugInfo,
1229  bool printGenericOpForm, bool useLocalScope,
1230  bool assumeVerified, nb::object fileObject,
1231  bool binary, bool skipRegions) {
1232  PyOperation &operation = getOperation();
1233  operation.checkValid();
1234  if (fileObject.is_none())
1235  fileObject = nb::module_::import_("sys").attr("stdout");
1236 
1237  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1238  if (largeElementsLimit)
1239  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1240  if (enableDebugInfo)
1241  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1242  /*prettyForm=*/prettyDebugInfo);
1243  if (printGenericOpForm)
1245  if (useLocalScope)
1247  if (assumeVerified)
1249  if (skipRegions)
1251 
1252  PyFileAccumulator accum(fileObject, binary);
1253  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1254  accum.getUserData());
1256 }
1257 
1258 void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1259  bool binary) {
1260  PyOperation &operation = getOperation();
1261  operation.checkValid();
1262  if (fileObject.is_none())
1263  fileObject = nb::module_::import_("sys").attr("stdout");
1264  PyFileAccumulator accum(fileObject, binary);
1265  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1266  accum.getUserData());
1267 }
1268 
1269 void PyOperationBase::writeBytecode(const nb::object &fileObject,
1270  std::optional<int64_t> bytecodeVersion) {
1271  PyOperation &operation = getOperation();
1272  operation.checkValid();
1273  PyFileAccumulator accum(fileObject, /*binary=*/true);
1274 
1275  if (!bytecodeVersion.has_value())
1276  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1277  accum.getUserData());
1278 
1279  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1282  operation, config, accum.getCallback(), accum.getUserData());
1284  if (mlirLogicalResultIsFailure(res))
1285  throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1286  Twine(*bytecodeVersion))
1287  .str()
1288  .c_str());
1289 }
1290 
1292  std::function<MlirWalkResult(MlirOperation)> callback,
1293  MlirWalkOrder walkOrder) {
1294  PyOperation &operation = getOperation();
1295  operation.checkValid();
1296  struct UserData {
1297  std::function<MlirWalkResult(MlirOperation)> callback;
1298  bool gotException;
1299  std::string exceptionWhat;
1300  nb::object exceptionType;
1301  };
1302  UserData userData{callback, false, {}, {}};
1303  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1304  void *userData) {
1305  UserData *calleeUserData = static_cast<UserData *>(userData);
1306  try {
1307  return (calleeUserData->callback)(op);
1308  } catch (nb::python_error &e) {
1309  calleeUserData->gotException = true;
1310  calleeUserData->exceptionWhat = std::string(e.what());
1311  calleeUserData->exceptionType = nb::borrow(e.type());
1313  }
1314  };
1315  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1316  if (userData.gotException) {
1317  std::string message("Exception raised in callback: ");
1318  message.append(userData.exceptionWhat);
1319  throw std::runtime_error(message);
1320  }
1321 }
1322 
1323 nb::object PyOperationBase::getAsm(bool binary,
1324  std::optional<int64_t> largeElementsLimit,
1325  bool enableDebugInfo, bool prettyDebugInfo,
1326  bool printGenericOpForm, bool useLocalScope,
1327  bool assumeVerified, bool skipRegions) {
1328  nb::object fileObject;
1329  if (binary) {
1330  fileObject = nb::module_::import_("io").attr("BytesIO")();
1331  } else {
1332  fileObject = nb::module_::import_("io").attr("StringIO")();
1333  }
1334  print(/*largeElementsLimit=*/largeElementsLimit,
1335  /*enableDebugInfo=*/enableDebugInfo,
1336  /*prettyDebugInfo=*/prettyDebugInfo,
1337  /*printGenericOpForm=*/printGenericOpForm,
1338  /*useLocalScope=*/useLocalScope,
1339  /*assumeVerified=*/assumeVerified,
1340  /*fileObject=*/fileObject,
1341  /*binary=*/binary,
1342  /*skipRegions=*/skipRegions);
1343 
1344  return fileObject.attr("getvalue")();
1345 }
1346 
1348  PyOperation &operation = getOperation();
1349  PyOperation &otherOp = other.getOperation();
1350  operation.checkValid();
1351  otherOp.checkValid();
1352  mlirOperationMoveAfter(operation, otherOp);
1353  operation.parentKeepAlive = otherOp.parentKeepAlive;
1354 }
1355 
1357  PyOperation &operation = getOperation();
1358  PyOperation &otherOp = other.getOperation();
1359  operation.checkValid();
1360  otherOp.checkValid();
1361  mlirOperationMoveBefore(operation, otherOp);
1362  operation.parentKeepAlive = otherOp.parentKeepAlive;
1363 }
1364 
1366  PyOperation &op = getOperation();
1368  if (!mlirOperationVerify(op.get()))
1369  throw MLIRError("Verification failed", errors.take());
1370  return true;
1371 }
1372 
1373 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1374  checkValid();
1375  if (!isAttached())
1376  throw nb::value_error("Detached operations have no parent");
1377  MlirOperation operation = mlirOperationGetParentOperation(get());
1378  if (mlirOperationIsNull(operation))
1379  return {};
1380  return PyOperation::forOperation(getContext(), operation);
1381 }
1382 
1384  checkValid();
1385  std::optional<PyOperationRef> parentOperation = getParentOperation();
1386  MlirBlock block = mlirOperationGetBlock(get());
1387  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1388  assert(parentOperation && "Operation has no parent");
1389  return PyBlock{std::move(*parentOperation), block};
1390 }
1391 
1393  checkValid();
1394  return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1395 }
1396 
1397 nb::object PyOperation::createFromCapsule(nb::object capsule) {
1398  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1399  if (mlirOperationIsNull(rawOperation))
1400  throw nb::python_error();
1401  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1402  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1403  .releaseObject();
1404 }
1405 
1407  const nb::object &maybeIp) {
1408  // InsertPoint active?
1409  if (!maybeIp.is(nb::cast(false))) {
1410  PyInsertionPoint *ip;
1411  if (maybeIp.is_none()) {
1413  } else {
1414  ip = nb::cast<PyInsertionPoint *>(maybeIp);
1415  }
1416  if (ip)
1417  ip->insert(*op.get());
1418  }
1419 }
1420 
1421 nb::object PyOperation::create(const std::string &name,
1422  std::optional<std::vector<PyType *>> results,
1423  std::optional<std::vector<PyValue *>> operands,
1424  std::optional<nb::dict> attributes,
1425  std::optional<std::vector<PyBlock *>> successors,
1426  int regions, DefaultingPyLocation location,
1427  const nb::object &maybeIp, bool inferType) {
1428  llvm::SmallVector<MlirValue, 4> mlirOperands;
1429  llvm::SmallVector<MlirType, 4> mlirResults;
1430  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1432 
1433  // General parameter validation.
1434  if (regions < 0)
1435  throw nb::value_error("number of regions must be >= 0");
1436 
1437  // Unpack/validate operands.
1438  if (operands) {
1439  mlirOperands.reserve(operands->size());
1440  for (PyValue *operand : *operands) {
1441  if (!operand)
1442  throw nb::value_error("operand value cannot be None");
1443  mlirOperands.push_back(operand->get());
1444  }
1445  }
1446 
1447  // Unpack/validate results.
1448  if (results) {
1449  mlirResults.reserve(results->size());
1450  for (PyType *result : *results) {
1451  // TODO: Verify result type originate from the same context.
1452  if (!result)
1453  throw nb::value_error("result type cannot be None");
1454  mlirResults.push_back(*result);
1455  }
1456  }
1457  // Unpack/validate attributes.
1458  if (attributes) {
1459  mlirAttributes.reserve(attributes->size());
1460  for (std::pair<nb::handle, nb::handle> it : *attributes) {
1461  std::string key;
1462  try {
1463  key = nb::cast<std::string>(it.first);
1464  } catch (nb::cast_error &err) {
1465  std::string msg = "Invalid attribute key (not a string) when "
1466  "attempting to create the operation \"" +
1467  name + "\" (" + err.what() + ")";
1468  throw nb::type_error(msg.c_str());
1469  }
1470  try {
1471  auto &attribute = nb::cast<PyAttribute &>(it.second);
1472  // TODO: Verify attribute originates from the same context.
1473  mlirAttributes.emplace_back(std::move(key), attribute);
1474  } catch (nb::cast_error &err) {
1475  std::string msg = "Invalid attribute value for the key \"" + key +
1476  "\" when attempting to create the operation \"" +
1477  name + "\" (" + err.what() + ")";
1478  throw nb::type_error(msg.c_str());
1479  } catch (std::runtime_error &) {
1480  // This exception seems thrown when the value is "None".
1481  std::string msg =
1482  "Found an invalid (`None`?) attribute value for the key \"" + key +
1483  "\" when attempting to create the operation \"" + name + "\"";
1484  throw std::runtime_error(msg);
1485  }
1486  }
1487  }
1488  // Unpack/validate successors.
1489  if (successors) {
1490  mlirSuccessors.reserve(successors->size());
1491  for (auto *successor : *successors) {
1492  // TODO: Verify successor originate from the same context.
1493  if (!successor)
1494  throw nb::value_error("successor block cannot be None");
1495  mlirSuccessors.push_back(successor->get());
1496  }
1497  }
1498 
1499  // Apply unpacked/validated to the operation state. Beyond this
1500  // point, exceptions cannot be thrown or else the state will leak.
1501  MlirOperationState state =
1502  mlirOperationStateGet(toMlirStringRef(name), location);
1503  if (!mlirOperands.empty())
1504  mlirOperationStateAddOperands(&state, mlirOperands.size(),
1505  mlirOperands.data());
1506  state.enableResultTypeInference = inferType;
1507  if (!mlirResults.empty())
1508  mlirOperationStateAddResults(&state, mlirResults.size(),
1509  mlirResults.data());
1510  if (!mlirAttributes.empty()) {
1511  // Note that the attribute names directly reference bytes in
1512  // mlirAttributes, so that vector must not be changed from here
1513  // on.
1514  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1515  mlirNamedAttributes.reserve(mlirAttributes.size());
1516  for (auto &it : mlirAttributes)
1517  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1519  toMlirStringRef(it.first)),
1520  it.second));
1521  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1522  mlirNamedAttributes.data());
1523  }
1524  if (!mlirSuccessors.empty())
1525  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1526  mlirSuccessors.data());
1527  if (regions) {
1529  mlirRegions.resize(regions);
1530  for (int i = 0; i < regions; ++i)
1531  mlirRegions[i] = mlirRegionCreate();
1532  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1533  mlirRegions.data());
1534  }
1535 
1536  // Construct the operation.
1537  MlirOperation operation = mlirOperationCreate(&state);
1538  if (!operation.ptr)
1539  throw nb::value_error("Operation creation failed");
1540  PyOperationRef created =
1541  PyOperation::createDetached(location->getContext(), operation);
1542  maybeInsertOperation(created, maybeIp);
1543 
1544  return created.getObject();
1545 }
1546 
1547 nb::object PyOperation::clone(const nb::object &maybeIp) {
1548  MlirOperation clonedOperation = mlirOperationClone(operation);
1549  PyOperationRef cloned =
1550  PyOperation::createDetached(getContext(), clonedOperation);
1551  maybeInsertOperation(cloned, maybeIp);
1552 
1553  return cloned->createOpView();
1554 }
1555 
1557  checkValid();
1558  MlirIdentifier ident = mlirOperationGetName(get());
1559  MlirStringRef identStr = mlirIdentifierStr(ident);
1560  auto operationCls = PyGlobals::get().lookupOperationClass(
1561  StringRef(identStr.data, identStr.length));
1562  if (operationCls)
1563  return PyOpView::constructDerived(*operationCls, getRef().getObject());
1564  return nb::cast(PyOpView(getRef().getObject()));
1565 }
1566 
1568  checkValid();
1570  mlirOperationDestroy(operation);
1571 }
1572 
1573 //------------------------------------------------------------------------------
1574 // PyOpView
1575 //------------------------------------------------------------------------------
1576 
1577 static void populateResultTypes(StringRef name, nb::list resultTypeList,
1578  const nb::object &resultSegmentSpecObj,
1579  std::vector<int32_t> &resultSegmentLengths,
1580  std::vector<PyType *> &resultTypes) {
1581  resultTypes.reserve(resultTypeList.size());
1582  if (resultSegmentSpecObj.is_none()) {
1583  // Non-variadic result unpacking.
1584  for (const auto &it : llvm::enumerate(resultTypeList)) {
1585  try {
1586  resultTypes.push_back(nb::cast<PyType *>(it.value()));
1587  if (!resultTypes.back())
1588  throw nb::cast_error();
1589  } catch (nb::cast_error &err) {
1590  throw nb::value_error((llvm::Twine("Result ") +
1591  llvm::Twine(it.index()) + " of operation \"" +
1592  name + "\" must be a Type (" + err.what() + ")")
1593  .str()
1594  .c_str());
1595  }
1596  }
1597  } else {
1598  // Sized result unpacking.
1599  auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1600  if (resultSegmentSpec.size() != resultTypeList.size()) {
1601  throw nb::value_error((llvm::Twine("Operation \"") + name +
1602  "\" requires " +
1603  llvm::Twine(resultSegmentSpec.size()) +
1604  " result segments but was provided " +
1605  llvm::Twine(resultTypeList.size()))
1606  .str()
1607  .c_str());
1608  }
1609  resultSegmentLengths.reserve(resultTypeList.size());
1610  for (const auto &it :
1611  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1612  int segmentSpec = std::get<1>(it.value());
1613  if (segmentSpec == 1 || segmentSpec == 0) {
1614  // Unpack unary element.
1615  try {
1616  auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1617  if (resultType) {
1618  resultTypes.push_back(resultType);
1619  resultSegmentLengths.push_back(1);
1620  } else if (segmentSpec == 0) {
1621  // Allowed to be optional.
1622  resultSegmentLengths.push_back(0);
1623  } else {
1624  throw nb::value_error(
1625  (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1626  " of operation \"" + name +
1627  "\" must be a Type (was None and result is not optional)")
1628  .str()
1629  .c_str());
1630  }
1631  } catch (nb::cast_error &err) {
1632  throw nb::value_error((llvm::Twine("Result ") +
1633  llvm::Twine(it.index()) + " of operation \"" +
1634  name + "\" must be a Type (" + err.what() +
1635  ")")
1636  .str()
1637  .c_str());
1638  }
1639  } else if (segmentSpec == -1) {
1640  // Unpack sequence by appending.
1641  try {
1642  if (std::get<0>(it.value()).is_none()) {
1643  // Treat it as an empty list.
1644  resultSegmentLengths.push_back(0);
1645  } else {
1646  // Unpack the list.
1647  auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1648  for (nb::handle segmentItem : segment) {
1649  resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1650  if (!resultTypes.back()) {
1651  throw nb::type_error("contained a None item");
1652  }
1653  }
1654  resultSegmentLengths.push_back(nb::len(segment));
1655  }
1656  } catch (std::exception &err) {
1657  // NOTE: Sloppy to be using a catch-all here, but there are at least
1658  // three different unrelated exceptions that can be thrown in the
1659  // above "casts". Just keep the scope above small and catch them all.
1660  throw nb::value_error((llvm::Twine("Result ") +
1661  llvm::Twine(it.index()) + " of operation \"" +
1662  name + "\" must be a Sequence of Types (" +
1663  err.what() + ")")
1664  .str()
1665  .c_str());
1666  }
1667  } else {
1668  throw nb::value_error("Unexpected segment spec");
1669  }
1670  }
1671  }
1672 }
1673 
1675  const nb::object &cls, std::optional<nb::list> resultTypeList,
1676  nb::list operandList, std::optional<nb::dict> attributes,
1677  std::optional<std::vector<PyBlock *>> successors,
1678  std::optional<int> regions, DefaultingPyLocation location,
1679  const nb::object &maybeIp) {
1680  PyMlirContextRef context = location->getContext();
1681  // Class level operation construction metadata.
1682  std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
1683  // Operand and result segment specs are either none, which does no
1684  // variadic unpacking, or a list of ints with segment sizes, where each
1685  // element is either a positive number (typically 1 for a scalar) or -1 to
1686  // indicate that it is derived from the length of the same-indexed operand
1687  // or result (implying that it is a list at that position).
1688  nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1689  nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1690 
1691  std::vector<int32_t> operandSegmentLengths;
1692  std::vector<int32_t> resultSegmentLengths;
1693 
1694  // Validate/determine region count.
1695  auto opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1696  int opMinRegionCount = std::get<0>(opRegionSpec);
1697  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1698  if (!regions) {
1699  regions = opMinRegionCount;
1700  }
1701  if (*regions < opMinRegionCount) {
1702  throw nb::value_error(
1703  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1704  llvm::Twine(opMinRegionCount) +
1705  " regions but was built with regions=" + llvm::Twine(*regions))
1706  .str()
1707  .c_str());
1708  }
1709  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1710  throw nb::value_error(
1711  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1712  llvm::Twine(opMinRegionCount) +
1713  " regions but was built with regions=" + llvm::Twine(*regions))
1714  .str()
1715  .c_str());
1716  }
1717 
1718  // Unpack results.
1719  std::vector<PyType *> resultTypes;
1720  if (resultTypeList.has_value()) {
1721  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1722  resultSegmentLengths, resultTypes);
1723  }
1724 
1725  // Unpack operands.
1726  std::vector<PyValue *> operands;
1727  operands.reserve(operands.size());
1728  if (operandSegmentSpecObj.is_none()) {
1729  // Non-sized operand unpacking.
1730  for (const auto &it : llvm::enumerate(operandList)) {
1731  try {
1732  operands.push_back(nb::cast<PyValue *>(it.value()));
1733  if (!operands.back())
1734  throw nb::cast_error();
1735  } catch (nb::cast_error &err) {
1736  throw nb::value_error((llvm::Twine("Operand ") +
1737  llvm::Twine(it.index()) + " of operation \"" +
1738  name + "\" must be a Value (" + err.what() + ")")
1739  .str()
1740  .c_str());
1741  }
1742  }
1743  } else {
1744  // Sized operand unpacking.
1745  auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1746  if (operandSegmentSpec.size() != operandList.size()) {
1747  throw nb::value_error((llvm::Twine("Operation \"") + name +
1748  "\" requires " +
1749  llvm::Twine(operandSegmentSpec.size()) +
1750  "operand segments but was provided " +
1751  llvm::Twine(operandList.size()))
1752  .str()
1753  .c_str());
1754  }
1755  operandSegmentLengths.reserve(operandList.size());
1756  for (const auto &it :
1757  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1758  int segmentSpec = std::get<1>(it.value());
1759  if (segmentSpec == 1 || segmentSpec == 0) {
1760  // Unpack unary element.
1761  try {
1762  auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
1763  if (operandValue) {
1764  operands.push_back(operandValue);
1765  operandSegmentLengths.push_back(1);
1766  } else if (segmentSpec == 0) {
1767  // Allowed to be optional.
1768  operandSegmentLengths.push_back(0);
1769  } else {
1770  throw nb::value_error(
1771  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
1772  " of operation \"" + name +
1773  "\" must be a Value (was None and operand is not optional)")
1774  .str()
1775  .c_str());
1776  }
1777  } catch (nb::cast_error &err) {
1778  throw nb::value_error((llvm::Twine("Operand ") +
1779  llvm::Twine(it.index()) + " of operation \"" +
1780  name + "\" must be a Value (" + err.what() +
1781  ")")
1782  .str()
1783  .c_str());
1784  }
1785  } else if (segmentSpec == -1) {
1786  // Unpack sequence by appending.
1787  try {
1788  if (std::get<0>(it.value()).is_none()) {
1789  // Treat it as an empty list.
1790  operandSegmentLengths.push_back(0);
1791  } else {
1792  // Unpack the list.
1793  auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1794  for (nb::handle segmentItem : segment) {
1795  operands.push_back(nb::cast<PyValue *>(segmentItem));
1796  if (!operands.back()) {
1797  throw nb::type_error("contained a None item");
1798  }
1799  }
1800  operandSegmentLengths.push_back(nb::len(segment));
1801  }
1802  } catch (std::exception &err) {
1803  // NOTE: Sloppy to be using a catch-all here, but there are at least
1804  // three different unrelated exceptions that can be thrown in the
1805  // above "casts". Just keep the scope above small and catch them all.
1806  throw nb::value_error((llvm::Twine("Operand ") +
1807  llvm::Twine(it.index()) + " of operation \"" +
1808  name + "\" must be a Sequence of Values (" +
1809  err.what() + ")")
1810  .str()
1811  .c_str());
1812  }
1813  } else {
1814  throw nb::value_error("Unexpected segment spec");
1815  }
1816  }
1817  }
1818 
1819  // Merge operand/result segment lengths into attributes if needed.
1820  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1821  // Dup.
1822  if (attributes) {
1823  attributes = nb::dict(*attributes);
1824  } else {
1825  attributes = nb::dict();
1826  }
1827  if (attributes->contains("resultSegmentSizes") ||
1828  attributes->contains("operandSegmentSizes")) {
1829  throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1830  "'operandSegmentSizes' attribute is unsupported. "
1831  "Use Operation.create for such low-level access.");
1832  }
1833 
1834  // Add resultSegmentSizes attribute.
1835  if (!resultSegmentLengths.empty()) {
1836  MlirAttribute segmentLengthAttr =
1837  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1838  resultSegmentLengths.data());
1839  (*attributes)["resultSegmentSizes"] =
1840  PyAttribute(context, segmentLengthAttr);
1841  }
1842 
1843  // Add operandSegmentSizes attribute.
1844  if (!operandSegmentLengths.empty()) {
1845  MlirAttribute segmentLengthAttr =
1846  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1847  operandSegmentLengths.data());
1848  (*attributes)["operandSegmentSizes"] =
1849  PyAttribute(context, segmentLengthAttr);
1850  }
1851  }
1852 
1853  // Delegate to create.
1854  return PyOperation::create(name,
1855  /*results=*/std::move(resultTypes),
1856  /*operands=*/std::move(operands),
1857  /*attributes=*/std::move(attributes),
1858  /*successors=*/std::move(successors),
1859  /*regions=*/*regions, location, maybeIp,
1860  !resultTypeList);
1861 }
1862 
1863 nb::object PyOpView::constructDerived(const nb::object &cls,
1864  const nb::object &operation) {
1865  nb::handle opViewType = nb::type<PyOpView>();
1866  nb::object instance = cls.attr("__new__")(cls);
1867  opViewType.attr("__init__")(instance, operation);
1868  return instance;
1869 }
1870 
1871 PyOpView::PyOpView(const nb::object &operationObject)
1872  // Casting through the PyOperationBase base-class and then back to the
1873  // Operation lets us accept any PyOperationBase subclass.
1874  : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1875  operationObject(operation.getRef().getObject()) {}
1876 
1877 //------------------------------------------------------------------------------
1878 // PyInsertionPoint.
1879 //------------------------------------------------------------------------------
1880 
1882 
1884  : refOperation(beforeOperationBase.getOperation().getRef()),
1885  block((*refOperation)->getBlock()) {}
1886 
1888  PyOperation &operation = operationBase.getOperation();
1889  if (operation.isAttached())
1890  throw nb::value_error(
1891  "Attempt to insert operation that is already attached");
1892  block.getParentOperation()->checkValid();
1893  MlirOperation beforeOp = {nullptr};
1894  if (refOperation) {
1895  // Insert before operation.
1896  (*refOperation)->checkValid();
1897  beforeOp = (*refOperation)->get();
1898  } else {
1899  // Insert at end (before null) is only valid if the block does not
1900  // already end in a known terminator (violating this will cause assertion
1901  // failures later).
1902  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1903  throw nb::index_error("Cannot insert operation at the end of a block "
1904  "that already has a terminator. Did you mean to "
1905  "use 'InsertionPoint.at_block_terminator(block)' "
1906  "versus 'InsertionPoint(block)'?");
1907  }
1908  }
1909  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1910  operation.setAttached();
1911 }
1912 
1914  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1915  if (mlirOperationIsNull(firstOp)) {
1916  // Just insert at end.
1917  return PyInsertionPoint(block);
1918  }
1919 
1920  // Insert before first op.
1922  block.getParentOperation()->getContext(), firstOp);
1923  return PyInsertionPoint{block, std::move(firstOpRef)};
1924 }
1925 
1927  MlirOperation terminator = mlirBlockGetTerminator(block.get());
1928  if (mlirOperationIsNull(terminator))
1929  throw nb::value_error("Block has no terminator");
1930  PyOperationRef terminatorOpRef = PyOperation::forOperation(
1931  block.getParentOperation()->getContext(), terminator);
1932  return PyInsertionPoint{block, std::move(terminatorOpRef)};
1933 }
1934 
1935 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
1936  return PyThreadContextEntry::pushInsertionPoint(insertPoint);
1937 }
1938 
1939 void PyInsertionPoint::contextExit(const nb::object &excType,
1940  const nb::object &excVal,
1941  const nb::object &excTb) {
1943 }
1944 
1945 //------------------------------------------------------------------------------
1946 // PyAttribute.
1947 //------------------------------------------------------------------------------
1948 
1949 bool PyAttribute::operator==(const PyAttribute &other) const {
1950  return mlirAttributeEqual(attr, other.attr);
1951 }
1952 
1954  return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
1955 }
1956 
1958  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1959  if (mlirAttributeIsNull(rawAttr))
1960  throw nb::python_error();
1961  return PyAttribute(
1963 }
1964 
1965 //------------------------------------------------------------------------------
1966 // PyNamedAttribute.
1967 //------------------------------------------------------------------------------
1968 
1969 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1970  : ownedName(new std::string(std::move(ownedName))) {
1973  toMlirStringRef(*this->ownedName)),
1974  attr);
1975 }
1976 
1977 //------------------------------------------------------------------------------
1978 // PyType.
1979 //------------------------------------------------------------------------------
1980 
1981 bool PyType::operator==(const PyType &other) const {
1982  return mlirTypeEqual(type, other.type);
1983 }
1984 
1985 nb::object PyType::getCapsule() {
1986  return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
1987 }
1988 
1989 PyType PyType::createFromCapsule(nb::object capsule) {
1990  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1991  if (mlirTypeIsNull(rawType))
1992  throw nb::python_error();
1994  rawType);
1995 }
1996 
1997 //------------------------------------------------------------------------------
1998 // PyTypeID.
1999 //------------------------------------------------------------------------------
2000 
2001 nb::object PyTypeID::getCapsule() {
2002  return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2003 }
2004 
2006  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2007  if (mlirTypeIDIsNull(mlirTypeID))
2008  throw nb::python_error();
2009  return PyTypeID(mlirTypeID);
2010 }
2011 bool PyTypeID::operator==(const PyTypeID &other) const {
2012  return mlirTypeIDEqual(typeID, other.typeID);
2013 }
2014 
2015 //------------------------------------------------------------------------------
2016 // PyValue and subclasses.
2017 //------------------------------------------------------------------------------
2018 
2019 nb::object PyValue::getCapsule() {
2020  return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2021 }
2022 
2024  MlirType type = mlirValueGetType(get());
2025  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2026  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2027  "mlirTypeID was expected to be non-null.");
2028  std::optional<nb::callable> valueCaster =
2030  // nb::rv_policy::move means use std::move to move the return value
2031  // contents into a new instance that will be owned by Python.
2032  nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2033  if (!valueCaster)
2034  return thisObj;
2035  return valueCaster.value()(thisObj);
2036 }
2037 
2039  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2040  if (mlirValueIsNull(value))
2041  throw nb::python_error();
2042  MlirOperation owner;
2043  if (mlirValueIsAOpResult(value))
2044  owner = mlirOpResultGetOwner(value);
2045  if (mlirValueIsABlockArgument(value))
2047  if (mlirOperationIsNull(owner))
2048  throw nb::python_error();
2049  MlirContext ctx = mlirOperationGetContext(owner);
2050  PyOperationRef ownerRef =
2052  return PyValue(ownerRef, value);
2053 }
2054 
2055 //------------------------------------------------------------------------------
2056 // PySymbolTable.
2057 //------------------------------------------------------------------------------
2058 
2060  : operation(operation.getOperation().getRef()) {
2061  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2062  if (mlirSymbolTableIsNull(symbolTable)) {
2063  throw nb::type_error("Operation is not a Symbol Table.");
2064  }
2065 }
2066 
2067 nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2068  operation->checkValid();
2069  MlirOperation symbol = mlirSymbolTableLookup(
2070  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2071  if (mlirOperationIsNull(symbol))
2072  throw nb::key_error(
2073  ("Symbol '" + name + "' not in the symbol table.").c_str());
2074 
2075  return PyOperation::forOperation(operation->getContext(), symbol,
2076  operation.getObject())
2077  ->createOpView();
2078 }
2079 
2081  operation->checkValid();
2082  symbol.getOperation().checkValid();
2083  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2084  // The operation is also erased, so we must invalidate it. There may be Python
2085  // references to this operation so we don't want to delete it from the list of
2086  // live operations here.
2087  symbol.getOperation().valid = false;
2088 }
2089 
2090 void PySymbolTable::dunderDel(const std::string &name) {
2091  nb::object operation = dunderGetItem(name);
2092  erase(nb::cast<PyOperationBase &>(operation));
2093 }
2094 
2095 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2096  operation->checkValid();
2097  symbol.getOperation().checkValid();
2098  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2100  if (mlirAttributeIsNull(symbolAttr))
2101  throw nb::value_error("Expected operation to have a symbol name.");
2102  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2103 }
2104 
2106  // Op must already be a symbol.
2107  PyOperation &operation = symbol.getOperation();
2108  operation.checkValid();
2110  MlirAttribute existingNameAttr =
2111  mlirOperationGetAttributeByName(operation.get(), attrName);
2112  if (mlirAttributeIsNull(existingNameAttr))
2113  throw nb::value_error("Expected operation to have a symbol name.");
2114  return existingNameAttr;
2115 }
2116 
2118  const std::string &name) {
2119  // Op must already be a symbol.
2120  PyOperation &operation = symbol.getOperation();
2121  operation.checkValid();
2123  MlirAttribute existingNameAttr =
2124  mlirOperationGetAttributeByName(operation.get(), attrName);
2125  if (mlirAttributeIsNull(existingNameAttr))
2126  throw nb::value_error("Expected operation to have a symbol name.");
2127  MlirAttribute newNameAttr =
2128  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2129  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2130 }
2131 
2133  PyOperation &operation = symbol.getOperation();
2134  operation.checkValid();
2136  MlirAttribute existingVisAttr =
2137  mlirOperationGetAttributeByName(operation.get(), attrName);
2138  if (mlirAttributeIsNull(existingVisAttr))
2139  throw nb::value_error("Expected operation to have a symbol visibility.");
2140  return existingVisAttr;
2141 }
2142 
2144  const std::string &visibility) {
2145  if (visibility != "public" && visibility != "private" &&
2146  visibility != "nested")
2147  throw nb::value_error(
2148  "Expected visibility to be 'public', 'private' or 'nested'");
2149  PyOperation &operation = symbol.getOperation();
2150  operation.checkValid();
2152  MlirAttribute existingVisAttr =
2153  mlirOperationGetAttributeByName(operation.get(), attrName);
2154  if (mlirAttributeIsNull(existingVisAttr))
2155  throw nb::value_error("Expected operation to have a symbol visibility.");
2156  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2157  toMlirStringRef(visibility));
2158  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2159 }
2160 
2161 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2162  const std::string &newSymbol,
2163  PyOperationBase &from) {
2164  PyOperation &fromOperation = from.getOperation();
2165  fromOperation.checkValid();
2167  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2168  from.getOperation())))
2169 
2170  throw nb::value_error("Symbol rename failed");
2171 }
2172 
2174  bool allSymUsesVisible,
2175  nb::object callback) {
2176  PyOperation &fromOperation = from.getOperation();
2177  fromOperation.checkValid();
2178  struct UserData {
2179  PyMlirContextRef context;
2180  nb::object callback;
2181  bool gotException;
2182  std::string exceptionWhat;
2183  nb::object exceptionType;
2184  };
2185  UserData userData{
2186  fromOperation.getContext(), std::move(callback), false, {}, {}};
2188  fromOperation.get(), allSymUsesVisible,
2189  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2190  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2191  auto pyFoundOp =
2192  PyOperation::forOperation(calleeUserData->context, foundOp);
2193  if (calleeUserData->gotException)
2194  return;
2195  try {
2196  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2197  } catch (nb::python_error &e) {
2198  calleeUserData->gotException = true;
2199  calleeUserData->exceptionWhat = e.what();
2200  calleeUserData->exceptionType = nb::borrow(e.type());
2201  }
2202  },
2203  static_cast<void *>(&userData));
2204  if (userData.gotException) {
2205  std::string message("Exception raised in callback: ");
2206  message.append(userData.exceptionWhat);
2207  throw std::runtime_error(message);
2208  }
2209 }
2210 
2211 namespace {
2212 /// CRTP base class for Python MLIR values that subclass Value and should be
2213 /// castable from it. The value hierarchy is one level deep and is not supposed
2214 /// to accommodate other levels unless core MLIR changes.
2215 template <typename DerivedTy>
2216 class PyConcreteValue : public PyValue {
2217 public:
2218  // Derived classes must define statics for:
2219  // IsAFunctionTy isaFunction
2220  // const char *pyClassName
2221  // and redefine bindDerived.
2222  using ClassTy = nb::class_<DerivedTy, PyValue>;
2223  using IsAFunctionTy = bool (*)(MlirValue);
2224 
2225  PyConcreteValue() = default;
2226  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2227  : PyValue(operationRef, value) {}
2228  PyConcreteValue(PyValue &orig)
2229  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2230 
2231  /// Attempts to cast the original value to the derived type and throws on
2232  /// type mismatches.
2233  static MlirValue castFrom(PyValue &orig) {
2234  if (!DerivedTy::isaFunction(orig.get())) {
2235  auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
2236  throw nb::value_error((Twine("Cannot cast value to ") +
2237  DerivedTy::pyClassName + " (from " + origRepr +
2238  ")")
2239  .str()
2240  .c_str());
2241  }
2242  return orig.get();
2243  }
2244 
2245  /// Binds the Python module objects to functions of this class.
2246  static void bind(nb::module_ &m) {
2247  auto cls = ClassTy(m, DerivedTy::pyClassName);
2248  cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
2249  cls.def_static(
2250  "isinstance",
2251  [](PyValue &otherValue) -> bool {
2252  return DerivedTy::isaFunction(otherValue);
2253  },
2254  nb::arg("other_value"));
2256  [](DerivedTy &self) { return self.maybeDownCast(); });
2257  DerivedTy::bindDerived(cls);
2258  }
2259 
2260  /// Implemented by derived classes to add methods to the Python subclass.
2261  static void bindDerived(ClassTy &m) {}
2262 };
2263 
2264 /// Python wrapper for MlirBlockArgument.
2265 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2266 public:
2267  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2268  static constexpr const char *pyClassName = "BlockArgument";
2269  using PyConcreteValue::PyConcreteValue;
2270 
2271  static void bindDerived(ClassTy &c) {
2272  c.def_prop_ro("owner", [](PyBlockArgument &self) {
2273  return PyBlock(self.getParentOperation(),
2274  mlirBlockArgumentGetOwner(self.get()));
2275  });
2276  c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2277  return mlirBlockArgumentGetArgNumber(self.get());
2278  });
2279  c.def(
2280  "set_type",
2281  [](PyBlockArgument &self, PyType type) {
2282  return mlirBlockArgumentSetType(self.get(), type);
2283  },
2284  nb::arg("type"));
2285  }
2286 };
2287 
2288 /// Python wrapper for MlirOpResult.
2289 class PyOpResult : public PyConcreteValue<PyOpResult> {
2290 public:
2291  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2292  static constexpr const char *pyClassName = "OpResult";
2293  using PyConcreteValue::PyConcreteValue;
2294 
2295  static void bindDerived(ClassTy &c) {
2296  c.def_prop_ro("owner", [](PyOpResult &self) {
2297  assert(
2298  mlirOperationEqual(self.getParentOperation()->get(),
2299  mlirOpResultGetOwner(self.get())) &&
2300  "expected the owner of the value in Python to match that in the IR");
2301  return self.getParentOperation().getObject();
2302  });
2303  c.def_prop_ro("result_number", [](PyOpResult &self) {
2304  return mlirOpResultGetResultNumber(self.get());
2305  });
2306  }
2307 };
2308 
2309 /// Returns the list of types of the values held by container.
2310 template <typename Container>
2311 static std::vector<MlirType> getValueTypes(Container &container,
2312  PyMlirContextRef &context) {
2313  std::vector<MlirType> result;
2314  result.reserve(container.size());
2315  for (int i = 0, e = container.size(); i < e; ++i) {
2316  result.push_back(mlirValueGetType(container.getElement(i).get()));
2317  }
2318  return result;
2319 }
2320 
2321 /// A list of block arguments. Internally, these are stored as consecutive
2322 /// elements, random access is cheap. The argument list is associated with the
2323 /// operation that contains the block (detached blocks are not allowed in
2324 /// Python bindings) and extends its lifetime.
2325 class PyBlockArgumentList
2326  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2327 public:
2328  static constexpr const char *pyClassName = "BlockArgumentList";
2330 
2331  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2332  intptr_t startIndex = 0, intptr_t length = -1,
2333  intptr_t step = 1)
2334  : Sliceable(startIndex,
2335  length == -1 ? mlirBlockGetNumArguments(block) : length,
2336  step),
2337  operation(std::move(operation)), block(block) {}
2338 
2339  static void bindDerived(ClassTy &c) {
2340  c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2341  return getValueTypes(self, self.operation->getContext());
2342  });
2343  }
2344 
2345 private:
2346  /// Give the parent CRTP class access to hook implementations below.
2347  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2348 
2349  /// Returns the number of arguments in the list.
2350  intptr_t getRawNumElements() {
2351  operation->checkValid();
2352  return mlirBlockGetNumArguments(block);
2353  }
2354 
2355  /// Returns `pos`-the element in the list.
2356  PyBlockArgument getRawElement(intptr_t pos) {
2357  MlirValue argument = mlirBlockGetArgument(block, pos);
2358  return PyBlockArgument(operation, argument);
2359  }
2360 
2361  /// Returns a sublist of this list.
2362  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2363  intptr_t step) {
2364  return PyBlockArgumentList(operation, block, startIndex, length, step);
2365  }
2366 
2367  PyOperationRef operation;
2368  MlirBlock block;
2369 };
2370 
2371 /// A list of operation operands. Internally, these are stored as consecutive
2372 /// elements, random access is cheap. The (returned) operand list is associated
2373 /// with the operation whose operands these are, and thus extends the lifetime
2374 /// of this operation.
2375 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2376 public:
2377  static constexpr const char *pyClassName = "OpOperandList";
2378  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2379 
2380  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2381  intptr_t length = -1, intptr_t step = 1)
2382  : Sliceable(startIndex,
2383  length == -1 ? mlirOperationGetNumOperands(operation->get())
2384  : length,
2385  step),
2386  operation(operation) {}
2387 
2388  void dunderSetItem(intptr_t index, PyValue value) {
2389  index = wrapIndex(index);
2390  mlirOperationSetOperand(operation->get(), index, value.get());
2391  }
2392 
2393  static void bindDerived(ClassTy &c) {
2394  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2395  }
2396 
2397 private:
2398  /// Give the parent CRTP class access to hook implementations below.
2399  friend class Sliceable<PyOpOperandList, PyValue>;
2400 
2401  intptr_t getRawNumElements() {
2402  operation->checkValid();
2403  return mlirOperationGetNumOperands(operation->get());
2404  }
2405 
2406  PyValue getRawElement(intptr_t pos) {
2407  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2408  MlirOperation owner;
2409  if (mlirValueIsAOpResult(operand))
2410  owner = mlirOpResultGetOwner(operand);
2411  else if (mlirValueIsABlockArgument(operand))
2413  else
2414  assert(false && "Value must be an block arg or op result.");
2415  PyOperationRef pyOwner =
2416  PyOperation::forOperation(operation->getContext(), owner);
2417  return PyValue(pyOwner, operand);
2418  }
2419 
2420  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2421  return PyOpOperandList(operation, startIndex, length, step);
2422  }
2423 
2424  PyOperationRef operation;
2425 };
2426 
2427 /// A list of operation results. Internally, these are stored as consecutive
2428 /// elements, random access is cheap. The (returned) result list is associated
2429 /// with the operation whose results these are, and thus extends the lifetime of
2430 /// this operation.
2431 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2432 public:
2433  static constexpr const char *pyClassName = "OpResultList";
2434  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2435 
2436  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2437  intptr_t length = -1, intptr_t step = 1)
2438  : Sliceable(startIndex,
2439  length == -1 ? mlirOperationGetNumResults(operation->get())
2440  : length,
2441  step),
2442  operation(std::move(operation)) {}
2443 
2444  static void bindDerived(ClassTy &c) {
2445  c.def_prop_ro("types", [](PyOpResultList &self) {
2446  return getValueTypes(self, self.operation->getContext());
2447  });
2448  c.def_prop_ro("owner", [](PyOpResultList &self) {
2449  return self.operation->createOpView();
2450  });
2451  }
2452 
2453 private:
2454  /// Give the parent CRTP class access to hook implementations below.
2455  friend class Sliceable<PyOpResultList, PyOpResult>;
2456 
2457  intptr_t getRawNumElements() {
2458  operation->checkValid();
2459  return mlirOperationGetNumResults(operation->get());
2460  }
2461 
2462  PyOpResult getRawElement(intptr_t index) {
2463  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2464  return PyOpResult(value);
2465  }
2466 
2467  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2468  return PyOpResultList(operation, startIndex, length, step);
2469  }
2470 
2471  PyOperationRef operation;
2472 };
2473 
2474 /// A list of operation successors. Internally, these are stored as consecutive
2475 /// elements, random access is cheap. The (returned) successor list is
2476 /// associated with the operation whose successors these are, and thus extends
2477 /// the lifetime of this operation.
2478 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2479 public:
2480  static constexpr const char *pyClassName = "OpSuccessors";
2481 
2482  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2483  intptr_t length = -1, intptr_t step = 1)
2484  : Sliceable(startIndex,
2485  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2486  : length,
2487  step),
2488  operation(operation) {}
2489 
2490  void dunderSetItem(intptr_t index, PyBlock block) {
2491  index = wrapIndex(index);
2492  mlirOperationSetSuccessor(operation->get(), index, block.get());
2493  }
2494 
2495  static void bindDerived(ClassTy &c) {
2496  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2497  }
2498 
2499 private:
2500  /// Give the parent CRTP class access to hook implementations below.
2501  friend class Sliceable<PyOpSuccessors, PyBlock>;
2502 
2503  intptr_t getRawNumElements() {
2504  operation->checkValid();
2505  return mlirOperationGetNumSuccessors(operation->get());
2506  }
2507 
2508  PyBlock getRawElement(intptr_t pos) {
2509  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2510  return PyBlock(operation, block);
2511  }
2512 
2513  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2514  return PyOpSuccessors(operation, startIndex, length, step);
2515  }
2516 
2517  PyOperationRef operation;
2518 };
2519 
2520 /// A list of operation attributes. Can be indexed by name, producing
2521 /// attributes, or by index, producing named attributes.
2522 class PyOpAttributeMap {
2523 public:
2524  PyOpAttributeMap(PyOperationRef operation)
2525  : operation(std::move(operation)) {}
2526 
2527  MlirAttribute dunderGetItemNamed(const std::string &name) {
2528  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2529  toMlirStringRef(name));
2530  if (mlirAttributeIsNull(attr)) {
2531  throw nb::key_error("attempt to access a non-existent attribute");
2532  }
2533  return attr;
2534  }
2535 
2536  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2537  if (index < 0 || index >= dunderLen()) {
2538  throw nb::index_error("attempt to access out of bounds attribute");
2539  }
2540  MlirNamedAttribute namedAttr =
2541  mlirOperationGetAttribute(operation->get(), index);
2542  return PyNamedAttribute(
2543  namedAttr.attribute,
2544  std::string(mlirIdentifierStr(namedAttr.name).data,
2545  mlirIdentifierStr(namedAttr.name).length));
2546  }
2547 
2548  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2550  attr);
2551  }
2552 
2553  void dunderDelItem(const std::string &name) {
2554  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2555  toMlirStringRef(name));
2556  if (!removed)
2557  throw nb::key_error("attempt to delete a non-existent attribute");
2558  }
2559 
2560  intptr_t dunderLen() {
2561  return mlirOperationGetNumAttributes(operation->get());
2562  }
2563 
2564  bool dunderContains(const std::string &name) {
2566  operation->get(), toMlirStringRef(name)));
2567  }
2568 
2569  static void bind(nb::module_ &m) {
2570  nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2571  .def("__contains__", &PyOpAttributeMap::dunderContains)
2572  .def("__len__", &PyOpAttributeMap::dunderLen)
2573  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2574  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2575  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2576  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2577  }
2578 
2579 private:
2580  PyOperationRef operation;
2581 };
2582 
2583 } // namespace
2584 
2585 //------------------------------------------------------------------------------
2586 // Populates the core exports of the 'ir' submodule.
2587 //------------------------------------------------------------------------------
2588 
2589 void mlir::python::populateIRCore(nb::module_ &m) {
2590  // disable leak warnings which tend to be false positives.
2591  nb::set_leak_warnings(false);
2592  //----------------------------------------------------------------------------
2593  // Enums.
2594  //----------------------------------------------------------------------------
2595  nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2596  .value("ERROR", MlirDiagnosticError)
2597  .value("WARNING", MlirDiagnosticWarning)
2598  .value("NOTE", MlirDiagnosticNote)
2599  .value("REMARK", MlirDiagnosticRemark);
2600 
2601  nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2602  .value("PRE_ORDER", MlirWalkPreOrder)
2603  .value("POST_ORDER", MlirWalkPostOrder);
2604 
2605  nb::enum_<MlirWalkResult>(m, "WalkResult")
2606  .value("ADVANCE", MlirWalkResultAdvance)
2607  .value("INTERRUPT", MlirWalkResultInterrupt)
2608  .value("SKIP", MlirWalkResultSkip);
2609 
2610  //----------------------------------------------------------------------------
2611  // Mapping of Diagnostics.
2612  //----------------------------------------------------------------------------
2613  nb::class_<PyDiagnostic>(m, "Diagnostic")
2614  .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2615  .def_prop_ro("location", &PyDiagnostic::getLocation)
2616  .def_prop_ro("message", &PyDiagnostic::getMessage)
2617  .def_prop_ro("notes", &PyDiagnostic::getNotes)
2618  .def("__str__", [](PyDiagnostic &self) -> nb::str {
2619  if (!self.isValid())
2620  return nb::str("<Invalid Diagnostic>");
2621  return self.getMessage();
2622  });
2623 
2624  nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2625  .def("__init__",
2627  new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2628  })
2629  .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2630  .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2631  .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2632  .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2633  .def("__str__",
2634  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2635 
2636  nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2637  .def("detach", &PyDiagnosticHandler::detach)
2638  .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2639  .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2640  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2641  .def("__exit__", &PyDiagnosticHandler::contextExit,
2642  nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2643  nb::arg("traceback").none());
2644 
2645  //----------------------------------------------------------------------------
2646  // Mapping of MlirContext.
2647  // Note that this is exported as _BaseContext. The containing, Python level
2648  // __init__.py will subclass it with site-specific functionality and set a
2649  // "Context" attribute on this module.
2650  //----------------------------------------------------------------------------
2651  nb::class_<PyMlirContext>(m, "_BaseContext")
2652  .def("__init__",
2653  [](PyMlirContext &self) {
2654  MlirContext context = mlirContextCreateWithThreading(false);
2655  new (&self) PyMlirContext(context);
2656  })
2657  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2658  .def("_get_context_again",
2659  [](PyMlirContext &self) {
2660  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2661  return ref.releaseObject();
2662  })
2663  .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2664  .def("_get_live_operation_objects",
2665  &PyMlirContext::getLiveOperationObjects)
2666  .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2667  .def("_clear_live_operations_inside",
2668  nb::overload_cast<MlirOperation>(
2669  &PyMlirContext::clearOperationsInside))
2670  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2671  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2672  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2673  .def("__enter__", &PyMlirContext::contextEnter)
2674  .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2675  nb::arg("exc_value").none(), nb::arg("traceback").none())
2676  .def_prop_ro_static(
2677  "current",
2678  [](nb::object & /*class*/) {
2679  auto *context = PyThreadContextEntry::getDefaultContext();
2680  if (!context)
2681  return nb::none();
2682  return nb::cast(context);
2683  },
2684  "Gets the Context bound to the current thread or raises ValueError")
2685  .def_prop_ro(
2686  "dialects",
2687  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2688  "Gets a container for accessing dialects by name")
2689  .def_prop_ro(
2690  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2691  "Alias for 'dialect'")
2692  .def(
2693  "get_dialect_descriptor",
2694  [=](PyMlirContext &self, std::string &name) {
2695  MlirDialect dialect = mlirContextGetOrLoadDialect(
2696  self.get(), {name.data(), name.size()});
2697  if (mlirDialectIsNull(dialect)) {
2698  throw nb::value_error(
2699  (Twine("Dialect '") + name + "' not found").str().c_str());
2700  }
2701  return PyDialectDescriptor(self.getRef(), dialect);
2702  },
2703  nb::arg("dialect_name"),
2704  "Gets or loads a dialect by name, returning its descriptor object")
2705  .def_prop_rw(
2706  "allow_unregistered_dialects",
2707  [](PyMlirContext &self) -> bool {
2709  },
2710  [](PyMlirContext &self, bool value) {
2712  })
2713  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2714  nb::arg("callback"),
2715  "Attaches a diagnostic handler that will receive callbacks")
2716  .def(
2717  "enable_multithreading",
2718  [](PyMlirContext &self, bool enable) {
2719  mlirContextEnableMultithreading(self.get(), enable);
2720  },
2721  nb::arg("enable"))
2722  .def(
2723  "is_registered_operation",
2724  [](PyMlirContext &self, std::string &name) {
2726  self.get(), MlirStringRef{name.data(), name.size()});
2727  },
2728  nb::arg("operation_name"))
2729  .def(
2730  "append_dialect_registry",
2731  [](PyMlirContext &self, PyDialectRegistry &registry) {
2732  mlirContextAppendDialectRegistry(self.get(), registry);
2733  },
2734  nb::arg("registry"))
2735  .def_prop_rw("emit_error_diagnostics", nullptr,
2736  &PyMlirContext::setEmitErrorDiagnostics,
2737  "Emit error diagnostics to diagnostic handlers. By default "
2738  "error diagnostics are captured and reported through "
2739  "MLIRError exceptions.")
2740  .def("load_all_available_dialects", [](PyMlirContext &self) {
2742  });
2743 
2744  //----------------------------------------------------------------------------
2745  // Mapping of PyDialectDescriptor
2746  //----------------------------------------------------------------------------
2747  nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2748  .def_prop_ro("namespace",
2749  [](PyDialectDescriptor &self) {
2751  return nb::str(ns.data, ns.length);
2752  })
2753  .def("__repr__", [](PyDialectDescriptor &self) {
2755  std::string repr("<DialectDescriptor ");
2756  repr.append(ns.data, ns.length);
2757  repr.append(">");
2758  return repr;
2759  });
2760 
2761  //----------------------------------------------------------------------------
2762  // Mapping of PyDialects
2763  //----------------------------------------------------------------------------
2764  nb::class_<PyDialects>(m, "Dialects")
2765  .def("__getitem__",
2766  [=](PyDialects &self, std::string keyName) {
2767  MlirDialect dialect =
2768  self.getDialectForKey(keyName, /*attrError=*/false);
2769  nb::object descriptor =
2770  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2771  return createCustomDialectWrapper(keyName, std::move(descriptor));
2772  })
2773  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2774  MlirDialect dialect =
2775  self.getDialectForKey(attrName, /*attrError=*/true);
2776  nb::object descriptor =
2777  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2778  return createCustomDialectWrapper(attrName, std::move(descriptor));
2779  });
2780 
2781  //----------------------------------------------------------------------------
2782  // Mapping of PyDialect
2783  //----------------------------------------------------------------------------
2784  nb::class_<PyDialect>(m, "Dialect")
2785  .def(nb::init<nb::object>(), nb::arg("descriptor"))
2786  .def_prop_ro("descriptor",
2787  [](PyDialect &self) { return self.getDescriptor(); })
2788  .def("__repr__", [](nb::object self) {
2789  auto clazz = self.attr("__class__");
2790  return nb::str("<Dialect ") +
2791  self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
2792  clazz.attr("__module__") + nb::str(".") +
2793  clazz.attr("__name__") + nb::str(")>");
2794  });
2795 
2796  //----------------------------------------------------------------------------
2797  // Mapping of PyDialectRegistry
2798  //----------------------------------------------------------------------------
2799  nb::class_<PyDialectRegistry>(m, "DialectRegistry")
2800  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
2801  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2802  .def(nb::init<>());
2803 
2804  //----------------------------------------------------------------------------
2805  // Mapping of Location
2806  //----------------------------------------------------------------------------
2807  nb::class_<PyLocation>(m, "Location")
2808  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2809  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2810  .def("__enter__", &PyLocation::contextEnter)
2811  .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
2812  nb::arg("exc_value").none(), nb::arg("traceback").none())
2813  .def("__eq__",
2814  [](PyLocation &self, PyLocation &other) -> bool {
2815  return mlirLocationEqual(self, other);
2816  })
2817  .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
2818  .def_prop_ro_static(
2819  "current",
2820  [](nb::object & /*class*/) {
2821  auto *loc = PyThreadContextEntry::getDefaultLocation();
2822  if (!loc)
2823  throw nb::value_error("No current Location");
2824  return loc;
2825  },
2826  "Gets the Location bound to the current thread or raises ValueError")
2827  .def_static(
2828  "unknown",
2829  [](DefaultingPyMlirContext context) {
2830  return PyLocation(context->getRef(),
2831  mlirLocationUnknownGet(context->get()));
2832  },
2833  nb::arg("context").none() = nb::none(),
2834  "Gets a Location representing an unknown location")
2835  .def_static(
2836  "callsite",
2837  [](PyLocation callee, const std::vector<PyLocation> &frames,
2838  DefaultingPyMlirContext context) {
2839  if (frames.empty())
2840  throw nb::value_error("No caller frames provided");
2841  MlirLocation caller = frames.back().get();
2842  for (const PyLocation &frame :
2843  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2844  caller = mlirLocationCallSiteGet(frame.get(), caller);
2845  return PyLocation(context->getRef(),
2846  mlirLocationCallSiteGet(callee.get(), caller));
2847  },
2848  nb::arg("callee"), nb::arg("frames"),
2849  nb::arg("context").none() = nb::none(),
2851  .def_static(
2852  "file",
2853  [](std::string filename, int line, int col,
2854  DefaultingPyMlirContext context) {
2855  return PyLocation(
2856  context->getRef(),
2858  context->get(), toMlirStringRef(filename), line, col));
2859  },
2860  nb::arg("filename"), nb::arg("line"), nb::arg("col"),
2861  nb::arg("context").none() = nb::none(),
2863  .def_static(
2864  "fused",
2865  [](const std::vector<PyLocation> &pyLocations,
2866  std::optional<PyAttribute> metadata,
2867  DefaultingPyMlirContext context) {
2869  locations.reserve(pyLocations.size());
2870  for (auto &pyLocation : pyLocations)
2871  locations.push_back(pyLocation.get());
2872  MlirLocation location = mlirLocationFusedGet(
2873  context->get(), locations.size(), locations.data(),
2874  metadata ? metadata->get() : MlirAttribute{0});
2875  return PyLocation(context->getRef(), location);
2876  },
2877  nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
2878  nb::arg("context").none() = nb::none(),
2880  .def_static(
2881  "name",
2882  [](std::string name, std::optional<PyLocation> childLoc,
2883  DefaultingPyMlirContext context) {
2884  return PyLocation(
2885  context->getRef(),
2887  context->get(), toMlirStringRef(name),
2888  childLoc ? childLoc->get()
2889  : mlirLocationUnknownGet(context->get())));
2890  },
2891  nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
2892  nb::arg("context").none() = nb::none(),
2894  .def_static(
2895  "from_attr",
2896  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2897  return PyLocation(context->getRef(),
2898  mlirLocationFromAttribute(attribute));
2899  },
2900  nb::arg("attribute"), nb::arg("context").none() = nb::none(),
2901  "Gets a Location from a LocationAttr")
2902  .def_prop_ro(
2903  "context",
2904  [](PyLocation &self) { return self.getContext().getObject(); },
2905  "Context that owns the Location")
2906  .def_prop_ro(
2907  "attr",
2908  [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2909  "Get the underlying LocationAttr")
2910  .def(
2911  "emit_error",
2912  [](PyLocation &self, std::string message) {
2913  mlirEmitError(self, message.c_str());
2914  },
2915  nb::arg("message"), "Emits an error at this location")
2916  .def("__repr__", [](PyLocation &self) {
2917  PyPrintAccumulator printAccum;
2918  mlirLocationPrint(self, printAccum.getCallback(),
2919  printAccum.getUserData());
2920  return printAccum.join();
2921  });
2922 
2923  //----------------------------------------------------------------------------
2924  // Mapping of Module
2925  //----------------------------------------------------------------------------
2926  nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
2927  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2928  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2929  .def_static(
2930  "parse",
2931  [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2932  PyMlirContext::ErrorCapture errors(context->getRef());
2933  MlirModule module = mlirModuleCreateParse(
2934  context->get(), toMlirStringRef(moduleAsm));
2935  if (mlirModuleIsNull(module))
2936  throw MLIRError("Unable to parse module assembly", errors.take());
2937  return PyModule::forModule(module).releaseObject();
2938  },
2939  nb::arg("asm"), nb::arg("context").none() = nb::none(),
2941  .def_static(
2942  "parse",
2943  [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
2944  PyMlirContext::ErrorCapture errors(context->getRef());
2945  MlirModule module = mlirModuleCreateParse(
2946  context->get(), toMlirStringRef(moduleAsm));
2947  if (mlirModuleIsNull(module))
2948  throw MLIRError("Unable to parse module assembly", errors.take());
2949  return PyModule::forModule(module).releaseObject();
2950  },
2951  nb::arg("asm"), nb::arg("context").none() = nb::none(),
2953  .def_static(
2954  "create",
2955  [](DefaultingPyLocation loc) {
2956  MlirModule module = mlirModuleCreateEmpty(loc);
2957  return PyModule::forModule(module).releaseObject();
2958  },
2959  nb::arg("loc").none() = nb::none(), "Creates an empty module")
2960  .def_prop_ro(
2961  "context",
2962  [](PyModule &self) { return self.getContext().getObject(); },
2963  "Context that created the Module")
2964  .def_prop_ro(
2965  "operation",
2966  [](PyModule &self) {
2967  return PyOperation::forOperation(self.getContext(),
2968  mlirModuleGetOperation(self.get()),
2969  self.getRef().releaseObject())
2970  .releaseObject();
2971  },
2972  "Accesses the module as an operation")
2973  .def_prop_ro(
2974  "body",
2975  [](PyModule &self) {
2976  PyOperationRef moduleOp = PyOperation::forOperation(
2977  self.getContext(), mlirModuleGetOperation(self.get()),
2978  self.getRef().releaseObject());
2979  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2980  return returnBlock;
2981  },
2982  "Return the block for this module")
2983  .def(
2984  "dump",
2985  [](PyModule &self) {
2987  },
2989  .def(
2990  "__str__",
2991  [](nb::object self) {
2992  // Defer to the operation's __str__.
2993  return self.attr("operation").attr("__str__")();
2994  },
2996 
2997  //----------------------------------------------------------------------------
2998  // Mapping of Operation.
2999  //----------------------------------------------------------------------------
3000  nb::class_<PyOperationBase>(m, "_OperationBase")
3001  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3002  [](PyOperationBase &self) {
3003  return self.getOperation().getCapsule();
3004  })
3005  .def("__eq__",
3006  [](PyOperationBase &self, PyOperationBase &other) {
3007  return &self.getOperation() == &other.getOperation();
3008  })
3009  .def("__eq__",
3010  [](PyOperationBase &self, nb::object other) { return false; })
3011  .def("__hash__",
3012  [](PyOperationBase &self) {
3013  return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3014  })
3015  .def_prop_ro("attributes",
3016  [](PyOperationBase &self) {
3017  return PyOpAttributeMap(self.getOperation().getRef());
3018  })
3019  .def_prop_ro(
3020  "context",
3021  [](PyOperationBase &self) {
3022  PyOperation &concreteOperation = self.getOperation();
3023  concreteOperation.checkValid();
3024  return concreteOperation.getContext().getObject();
3025  },
3026  "Context that owns the Operation")
3027  .def_prop_ro("name",
3028  [](PyOperationBase &self) {
3029  auto &concreteOperation = self.getOperation();
3030  concreteOperation.checkValid();
3031  MlirOperation operation = concreteOperation.get();
3032  MlirStringRef name =
3034  return nb::str(name.data, name.length);
3035  })
3036  .def_prop_ro("operands",
3037  [](PyOperationBase &self) {
3038  return PyOpOperandList(self.getOperation().getRef());
3039  })
3040  .def_prop_ro("regions",
3041  [](PyOperationBase &self) {
3042  return PyRegionList(self.getOperation().getRef());
3043  })
3044  .def_prop_ro(
3045  "results",
3046  [](PyOperationBase &self) {
3047  return PyOpResultList(self.getOperation().getRef());
3048  },
3049  "Returns the list of Operation results.")
3050  .def_prop_ro(
3051  "result",
3052  [](PyOperationBase &self) {
3053  auto &operation = self.getOperation();
3054  auto numResults = mlirOperationGetNumResults(operation);
3055  if (numResults != 1) {
3056  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
3057  throw nb::value_error(
3058  (Twine("Cannot call .result on operation ") +
3059  StringRef(name.data, name.length) + " which has " +
3060  Twine(numResults) +
3061  " results (it is only valid for operations with a "
3062  "single result)")
3063  .str()
3064  .c_str());
3065  }
3066  return PyOpResult(operation.getRef(),
3067  mlirOperationGetResult(operation, 0))
3068  .maybeDownCast();
3069  },
3070  "Shortcut to get an op result if it has only one (throws an error "
3071  "otherwise).")
3072  .def_prop_ro(
3073  "location",
3074  [](PyOperationBase &self) {
3075  PyOperation &operation = self.getOperation();
3076  return PyLocation(operation.getContext(),
3077  mlirOperationGetLocation(operation.get()));
3078  },
3079  "Returns the source location the operation was defined or derived "
3080  "from.")
3081  .def_prop_ro("parent",
3082  [](PyOperationBase &self) -> nb::object {
3083  auto parent = self.getOperation().getParentOperation();
3084  if (parent)
3085  return parent->getObject();
3086  return nb::none();
3087  })
3088  .def(
3089  "__str__",
3090  [](PyOperationBase &self) {
3091  return self.getAsm(/*binary=*/false,
3092  /*largeElementsLimit=*/std::nullopt,
3093  /*enableDebugInfo=*/false,
3094  /*prettyDebugInfo=*/false,
3095  /*printGenericOpForm=*/false,
3096  /*useLocalScope=*/false,
3097  /*assumeVerified=*/false,
3098  /*skipRegions=*/false);
3099  },
3100  "Returns the assembly form of the operation.")
3101  .def("print",
3102  nb::overload_cast<PyAsmState &, nb::object, bool>(
3104  nb::arg("state"), nb::arg("file").none() = nb::none(),
3105  nb::arg("binary") = false, kOperationPrintStateDocstring)
3106  .def("print",
3107  nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3108  bool, nb::object, bool, bool>(
3110  // Careful: Lots of arguments must match up with print method.
3111  nb::arg("large_elements_limit").none() = nb::none(),
3112  nb::arg("enable_debug_info") = false,
3113  nb::arg("pretty_debug_info") = false,
3114  nb::arg("print_generic_op_form") = false,
3115  nb::arg("use_local_scope") = false,
3116  nb::arg("assume_verified") = false,
3117  nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3118  nb::arg("skip_regions") = false, kOperationPrintDocstring)
3119  .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3120  nb::arg("desired_version").none() = nb::none(),
3122  .def("get_asm", &PyOperationBase::getAsm,
3123  // Careful: Lots of arguments must match up with get_asm method.
3124  nb::arg("binary") = false,
3125  nb::arg("large_elements_limit").none() = nb::none(),
3126  nb::arg("enable_debug_info") = false,
3127  nb::arg("pretty_debug_info") = false,
3128  nb::arg("print_generic_op_form") = false,
3129  nb::arg("use_local_scope") = false,
3130  nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3132  .def("verify", &PyOperationBase::verify,
3133  "Verify the operation. Raises MLIRError if verification fails, and "
3134  "returns true otherwise.")
3135  .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3136  "Puts self immediately after the other operation in its parent "
3137  "block.")
3138  .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3139  "Puts self immediately before the other operation in its parent "
3140  "block.")
3141  .def(
3142  "clone",
3143  [](PyOperationBase &self, nb::object ip) {
3144  return self.getOperation().clone(ip);
3145  },
3146  nb::arg("ip").none() = nb::none())
3147  .def(
3148  "detach_from_parent",
3149  [](PyOperationBase &self) {
3150  PyOperation &operation = self.getOperation();
3151  operation.checkValid();
3152  if (!operation.isAttached())
3153  throw nb::value_error("Detached operation has no parent.");
3154 
3155  operation.detachFromParent();
3156  return operation.createOpView();
3157  },
3158  "Detaches the operation from its parent block.")
3159  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3160  .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3161  nb::arg("walk_order") = MlirWalkPostOrder);
3162 
3163  nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3164  .def_static("create", &PyOperation::create, nb::arg("name"),
3165  nb::arg("results").none() = nb::none(),
3166  nb::arg("operands").none() = nb::none(),
3167  nb::arg("attributes").none() = nb::none(),
3168  nb::arg("successors").none() = nb::none(),
3169  nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(),
3170  nb::arg("ip").none() = nb::none(),
3171  nb::arg("infer_type") = false, kOperationCreateDocstring)
3172  .def_static(
3173  "parse",
3174  [](const std::string &sourceStr, const std::string &sourceName,
3175  DefaultingPyMlirContext context) {
3176  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3177  ->createOpView();
3178  },
3179  nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3180  nb::arg("context").none() = nb::none(),
3181  "Parses an operation. Supports both text assembly format and binary "
3182  "bytecode format.")
3183  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3184  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3185  .def_prop_ro("operation", [](nb::object self) { return self; })
3186  .def_prop_ro("opview", &PyOperation::createOpView)
3187  .def_prop_ro(
3188  "successors",
3189  [](PyOperationBase &self) {
3190  return PyOpSuccessors(self.getOperation().getRef());
3191  },
3192  "Returns the list of Operation successors.");
3193 
3194  auto opViewClass =
3195  nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3196  .def(nb::init<nb::object>(), nb::arg("operation"))
3197  .def_prop_ro("operation", &PyOpView::getOperationObject)
3198  .def_prop_ro("opview", [](nb::object self) { return self; })
3199  .def(
3200  "__str__",
3201  [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3202  .def_prop_ro(
3203  "successors",
3204  [](PyOperationBase &self) {
3205  return PyOpSuccessors(self.getOperation().getRef());
3206  },
3207  "Returns the list of Operation successors.");
3208  opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3209  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3210  opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3211  opViewClass.attr("build_generic") = classmethod(
3212  &PyOpView::buildGeneric, nb::arg("cls"),
3213  nb::arg("results").none() = nb::none(),
3214  nb::arg("operands").none() = nb::none(),
3215  nb::arg("attributes").none() = nb::none(),
3216  nb::arg("successors").none() = nb::none(),
3217  nb::arg("regions").none() = nb::none(),
3218  nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3219  "Builds a specific, generated OpView based on class level attributes.");
3220  opViewClass.attr("parse") = classmethod(
3221  [](const nb::object &cls, const std::string &sourceStr,
3222  const std::string &sourceName, DefaultingPyMlirContext context) {
3223  PyOperationRef parsed =
3224  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3225 
3226  // Check if the expected operation was parsed, and cast to to the
3227  // appropriate `OpView` subclass if successful.
3228  // NOTE: This accesses attributes that have been automatically added to
3229  // `OpView` subclasses, and is not intended to be used on `OpView`
3230  // directly.
3231  std::string clsOpName =
3232  nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3233  MlirStringRef identifier =
3235  std::string_view parsedOpName(identifier.data, identifier.length);
3236  if (clsOpName != parsedOpName)
3237  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3238  parsedOpName + "'");
3239  return PyOpView::constructDerived(cls, parsed.getObject());
3240  },
3241  nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3242  nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
3243  "Parses a specific, generated OpView based on class level attributes");
3244 
3245  //----------------------------------------------------------------------------
3246  // Mapping of PyRegion.
3247  //----------------------------------------------------------------------------
3248  nb::class_<PyRegion>(m, "Region")
3249  .def_prop_ro(
3250  "blocks",
3251  [](PyRegion &self) {
3252  return PyBlockList(self.getParentOperation(), self.get());
3253  },
3254  "Returns a forward-optimized sequence of blocks.")
3255  .def_prop_ro(
3256  "owner",
3257  [](PyRegion &self) {
3258  return self.getParentOperation()->createOpView();
3259  },
3260  "Returns the operation owning this region.")
3261  .def(
3262  "__iter__",
3263  [](PyRegion &self) {
3264  self.checkValid();
3265  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3266  return PyBlockIterator(self.getParentOperation(), firstBlock);
3267  },
3268  "Iterates over blocks in the region.")
3269  .def("__eq__",
3270  [](PyRegion &self, PyRegion &other) {
3271  return self.get().ptr == other.get().ptr;
3272  })
3273  .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3274 
3275  //----------------------------------------------------------------------------
3276  // Mapping of PyBlock.
3277  //----------------------------------------------------------------------------
3278  nb::class_<PyBlock>(m, "Block")
3279  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3280  .def_prop_ro(
3281  "owner",
3282  [](PyBlock &self) {
3283  return self.getParentOperation()->createOpView();
3284  },
3285  "Returns the owning operation of this block.")
3286  .def_prop_ro(
3287  "region",
3288  [](PyBlock &self) {
3289  MlirRegion region = mlirBlockGetParentRegion(self.get());
3290  return PyRegion(self.getParentOperation(), region);
3291  },
3292  "Returns the owning region of this block.")
3293  .def_prop_ro(
3294  "arguments",
3295  [](PyBlock &self) {
3296  return PyBlockArgumentList(self.getParentOperation(), self.get());
3297  },
3298  "Returns a list of block arguments.")
3299  .def(
3300  "add_argument",
3301  [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3302  return mlirBlockAddArgument(self.get(), type, loc);
3303  },
3304  "Append an argument of the specified type to the block and returns "
3305  "the newly added argument.")
3306  .def(
3307  "erase_argument",
3308  [](PyBlock &self, unsigned index) {
3309  return mlirBlockEraseArgument(self.get(), index);
3310  },
3311  "Erase the argument at 'index' and remove it from the argument list.")
3312  .def_prop_ro(
3313  "operations",
3314  [](PyBlock &self) {
3315  return PyOperationList(self.getParentOperation(), self.get());
3316  },
3317  "Returns a forward-optimized sequence of operations.")
3318  .def_static(
3319  "create_at_start",
3320  [](PyRegion &parent, const nb::sequence &pyArgTypes,
3321  const std::optional<nb::sequence> &pyArgLocs) {
3322  parent.checkValid();
3323  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3324  mlirRegionInsertOwnedBlock(parent, 0, block);
3325  return PyBlock(parent.getParentOperation(), block);
3326  },
3327  nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3328  nb::arg("arg_locs") = std::nullopt,
3329  "Creates and returns a new Block at the beginning of the given "
3330  "region (with given argument types and locations).")
3331  .def(
3332  "append_to",
3333  [](PyBlock &self, PyRegion &region) {
3334  MlirBlock b = self.get();
3336  mlirBlockDetach(b);
3337  mlirRegionAppendOwnedBlock(region.get(), b);
3338  },
3339  "Append this block to a region, transferring ownership if necessary")
3340  .def(
3341  "create_before",
3342  [](PyBlock &self, const nb::args &pyArgTypes,
3343  const std::optional<nb::sequence> &pyArgLocs) {
3344  self.checkValid();
3345  MlirBlock block =
3346  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3347  MlirRegion region = mlirBlockGetParentRegion(self.get());
3348  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3349  return PyBlock(self.getParentOperation(), block);
3350  },
3351  nb::arg("arg_types"), nb::kw_only(),
3352  nb::arg("arg_locs") = std::nullopt,
3353  "Creates and returns a new Block before this block "
3354  "(with given argument types and locations).")
3355  .def(
3356  "create_after",
3357  [](PyBlock &self, const nb::args &pyArgTypes,
3358  const std::optional<nb::sequence> &pyArgLocs) {
3359  self.checkValid();
3360  MlirBlock block =
3361  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3362  MlirRegion region = mlirBlockGetParentRegion(self.get());
3363  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3364  return PyBlock(self.getParentOperation(), block);
3365  },
3366  nb::arg("arg_types"), nb::kw_only(),
3367  nb::arg("arg_locs") = std::nullopt,
3368  "Creates and returns a new Block after this block "
3369  "(with given argument types and locations).")
3370  .def(
3371  "__iter__",
3372  [](PyBlock &self) {
3373  self.checkValid();
3374  MlirOperation firstOperation =
3376  return PyOperationIterator(self.getParentOperation(),
3377  firstOperation);
3378  },
3379  "Iterates over operations in the block.")
3380  .def("__eq__",
3381  [](PyBlock &self, PyBlock &other) {
3382  return self.get().ptr == other.get().ptr;
3383  })
3384  .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3385  .def("__hash__",
3386  [](PyBlock &self) {
3387  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3388  })
3389  .def(
3390  "__str__",
3391  [](PyBlock &self) {
3392  self.checkValid();
3393  PyPrintAccumulator printAccum;
3394  mlirBlockPrint(self.get(), printAccum.getCallback(),
3395  printAccum.getUserData());
3396  return printAccum.join();
3397  },
3398  "Returns the assembly form of the block.")
3399  .def(
3400  "append",
3401  [](PyBlock &self, PyOperationBase &operation) {
3402  if (operation.getOperation().isAttached())
3403  operation.getOperation().detachFromParent();
3404 
3405  MlirOperation mlirOperation = operation.getOperation().get();
3406  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3407  operation.getOperation().setAttached(
3408  self.getParentOperation().getObject());
3409  },
3410  nb::arg("operation"),
3411  "Appends an operation to this block. If the operation is currently "
3412  "in another block, it will be moved.");
3413 
3414  //----------------------------------------------------------------------------
3415  // Mapping of PyInsertionPoint.
3416  //----------------------------------------------------------------------------
3417 
3418  nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3419  .def(nb::init<PyBlock &>(), nb::arg("block"),
3420  "Inserts after the last operation but still inside the block.")
3421  .def("__enter__", &PyInsertionPoint::contextEnter)
3422  .def("__exit__", &PyInsertionPoint::contextExit,
3423  nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3424  nb::arg("traceback").none())
3425  .def_prop_ro_static(
3426  "current",
3427  [](nb::object & /*class*/) {
3428  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3429  if (!ip)
3430  throw nb::value_error("No current InsertionPoint");
3431  return ip;
3432  },
3433  "Gets the InsertionPoint bound to the current thread or raises "
3434  "ValueError if none has been set")
3435  .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3436  "Inserts before a referenced operation.")
3437  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3438  nb::arg("block"), "Inserts at the beginning of the block.")
3439  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3440  nb::arg("block"), "Inserts before the block terminator.")
3441  .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3442  "Inserts an operation.")
3443  .def_prop_ro(
3444  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3445  "Returns the block that this InsertionPoint points to.")
3446  .def_prop_ro(
3447  "ref_operation",
3448  [](PyInsertionPoint &self) -> nb::object {
3449  auto refOperation = self.getRefOperation();
3450  if (refOperation)
3451  return refOperation->getObject();
3452  return nb::none();
3453  },
3454  "The reference operation before which new operations are "
3455  "inserted, or None if the insertion point is at the end of "
3456  "the block");
3457 
3458  //----------------------------------------------------------------------------
3459  // Mapping of PyAttribute.
3460  //----------------------------------------------------------------------------
3461  nb::class_<PyAttribute>(m, "Attribute")
3462  // Delegate to the PyAttribute copy constructor, which will also lifetime
3463  // extend the backing context which owns the MlirAttribute.
3464  .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3465  "Casts the passed attribute to the generic Attribute")
3466  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3467  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3468  .def_static(
3469  "parse",
3470  [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3471  PyMlirContext::ErrorCapture errors(context->getRef());
3472  MlirAttribute attr = mlirAttributeParseGet(
3473  context->get(), toMlirStringRef(attrSpec));
3474  if (mlirAttributeIsNull(attr))
3475  throw MLIRError("Unable to parse attribute", errors.take());
3476  return attr;
3477  },
3478  nb::arg("asm"), nb::arg("context").none() = nb::none(),
3479  "Parses an attribute from an assembly form. Raises an MLIRError on "
3480  "failure.")
3481  .def_prop_ro(
3482  "context",
3483  [](PyAttribute &self) { return self.getContext().getObject(); },
3484  "Context that owns the Attribute")
3485  .def_prop_ro("type",
3486  [](PyAttribute &self) { return mlirAttributeGetType(self); })
3487  .def(
3488  "get_named",
3489  [](PyAttribute &self, std::string name) {
3490  return PyNamedAttribute(self, std::move(name));
3491  },
3492  nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3493  .def("__eq__",
3494  [](PyAttribute &self, PyAttribute &other) { return self == other; })
3495  .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3496  .def("__hash__",
3497  [](PyAttribute &self) {
3498  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3499  })
3500  .def(
3501  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3503  .def(
3504  "__str__",
3505  [](PyAttribute &self) {
3506  PyPrintAccumulator printAccum;
3507  mlirAttributePrint(self, printAccum.getCallback(),
3508  printAccum.getUserData());
3509  return printAccum.join();
3510  },
3511  "Returns the assembly form of the Attribute.")
3512  .def("__repr__",
3513  [](PyAttribute &self) {
3514  // Generally, assembly formats are not printed for __repr__ because
3515  // this can cause exceptionally long debug output and exceptions.
3516  // However, attribute values are generally considered useful and
3517  // are printed. This may need to be re-evaluated if debug dumps end
3518  // up being excessive.
3519  PyPrintAccumulator printAccum;
3520  printAccum.parts.append("Attribute(");
3521  mlirAttributePrint(self, printAccum.getCallback(),
3522  printAccum.getUserData());
3523  printAccum.parts.append(")");
3524  return printAccum.join();
3525  })
3526  .def_prop_ro("typeid",
3527  [](PyAttribute &self) -> MlirTypeID {
3528  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3529  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3530  "mlirTypeID was expected to be non-null.");
3531  return mlirTypeID;
3532  })
3534  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3535  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3536  "mlirTypeID was expected to be non-null.");
3537  std::optional<nb::callable> typeCaster =
3538  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3539  mlirAttributeGetDialect(self));
3540  if (!typeCaster)
3541  return nb::cast(self);
3542  return typeCaster.value()(self);
3543  });
3544 
3545  //----------------------------------------------------------------------------
3546  // Mapping of PyNamedAttribute
3547  //----------------------------------------------------------------------------
3548  nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3549  .def("__repr__",
3550  [](PyNamedAttribute &self) {
3551  PyPrintAccumulator printAccum;
3552  printAccum.parts.append("NamedAttribute(");
3553  printAccum.parts.append(
3554  nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3555  mlirIdentifierStr(self.namedAttr.name).length));
3556  printAccum.parts.append("=");
3557  mlirAttributePrint(self.namedAttr.attribute,
3558  printAccum.getCallback(),
3559  printAccum.getUserData());
3560  printAccum.parts.append(")");
3561  return printAccum.join();
3562  })
3563  .def_prop_ro(
3564  "name",
3565  [](PyNamedAttribute &self) {
3566  return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3567  mlirIdentifierStr(self.namedAttr.name).length);
3568  },
3569  "The name of the NamedAttribute binding")
3570  .def_prop_ro(
3571  "attr",
3572  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3573  nb::keep_alive<0, 1>(),
3574  "The underlying generic attribute of the NamedAttribute binding");
3575 
3576  //----------------------------------------------------------------------------
3577  // Mapping of PyType.
3578  //----------------------------------------------------------------------------
3579  nb::class_<PyType>(m, "Type")
3580  // Delegate to the PyType copy constructor, which will also lifetime
3581  // extend the backing context which owns the MlirType.
3582  .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3583  "Casts the passed type to the generic Type")
3584  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3585  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3586  .def_static(
3587  "parse",
3588  [](std::string typeSpec, DefaultingPyMlirContext context) {
3589  PyMlirContext::ErrorCapture errors(context->getRef());
3590  MlirType type =
3591  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3592  if (mlirTypeIsNull(type))
3593  throw MLIRError("Unable to parse type", errors.take());
3594  return type;
3595  },
3596  nb::arg("asm"), nb::arg("context").none() = nb::none(),
3598  .def_prop_ro(
3599  "context", [](PyType &self) { return self.getContext().getObject(); },
3600  "Context that owns the Type")
3601  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3602  .def(
3603  "__eq__", [](PyType &self, nb::object &other) { return false; },
3604  nb::arg("other").none())
3605  .def("__hash__",
3606  [](PyType &self) {
3607  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3608  })
3609  .def(
3610  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3611  .def(
3612  "__str__",
3613  [](PyType &self) {
3614  PyPrintAccumulator printAccum;
3615  mlirTypePrint(self, printAccum.getCallback(),
3616  printAccum.getUserData());
3617  return printAccum.join();
3618  },
3619  "Returns the assembly form of the type.")
3620  .def("__repr__",
3621  [](PyType &self) {
3622  // Generally, assembly formats are not printed for __repr__ because
3623  // this can cause exceptionally long debug output and exceptions.
3624  // However, types are an exception as they typically have compact
3625  // assembly forms and printing them is useful.
3626  PyPrintAccumulator printAccum;
3627  printAccum.parts.append("Type(");
3628  mlirTypePrint(self, printAccum.getCallback(),
3629  printAccum.getUserData());
3630  printAccum.parts.append(")");
3631  return printAccum.join();
3632  })
3634  [](PyType &self) {
3635  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3636  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3637  "mlirTypeID was expected to be non-null.");
3638  std::optional<nb::callable> typeCaster =
3639  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3640  mlirTypeGetDialect(self));
3641  if (!typeCaster)
3642  return nb::cast(self);
3643  return typeCaster.value()(self);
3644  })
3645  .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3646  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3647  if (!mlirTypeIDIsNull(mlirTypeID))
3648  return mlirTypeID;
3649  auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3650  throw nb::value_error(
3651  (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3652  });
3653 
3654  //----------------------------------------------------------------------------
3655  // Mapping of PyTypeID.
3656  //----------------------------------------------------------------------------
3657  nb::class_<PyTypeID>(m, "TypeID")
3658  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3659  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3660  // Note, this tests whether the underlying TypeIDs are the same,
3661  // not whether the wrapper MlirTypeIDs are the same, nor whether
3662  // the Python objects are the same (i.e., PyTypeID is a value type).
3663  .def("__eq__",
3664  [](PyTypeID &self, PyTypeID &other) { return self == other; })
3665  .def("__eq__",
3666  [](PyTypeID &self, const nb::object &other) { return false; })
3667  // Note, this gives the hash value of the underlying TypeID, not the
3668  // hash value of the Python object, nor the hash value of the
3669  // MlirTypeID wrapper.
3670  .def("__hash__", [](PyTypeID &self) {
3671  return static_cast<size_t>(mlirTypeIDHashValue(self));
3672  });
3673 
3674  //----------------------------------------------------------------------------
3675  // Mapping of Value.
3676  //----------------------------------------------------------------------------
3677  nb::class_<PyValue>(m, "Value")
3678  .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
3679  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3680  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3681  .def_prop_ro(
3682  "context",
3683  [](PyValue &self) { return self.getParentOperation()->getContext(); },
3684  "Context in which the value lives.")
3685  .def(
3686  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3688  .def_prop_ro(
3689  "owner",
3690  [](PyValue &self) -> nb::object {
3691  MlirValue v = self.get();
3692  if (mlirValueIsAOpResult(v)) {
3693  assert(
3694  mlirOperationEqual(self.getParentOperation()->get(),
3695  mlirOpResultGetOwner(self.get())) &&
3696  "expected the owner of the value in Python to match that in "
3697  "the IR");
3698  return self.getParentOperation().getObject();
3699  }
3700 
3701  if (mlirValueIsABlockArgument(v)) {
3702  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3703  return nb::cast(PyBlock(self.getParentOperation(), block));
3704  }
3705 
3706  assert(false && "Value must be a block argument or an op result");
3707  return nb::none();
3708  })
3709  .def_prop_ro("uses",
3710  [](PyValue &self) {
3711  return PyOpOperandIterator(
3712  mlirValueGetFirstUse(self.get()));
3713  })
3714  .def("__eq__",
3715  [](PyValue &self, PyValue &other) {
3716  return self.get().ptr == other.get().ptr;
3717  })
3718  .def("__eq__", [](PyValue &self, nb::object other) { return false; })
3719  .def("__hash__",
3720  [](PyValue &self) {
3721  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3722  })
3723  .def(
3724  "__str__",
3725  [](PyValue &self) {
3726  PyPrintAccumulator printAccum;
3727  printAccum.parts.append("Value(");
3728  mlirValuePrint(self.get(), printAccum.getCallback(),
3729  printAccum.getUserData());
3730  printAccum.parts.append(")");
3731  return printAccum.join();
3732  },
3734  .def(
3735  "get_name",
3736  [](PyValue &self, bool useLocalScope) {
3737  PyPrintAccumulator printAccum;
3738  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3739  if (useLocalScope)
3741  MlirAsmState valueState =
3742  mlirAsmStateCreateForValue(self.get(), flags);
3743  mlirValuePrintAsOperand(self.get(), valueState,
3744  printAccum.getCallback(),
3745  printAccum.getUserData());
3747  mlirAsmStateDestroy(valueState);
3748  return printAccum.join();
3749  },
3750  nb::arg("use_local_scope") = false)
3751  .def(
3752  "get_name",
3753  [](PyValue &self, PyAsmState &state) {
3754  PyPrintAccumulator printAccum;
3755  MlirAsmState valueState = state.get();
3756  mlirValuePrintAsOperand(self.get(), valueState,
3757  printAccum.getCallback(),
3758  printAccum.getUserData());
3759  return printAccum.join();
3760  },
3761  nb::arg("state"), kGetNameAsOperand)
3762  .def_prop_ro("type",
3763  [](PyValue &self) { return mlirValueGetType(self.get()); })
3764  .def(
3765  "set_type",
3766  [](PyValue &self, const PyType &type) {
3767  return mlirValueSetType(self.get(), type);
3768  },
3769  nb::arg("type"))
3770  .def(
3771  "replace_all_uses_with",
3772  [](PyValue &self, PyValue &with) {
3773  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3774  },
3776  .def(
3777  "replace_all_uses_except",
3778  [](MlirValue self, MlirValue with, PyOperation &exception) {
3779  MlirOperation exceptedUser = exception.get();
3780  mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
3781  },
3782  nb::arg("with"), nb::arg("exceptions"),
3784  .def(
3785  "replace_all_uses_except",
3786  [](MlirValue self, MlirValue with, nb::list exceptions) {
3787  // Convert Python list to a SmallVector of MlirOperations
3788  llvm::SmallVector<MlirOperation> exceptionOps;
3789  for (nb::handle exception : exceptions) {
3790  exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
3791  }
3792 
3794  self, with, static_cast<intptr_t>(exceptionOps.size()),
3795  exceptionOps.data());
3796  },
3797  nb::arg("with"), nb::arg("exceptions"),
3800  [](PyValue &self) { return self.maybeDownCast(); });
3801  PyBlockArgument::bind(m);
3802  PyOpResult::bind(m);
3803  PyOpOperand::bind(m);
3804 
3805  nb::class_<PyAsmState>(m, "AsmState")
3806  .def(nb::init<PyValue &, bool>(), nb::arg("value"),
3807  nb::arg("use_local_scope") = false)
3808  .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
3809  nb::arg("use_local_scope") = false);
3810 
3811  //----------------------------------------------------------------------------
3812  // Mapping of SymbolTable.
3813  //----------------------------------------------------------------------------
3814  nb::class_<PySymbolTable>(m, "SymbolTable")
3815  .def(nb::init<PyOperationBase &>())
3816  .def("__getitem__", &PySymbolTable::dunderGetItem)
3817  .def("insert", &PySymbolTable::insert, nb::arg("operation"))
3818  .def("erase", &PySymbolTable::erase, nb::arg("operation"))
3819  .def("__delitem__", &PySymbolTable::dunderDel)
3820  .def("__contains__",
3821  [](PySymbolTable &table, const std::string &name) {
3823  table, mlirStringRefCreate(name.data(), name.length())));
3824  })
3825  // Static helpers.
3826  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3827  nb::arg("symbol"), nb::arg("name"))
3828  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3829  nb::arg("symbol"))
3830  .def_static("get_visibility", &PySymbolTable::getVisibility,
3831  nb::arg("symbol"))
3832  .def_static("set_visibility", &PySymbolTable::setVisibility,
3833  nb::arg("symbol"), nb::arg("visibility"))
3834  .def_static("replace_all_symbol_uses",
3835  &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
3836  nb::arg("new_symbol"), nb::arg("from_op"))
3837  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3838  nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
3839  nb::arg("callback"));
3840 
3841  // Container bindings.
3842  PyBlockArgumentList::bind(m);
3843  PyBlockIterator::bind(m);
3844  PyBlockList::bind(m);
3845  PyOperationIterator::bind(m);
3846  PyOperationList::bind(m);
3847  PyOpAttributeMap::bind(m);
3848  PyOpOperandIterator::bind(m);
3849  PyOpOperandList::bind(m);
3850  PyOpResultList::bind(m);
3851  PyOpSuccessors::bind(m);
3852  PyRegionIterator::bind(m);
3853  PyRegionList::bind(m);
3854 
3855  // Debug bindings.
3857 
3858  // Attribute builder getter.
3860 
3861  nb::register_exception_translator([](const std::exception_ptr &p,
3862  void *payload) {
3863  // We can't define exceptions with custom fields through pybind, so instead
3864  // the exception class is defined in python and imported here.
3865  try {
3866  if (p)
3867  std::rethrow_exception(p);
3868  } catch (const MLIRError &e) {
3869  nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
3870  .attr("MLIRError")(e.message, e.errorDiagnostics);
3871  PyErr_SetObject(PyExc_Exception, obj.ptr());
3872  }
3873  });
3874 }
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition: Debug.cpp:20
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools.
Definition: Debug.cpp:28
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:113
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:175
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:56
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:171
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:210
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:59
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:145
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:86
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:50
static const char kDumpDocstring[]
Definition: IRCore.cpp:153
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:156
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:53
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition: IRCore.cpp:1406
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition: IRCore.cpp:198
nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:192
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:135
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:122
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:220
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:67
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:39
static void populateResultTypes(StringRef name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1577
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:47
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:163
static const char kValueReplaceAllUsesExceptDocstring[]
Definition: IRCore.cpp:180
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:273
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:118
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:348
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:189
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:180
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:255
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:357
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:235
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:367
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:454
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:198
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:330
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:264
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:445
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:216
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
const float * table
Accumulates int a python file-like object, either writing text (default) or binary.
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:302
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:310
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:517
static PyLocation & resolve()
Definition: IRCore.cpp:1073
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:291
static PyMlirContext & resolve()
Definition: IRCore.cpp:794
ReferrentTy * get() const
Definition: NanobindUtils.h:55
Wrapper around an MlirAsmState.
Definition: IRModule.h:782
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1003
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:1005
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:1953
static PyAttribute createFromCapsule(nanobind::object capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:1957
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:1949
Wrapper around an MlirBlock.
Definition: IRModule.h:817
MlirBlock get()
Definition: IRModule.h:824
PyOperationRef & getParentOperation()
Definition: IRModule.h:825
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:398
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition: IRCore.cpp:950
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:956
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:348
PyLocation getLocation()
Definition: IRCore.cpp:979
nanobind::tuple getNotes()
Definition: IRCore.cpp:994
nanobind::str getMessage()
Definition: IRCore.cpp:986
DiagnosticInfo getInfo()
Definition: IRCore.cpp:1010
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:350
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:974
Wrapper around an MlirDialect.
Definition: IRModule.h:453
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:490
nanobind::object getCapsule()
Definition: IRCore.cpp:1035
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition: IRCore.cpp:1039
User-level dialect object.
Definition: IRModule.h:477
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:466
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:1022
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:144
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:171
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:33
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:841
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:1926
PyInsertionPoint(PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:1881
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:1913
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:1939
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:1887
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition: IRCore.cpp:1935
Wrapper around an MlirLocation.
Definition: IRModule.h:317
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1051
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:319
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1055
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:1067
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition: IRCore.cpp:1063
MlirLocation get() const
Definition: IRModule.h:323
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:186
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:190
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:682
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:655
void clearOperationAndInside(PyOperationBase &op)
Clears the operaiton and all operations inside using clearOperation(MlirOperation).
Definition: IRCore.cpp:707
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:718
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:730
size_t clearLiveOperations()
Clears the live operations map, returning the number of entries which were invalidated.
Definition: IRCore.cpp:666
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:622
std::vector< PyOperation * > getLiveOperationObjects()
Get a list of Python objects which are still in the live context map.
Definition: IRCore.cpp:659
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:724
void clearOperation(MlirOperation op)
Removes an operation from the live operations map and sets it invalid.
Definition: IRCore.cpp:674
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:633
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:626
size_t getLiveOperationCount()
Gets the count of live operations associated with this context.
Definition: IRCore.cpp:657
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition: IRCore.cpp:720
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:540
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1100
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1125
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1132
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1027
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:1969
MlirNamedAttribute namedAttr
Definition: IRModule.h:1036
nanobind::object getObject()
Definition: IRModule.h:88
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:76
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:731
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:1863
PyOpView(const nanobind::object &operationObject)
Definition: IRCore.cpp:1871
static nanobind::object buildGeneric(const nanobind::object &cls, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, DefaultingPyLocation location, const nanobind::object &maybeIp)
Definition: IRCore.cpp:1674
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:569
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1291
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void print(std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1347
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions)
Definition: IRCore.cpp:1323
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1269
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1356
bool verify()
Verify the operation.
Definition: IRCore.cpp:1365
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:638
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1567
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1392
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition: IRCore.cpp:1195
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition: IRCore.cpp:1547
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:616
static nanobind::object create(const std::string &name, std::optional< std::vector< PyType * >> results, std::optional< std::vector< PyValue * >> operands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, DefaultingPyLocation location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1421
PyOperationRef getRef()
Definition: IRModule.h:651
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1397
MlirOperation get() const
Definition: IRModule.h:646
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1179
void setAttached(const nanobind::object &parent=nanobind::object())
Definition: IRModule.h:656
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1373
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1556
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1383
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1209
void checkValid() const
Definition: IRCore.cpp:1221
Wrapper around an MlirRegion.
Definition: IRModule.h:763
PyOperationRef & getParentOperation()
Definition: IRModule.h:772
MlirRegion get()
Definition: IRModule.h:771
Bindings for MLIR symbol tables.
Definition: IRModule.h:1242
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2090
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2173
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2161
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2143
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2117
MlirAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2095
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2080
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2059
static MlirAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2105
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2067
static MlirAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2132
Tracks an entry in the thread context stack.
Definition: IRModule.h:107
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:814
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:926
static nanobind::object pushLocation(nanobind::object location)
Definition: IRCore.cpp:917
static nanobind::object pushContext(nanobind::object context)
Definition: IRCore.cpp:876
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:871
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:906
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition: IRCore.cpp:894
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:883
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:866
PyMlirContext * getContext()
Definition: IRCore.cpp:843
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:861
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:809
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:849
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:901
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:2005
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:2011
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:2001
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:903
Wrapper around the generic MlirType.
Definition: IRModule.h:877
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:879
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:1985
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:1989
bool operator==(const PyType &other) const
Definition: IRCore.cpp:1981
Wrapper around the generic MlirValue.
Definition: IRModule.h:1143
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1149
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:2038
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2023
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:2019
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:252
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:958
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1043
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:735
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:701
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:529
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:260
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1201
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:128
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:588
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1180
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:728
@ MlirWalkPreOrder
Definition: IR.h:729
@ MlirWalkPostOrder
Definition: IR.h:730
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:662
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:377
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:330
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1128
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1185
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:220
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:946
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:83
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1149
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:314
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1062
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1045
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1101
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:72
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:802
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:1009
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:600
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:992
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:288
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1082
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:672
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1033
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1112
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1120
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:841
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:725
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:898
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition: IR.cpp:1013
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:693
MlirWalkResult
Operation walk result.
Definition: IR.h:721
@ MlirWalkResultInterrupt
Definition: IR.h:723
@ MlirWalkResultSkip
Definition: IR.h:724
@ MlirWalkResultAdvance
Definition: IR.h:722
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:782
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:667
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1008
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:99
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:596
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:503
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:910
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:216
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:271
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:891
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:187
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:94
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:201
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:788
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:963
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:515
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:825
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:817
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:866
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:928
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1175
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1066
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:984
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:310
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1031
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1165
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:292
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:533
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:282
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:511
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition: IR.cpp:914
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:687
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:999
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:300
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:717
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:336
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1078
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:348
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1041
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:856
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Skip printing regions.
Definition: IR.cpp:228
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:565
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:537
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:322
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:256
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:224
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1070
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1161
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:173
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:239
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:573
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:389
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:845
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:267
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:967
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:950
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:569
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:235
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:743
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:54
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:837
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:541
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:296
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: IR.cpp:919
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:677
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1126
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1190
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1093
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:314
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:778
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:849
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1087
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:592
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:235
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1097
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:954
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:756
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:507
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:545
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1074
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1137
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:653
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:107
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:381
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:385
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:326
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:197
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:279
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:103
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:932
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:243
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1157
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:765
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:905
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:372
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:211
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:519
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1108
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:577
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:715
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:971
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:193
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:169
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:457
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1098
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:76
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:721
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:658
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:986
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:708
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:980
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:976
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:70
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:494
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1116
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:519
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:771
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Definition: Polynomial.h:262
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:162
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:529
void populateIRCore(nanobind::module_ &m)
PyObjectRef< PyOperation > PyOperationRef
Definition: IRModule.h:612
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:340
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:272
static nb::callable dundeGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:275
static void dundeSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace)
Definition: IRCore.cpp:281
static void bind(nb::module_ &m)
Definition: IRCore.cpp:287
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:245
static void bind(nb::module_ &m)
Definition: IRCore.cpp:250
static void set(nb::object &o, bool enable)
Definition: IRCore.cpp:246
static bool get(const nb::object &)
Definition: IRCore.cpp:248
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1294
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1299
Materialized diagnostic information.
Definition: IRModule.h:360
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:426
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:436
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:427