MLIR  22.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Globals.h"
10 #include "IRModule.h"
11 #include "NanobindUtils.h"
12 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14 #include "mlir-c/Debug.h"
15 #include "mlir-c/Diagnostics.h"
16 #include "mlir-c/IR.h"
17 #include "mlir-c/Support.h"
20 #include "nanobind/nanobind.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 #include <optional>
25 
26 namespace nb = nanobind;
27 using namespace nb::literals;
28 using namespace mlir;
29 using namespace mlir::python;
30 
31 using llvm::SmallVector;
32 using llvm::StringRef;
33 using llvm::Twine;
34 
35 //------------------------------------------------------------------------------
36 // Docstrings (trivial, non-duplicated docstrings are included inline).
37 //------------------------------------------------------------------------------
38 
39 static const char kContextParseTypeDocstring[] =
40  R"(Parses the assembly form of a type.
41 
42 Returns a Type object or raises an MLIRError if the type cannot be parsed.
43 
44 See also: https://mlir.llvm.org/docs/LangRef/#type-system
45 )";
46 
48  R"(Gets a Location representing a caller and callsite)";
49 
50 static const char kContextGetFileLocationDocstring[] =
51  R"(Gets a Location representing a file, line and column)";
52 
53 static const char kContextGetFileRangeDocstring[] =
54  R"(Gets a Location representing a file, line and column range)";
55 
56 static const char kContextGetFusedLocationDocstring[] =
57  R"(Gets a Location representing a fused location with optional metadata)";
58 
59 static const char kContextGetNameLocationDocString[] =
60  R"(Gets a Location representing a named location with optional child location)";
61 
62 static const char kModuleParseDocstring[] =
63  R"(Parses a module's assembly format from a string.
64 
65 Returns a new MlirModule or raises an MLIRError if the parsing fails.
66 
67 See also: https://mlir.llvm.org/docs/LangRef/
68 )";
69 
70 static const char kModuleCAPICreate[] =
71  R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
72 Note this returns a new object BUT _clear_mlir_module(module) must be called to
73 prevent double-frees (of the underlying mlir::Module).
74 )";
75 
76 static const char kOperationCreateDocstring[] =
77  R"(Creates a new operation.
78 
79 Args:
80  name: Operation name (e.g. "dialect.operation").
81  results: Sequence of Type representing op result types.
82  attributes: Dict of str:Attribute.
83  successors: List of Block for the operation's successors.
84  regions: Number of regions to create.
85  location: A Location object (defaults to resolve from context manager).
86  ip: An InsertionPoint (defaults to resolve from context manager or set to
87  False to disable insertion, even with an insertion point set in the
88  context manager).
89  infer_type: Whether to infer result types.
90 Returns:
91  A new "detached" Operation object. Detached operations can be added
92  to blocks, which causes them to become "attached."
93 )";
94 
95 static const char kOperationPrintDocstring[] =
96  R"(Prints the assembly form of the operation to a file like object.
97 
98 Args:
99  file: The file like object to write to. Defaults to sys.stdout.
100  binary: Whether to write bytes (True) or str (False). Defaults to False.
101  large_elements_limit: Whether to elide elements attributes above this
102  number of elements. Defaults to None (no limit).
103  large_resource_limit: Whether to elide resource attributes above this
104  number of characters. Defaults to None (no limit). If large_elements_limit
105  is set and this is None, the behavior will be to use large_elements_limit
106  as large_resource_limit.
107  enable_debug_info: Whether to print debug/location information. Defaults
108  to False.
109  pretty_debug_info: Whether to format debug information for easier reading
110  by a human (warning: the result is unparseable).
111  print_generic_op_form: Whether to print the generic assembly forms of all
112  ops. Defaults to False.
113  use_local_Scope: Whether to print in a way that is more optimized for
114  multi-threaded access but may not be consistent with how the overall
115  module prints.
116  assume_verified: By default, if not printing generic form, the verifier
117  will be run and if it fails, generic form will be printed with a comment
118  about failed verification. While a reasonable default for interactive use,
119  for systematic use, it is often better for the caller to verify explicitly
120  and report failures in a more robust fashion. Set this to True if doing this
121  in order to avoid running a redundant verification. If the IR is actually
122  invalid, behavior is undefined.
123  skip_regions: Whether to skip printing regions. Defaults to False.
124 )";
125 
126 static const char kOperationPrintStateDocstring[] =
127  R"(Prints the assembly form of the operation to a file like object.
128 
129 Args:
130  file: The file like object to write to. Defaults to sys.stdout.
131  binary: Whether to write bytes (True) or str (False). Defaults to False.
132  state: AsmState capturing the operation numbering and flags.
133 )";
134 
135 static const char kOperationGetAsmDocstring[] =
136  R"(Gets the assembly form of the operation with all options available.
137 
138 Args:
139  binary: Whether to return a bytes (True) or str (False) object. Defaults to
140  False.
141  ... others ...: See the print() method for common keyword arguments for
142  configuring the printout.
143 Returns:
144  Either a bytes or str object, depending on the setting of the 'binary'
145  argument.
146 )";
147 
148 static const char kOperationPrintBytecodeDocstring[] =
149  R"(Write the bytecode form of the operation to a file like object.
150 
151 Args:
152  file: The file like object to write to.
153  desired_version: The version of bytecode to emit.
154 Returns:
155  The bytecode writer status.
156 )";
157 
158 static const char kOperationStrDunderDocstring[] =
159  R"(Gets the assembly form of the operation with default options.
160 
161 If more advanced control over the assembly formatting or I/O options is needed,
162 use the dedicated print or get_asm method, which supports keyword arguments to
163 customize behavior.
164 )";
165 
166 static const char kDumpDocstring[] =
167  R"(Dumps a debug representation of the object to stderr.)";
168 
169 static const char kAppendBlockDocstring[] =
170  R"(Appends a new block, with argument types as positional args.
171 
172 Returns:
173  The created block.
174 )";
175 
176 static const char kValueDunderStrDocstring[] =
177  R"(Returns the string form of the value.
178 
179 If the value is a block argument, this is the assembly form of its type and the
180 position in the argument list. If the value is an operation result, this is
181 equivalent to printing the operation that produced it.
182 )";
183 
184 static const char kGetNameAsOperand[] =
185  R"(Returns the string form of value as an operand (i.e., the ValueID).
186 )";
187 
189  R"(Replace all uses of value with the new value, updating anything in
190 the IR that uses 'self' to use the other value instead.
191 )";
192 
194  R"("Replace all uses of this value with the 'with' value, except for those
195 in 'exceptions'. 'exceptions' can be either a single operation or a list of
196 operations.
197 )";
198 
199 //------------------------------------------------------------------------------
200 // Utilities.
201 //------------------------------------------------------------------------------
202 
203 /// Helper for creating an @classmethod.
204 template <class Func, typename... Args>
205 static nb::object classmethod(Func f, Args... args) {
206  nb::object cf = nb::cpp_function(f, args...);
207  return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
208 }
209 
210 static nb::object
211 createCustomDialectWrapper(const std::string &dialectNamespace,
212  nb::object dialectDescriptor) {
213  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
214  if (!dialectClass) {
215  // Use the base class.
216  return nb::cast(PyDialect(std::move(dialectDescriptor)));
217  }
218 
219  // Create the custom implementation.
220  return (*dialectClass)(std::move(dialectDescriptor));
221 }
222 
223 static MlirStringRef toMlirStringRef(const std::string &s) {
224  return mlirStringRefCreate(s.data(), s.size());
225 }
226 
227 static MlirStringRef toMlirStringRef(std::string_view s) {
228  return mlirStringRefCreate(s.data(), s.size());
229 }
230 
231 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
232  return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
233 }
234 
235 /// Create a block, using the current location context if no locations are
236 /// specified.
237 static MlirBlock createBlock(const nb::sequence &pyArgTypes,
238  const std::optional<nb::sequence> &pyArgLocs) {
239  SmallVector<MlirType> argTypes;
240  argTypes.reserve(nb::len(pyArgTypes));
241  for (const auto &pyType : pyArgTypes)
242  argTypes.push_back(nb::cast<PyType &>(pyType));
243 
245  if (pyArgLocs) {
246  argLocs.reserve(nb::len(*pyArgLocs));
247  for (const auto &pyLoc : *pyArgLocs)
248  argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
249  } else if (!argTypes.empty()) {
250  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
251  }
252 
253  if (argTypes.size() != argLocs.size())
254  throw nb::value_error(("Expected " + Twine(argTypes.size()) +
255  " locations, got: " + Twine(argLocs.size()))
256  .str()
257  .c_str());
258  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
259 }
260 
261 /// Wrapper for the global LLVM debugging flag.
263  static void set(nb::object &o, bool enable) {
264  nb::ft_lock_guard lock(mutex);
265  mlirEnableGlobalDebug(enable);
266  }
267 
268  static bool get(const nb::object &) {
269  nb::ft_lock_guard lock(mutex);
270  return mlirIsGlobalDebugEnabled();
271  }
272 
273  static void bind(nb::module_ &m) {
274  // Debug flags.
275  nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
276  .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
277  &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
278  .def_static(
279  "set_types",
280  [](const std::string &type) {
281  nb::ft_lock_guard lock(mutex);
282  mlirSetGlobalDebugType(type.c_str());
283  },
284  "types"_a, "Sets specific debug types to be produced by LLVM")
285  .def_static("set_types", [](const std::vector<std::string> &types) {
286  std::vector<const char *> pointers;
287  pointers.reserve(types.size());
288  for (const std::string &str : types)
289  pointers.push_back(str.c_str());
290  nb::ft_lock_guard lock(mutex);
291  mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
292  });
293  }
294 
295 private:
296  static nb::ft_mutex mutex;
297 };
298 
299 nb::ft_mutex PyGlobalDebugFlag::mutex;
300 
302  static bool dunderContains(const std::string &attributeKind) {
303  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
304  }
305  static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
306  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
307  if (!builder)
308  throw nb::key_error(attributeKind.c_str());
309  return *builder;
310  }
311  static void dunderSetItemNamed(const std::string &attributeKind,
312  nb::callable func, bool replace) {
313  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
314  replace);
315  }
316 
317  static void bind(nb::module_ &m) {
318  nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
319  .def_static("contains", &PyAttrBuilderMap::dunderContains)
320  .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
321  .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
322  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
323  "Register an attribute builder for building MLIR "
324  "attributes from python values.");
325  }
326 };
327 
328 //------------------------------------------------------------------------------
329 // PyBlock
330 //------------------------------------------------------------------------------
331 
332 nb::object PyBlock::getCapsule() {
333  return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
334 }
335 
336 //------------------------------------------------------------------------------
337 // Collections.
338 //------------------------------------------------------------------------------
339 
340 namespace {
341 
342 class PyRegionIterator {
343 public:
344  PyRegionIterator(PyOperationRef operation)
345  : operation(std::move(operation)) {}
346 
347  PyRegionIterator &dunderIter() { return *this; }
348 
349  PyRegion dunderNext() {
350  operation->checkValid();
351  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
352  throw nb::stop_iteration();
353  }
354  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
355  return PyRegion(operation, region);
356  }
357 
358  static void bind(nb::module_ &m) {
359  nb::class_<PyRegionIterator>(m, "RegionIterator")
360  .def("__iter__", &PyRegionIterator::dunderIter)
361  .def("__next__", &PyRegionIterator::dunderNext);
362  }
363 
364 private:
365  PyOperationRef operation;
366  int nextIndex = 0;
367 };
368 
369 /// Regions of an op are fixed length and indexed numerically so are represented
370 /// with a sequence-like container.
371 class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
372 public:
373  static constexpr const char *pyClassName = "RegionSequence";
374 
375  PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
376  intptr_t length = -1, intptr_t step = 1)
377  : Sliceable(startIndex,
378  length == -1 ? mlirOperationGetNumRegions(operation->get())
379  : length,
380  step),
381  operation(std::move(operation)) {}
382 
383  PyRegionIterator dunderIter() {
384  operation->checkValid();
385  return PyRegionIterator(operation);
386  }
387 
388  static void bindDerived(ClassTy &c) {
389  c.def("__iter__", &PyRegionList::dunderIter);
390  }
391 
392 private:
393  /// Give the parent CRTP class access to hook implementations below.
394  friend class Sliceable<PyRegionList, PyRegion>;
395 
396  intptr_t getRawNumElements() {
397  operation->checkValid();
398  return mlirOperationGetNumRegions(operation->get());
399  }
400 
401  PyRegion getRawElement(intptr_t pos) {
402  operation->checkValid();
403  return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
404  }
405 
406  PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
407  return PyRegionList(operation, startIndex, length, step);
408  }
409 
410  PyOperationRef operation;
411 };
412 
413 class PyBlockIterator {
414 public:
415  PyBlockIterator(PyOperationRef operation, MlirBlock next)
416  : operation(std::move(operation)), next(next) {}
417 
418  PyBlockIterator &dunderIter() { return *this; }
419 
420  PyBlock dunderNext() {
421  operation->checkValid();
422  if (mlirBlockIsNull(next)) {
423  throw nb::stop_iteration();
424  }
425 
426  PyBlock returnBlock(operation, next);
427  next = mlirBlockGetNextInRegion(next);
428  return returnBlock;
429  }
430 
431  static void bind(nb::module_ &m) {
432  nb::class_<PyBlockIterator>(m, "BlockIterator")
433  .def("__iter__", &PyBlockIterator::dunderIter)
434  .def("__next__", &PyBlockIterator::dunderNext);
435  }
436 
437 private:
438  PyOperationRef operation;
439  MlirBlock next;
440 };
441 
442 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
443 /// we present them as a more full-featured list-like container but optimize
444 /// it for forward iteration. Blocks are always owned by a region.
445 class PyBlockList {
446 public:
447  PyBlockList(PyOperationRef operation, MlirRegion region)
448  : operation(std::move(operation)), region(region) {}
449 
450  PyBlockIterator dunderIter() {
451  operation->checkValid();
452  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
453  }
454 
455  intptr_t dunderLen() {
456  operation->checkValid();
457  intptr_t count = 0;
458  MlirBlock block = mlirRegionGetFirstBlock(region);
459  while (!mlirBlockIsNull(block)) {
460  count += 1;
461  block = mlirBlockGetNextInRegion(block);
462  }
463  return count;
464  }
465 
466  PyBlock dunderGetItem(intptr_t index) {
467  operation->checkValid();
468  if (index < 0) {
469  index += dunderLen();
470  }
471  if (index < 0) {
472  throw nb::index_error("attempt to access out of bounds block");
473  }
474  MlirBlock block = mlirRegionGetFirstBlock(region);
475  while (!mlirBlockIsNull(block)) {
476  if (index == 0) {
477  return PyBlock(operation, block);
478  }
479  block = mlirBlockGetNextInRegion(block);
480  index -= 1;
481  }
482  throw nb::index_error("attempt to access out of bounds block");
483  }
484 
485  PyBlock appendBlock(const nb::args &pyArgTypes,
486  const std::optional<nb::sequence> &pyArgLocs) {
487  operation->checkValid();
488  MlirBlock block =
489  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
490  mlirRegionAppendOwnedBlock(region, block);
491  return PyBlock(operation, block);
492  }
493 
494  static void bind(nb::module_ &m) {
495  nb::class_<PyBlockList>(m, "BlockList")
496  .def("__getitem__", &PyBlockList::dunderGetItem)
497  .def("__iter__", &PyBlockList::dunderIter)
498  .def("__len__", &PyBlockList::dunderLen)
499  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
500  nb::arg("args"), nb::kw_only(),
501  nb::arg("arg_locs") = std::nullopt);
502  }
503 
504 private:
505  PyOperationRef operation;
506  MlirRegion region;
507 };
508 
509 class PyOperationIterator {
510 public:
511  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
512  : parentOperation(std::move(parentOperation)), next(next) {}
513 
514  PyOperationIterator &dunderIter() { return *this; }
515 
516  nb::typed<nb::object, PyOpView> dunderNext() {
517  parentOperation->checkValid();
518  if (mlirOperationIsNull(next)) {
519  throw nb::stop_iteration();
520  }
521 
522  PyOperationRef returnOperation =
523  PyOperation::forOperation(parentOperation->getContext(), next);
524  next = mlirOperationGetNextInBlock(next);
525  return returnOperation->createOpView();
526  }
527 
528  static void bind(nb::module_ &m) {
529  nb::class_<PyOperationIterator>(m, "OperationIterator")
530  .def("__iter__", &PyOperationIterator::dunderIter)
531  .def("__next__", &PyOperationIterator::dunderNext);
532  }
533 
534 private:
535  PyOperationRef parentOperation;
536  MlirOperation next;
537 };
538 
539 /// Operations are exposed by the C-API as a forward-only linked list. In
540 /// Python, we present them as a more full-featured list-like container but
541 /// optimize it for forward iteration. Iterable operations are always owned
542 /// by a block.
543 class PyOperationList {
544 public:
545  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
546  : parentOperation(std::move(parentOperation)), block(block) {}
547 
548  PyOperationIterator dunderIter() {
549  parentOperation->checkValid();
550  return PyOperationIterator(parentOperation,
552  }
553 
554  intptr_t dunderLen() {
555  parentOperation->checkValid();
556  intptr_t count = 0;
557  MlirOperation childOp = mlirBlockGetFirstOperation(block);
558  while (!mlirOperationIsNull(childOp)) {
559  count += 1;
560  childOp = mlirOperationGetNextInBlock(childOp);
561  }
562  return count;
563  }
564 
565  nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
566  parentOperation->checkValid();
567  if (index < 0) {
568  index += dunderLen();
569  }
570  if (index < 0) {
571  throw nb::index_error("attempt to access out of bounds operation");
572  }
573  MlirOperation childOp = mlirBlockGetFirstOperation(block);
574  while (!mlirOperationIsNull(childOp)) {
575  if (index == 0) {
576  return PyOperation::forOperation(parentOperation->getContext(), childOp)
577  ->createOpView();
578  }
579  childOp = mlirOperationGetNextInBlock(childOp);
580  index -= 1;
581  }
582  throw nb::index_error("attempt to access out of bounds operation");
583  }
584 
585  static void bind(nb::module_ &m) {
586  nb::class_<PyOperationList>(m, "OperationList")
587  .def("__getitem__", &PyOperationList::dunderGetItem)
588  .def("__iter__", &PyOperationList::dunderIter)
589  .def("__len__", &PyOperationList::dunderLen);
590  }
591 
592 private:
593  PyOperationRef parentOperation;
594  MlirBlock block;
595 };
596 
597 class PyOpOperand {
598 public:
599  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
600 
601  PyOpView getOwner() {
602  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
603  PyMlirContextRef context =
604  PyMlirContext::forContext(mlirOperationGetContext(owner));
605  return PyOperation::forOperation(context, owner)->createOpView();
606  }
607 
608  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
609 
610  static void bind(nb::module_ &m) {
611  nb::class_<PyOpOperand>(m, "OpOperand")
612  .def_prop_ro("owner", &PyOpOperand::getOwner)
613  .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
614  }
615 
616 private:
617  MlirOpOperand opOperand;
618 };
619 
620 class PyOpOperandIterator {
621 public:
622  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
623 
624  PyOpOperandIterator &dunderIter() { return *this; }
625 
626  PyOpOperand dunderNext() {
627  if (mlirOpOperandIsNull(opOperand))
628  throw nb::stop_iteration();
629 
630  PyOpOperand returnOpOperand(opOperand);
631  opOperand = mlirOpOperandGetNextUse(opOperand);
632  return returnOpOperand;
633  }
634 
635  static void bind(nb::module_ &m) {
636  nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
637  .def("__iter__", &PyOpOperandIterator::dunderIter)
638  .def("__next__", &PyOpOperandIterator::dunderNext);
639  }
640 
641 private:
642  MlirOpOperand opOperand;
643 };
644 
645 } // namespace
646 
647 //------------------------------------------------------------------------------
648 // PyMlirContext
649 //------------------------------------------------------------------------------
650 
651 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
652  nb::gil_scoped_acquire acquire;
653  nb::ft_lock_guard lock(live_contexts_mutex);
654  auto &liveContexts = getLiveContexts();
655  liveContexts[context.ptr] = this;
656 }
657 
659  // Note that the only public way to construct an instance is via the
660  // forContext method, which always puts the associated handle into
661  // liveContexts.
662  nb::gil_scoped_acquire acquire;
663  {
664  nb::ft_lock_guard lock(live_contexts_mutex);
665  getLiveContexts().erase(context.ptr);
666  }
667  mlirContextDestroy(context);
668 }
669 
671  return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
672 }
673 
674 nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
675  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
676  if (mlirContextIsNull(rawContext))
677  throw nb::python_error();
678  return forContext(rawContext).releaseObject();
679 }
680 
682  nb::gil_scoped_acquire acquire;
683  nb::ft_lock_guard lock(live_contexts_mutex);
684  auto &liveContexts = getLiveContexts();
685  auto it = liveContexts.find(context.ptr);
686  if (it == liveContexts.end()) {
687  // Create.
688  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
689  nb::object pyRef = nb::cast(unownedContextWrapper);
690  assert(pyRef && "cast to nb::object failed");
691  liveContexts[context.ptr] = unownedContextWrapper;
692  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
693  }
694  // Use existing.
695  nb::object pyRef = nb::cast(it->second);
696  return PyMlirContextRef(it->second, std::move(pyRef));
697 }
698 
699 nb::ft_mutex PyMlirContext::live_contexts_mutex;
700 
701 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
702  static LiveContextMap liveContexts;
703  return liveContexts;
704 }
705 
707  nb::ft_lock_guard lock(live_contexts_mutex);
708  return getLiveContexts().size();
709 }
710 
711 nb::object PyMlirContext::contextEnter(nb::object context) {
712  return PyThreadContextEntry::pushContext(context);
713 }
714 
715 void PyMlirContext::contextExit(const nb::object &excType,
716  const nb::object &excVal,
717  const nb::object &excTb) {
719 }
720 
721 nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
722  // Note that ownership is transferred to the delete callback below by way of
723  // an explicit inc_ref (borrow).
724  PyDiagnosticHandler *pyHandler =
725  new PyDiagnosticHandler(get(), std::move(callback));
726  nb::object pyHandlerObject =
727  nb::cast(pyHandler, nb::rv_policy::take_ownership);
728  (void)pyHandlerObject.inc_ref();
729 
730  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
731  // guaranteed to be known to pybind.
732  auto handlerCallback =
733  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
734  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
735  nb::object pyDiagnosticObject =
736  nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
737 
738  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
739  bool result = false;
740  {
741  // Since this can be called from arbitrary C++ contexts, always get the
742  // gil.
743  nb::gil_scoped_acquire gil;
744  try {
745  result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
746  } catch (std::exception &e) {
747  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
748  e.what());
749  pyHandler->hadError = true;
750  }
751  }
752 
753  pyDiagnostic->invalidate();
755  };
756  auto deleteCallback = +[](void *userData) {
757  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
758  assert(pyHandler->registeredID && "handler is not registered");
759  pyHandler->registeredID.reset();
760 
761  // Decrement reference, balancing the inc_ref() above.
762  nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
763  pyHandlerObject.dec_ref();
764  };
765 
766  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
767  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
768  return pyHandlerObject;
769 }
770 
771 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
772  void *userData) {
773  auto *self = static_cast<ErrorCapture *>(userData);
774  // Check if the context requested we emit errors instead of capturing them.
775  if (self->ctx->emitErrorDiagnostics)
776  return mlirLogicalResultFailure();
777 
779  return mlirLogicalResultFailure();
780 
781  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
782  return mlirLogicalResultSuccess();
783 }
784 
787  if (!context) {
788  throw std::runtime_error(
789  "An MLIR function requires a Context but none was provided in the call "
790  "or from the surrounding environment. Either pass to the function with "
791  "a 'context=' argument or establish a default using 'with Context():'");
792  }
793  return *context;
794 }
795 
796 //------------------------------------------------------------------------------
797 // PyThreadContextEntry management
798 //------------------------------------------------------------------------------
799 
800 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
801  static thread_local std::vector<PyThreadContextEntry> stack;
802  return stack;
803 }
804 
806  auto &stack = getStack();
807  if (stack.empty())
808  return nullptr;
809  return &stack.back();
810 }
811 
812 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
813  nb::object insertionPoint,
814  nb::object location) {
815  auto &stack = getStack();
816  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
817  std::move(location));
818  // If the new stack has more than one entry and the context of the new top
819  // entry matches the previous, copy the insertionPoint and location from the
820  // previous entry if missing from the new top entry.
821  if (stack.size() > 1) {
822  auto &prev = *(stack.rbegin() + 1);
823  auto &current = stack.back();
824  if (current.context.is(prev.context)) {
825  // Default non-context objects from the previous entry.
826  if (!current.insertionPoint)
827  current.insertionPoint = prev.insertionPoint;
828  if (!current.location)
829  current.location = prev.location;
830  }
831  }
832 }
833 
835  if (!context)
836  return nullptr;
837  return nb::cast<PyMlirContext *>(context);
838 }
839 
841  if (!insertionPoint)
842  return nullptr;
843  return nb::cast<PyInsertionPoint *>(insertionPoint);
844 }
845 
847  if (!location)
848  return nullptr;
849  return nb::cast<PyLocation *>(location);
850 }
851 
853  auto *tos = getTopOfStack();
854  return tos ? tos->getContext() : nullptr;
855 }
856 
858  auto *tos = getTopOfStack();
859  return tos ? tos->getInsertionPoint() : nullptr;
860 }
861 
863  auto *tos = getTopOfStack();
864  return tos ? tos->getLocation() : nullptr;
865 }
866 
867 nb::object PyThreadContextEntry::pushContext(nb::object context) {
868  push(FrameKind::Context, /*context=*/context,
869  /*insertionPoint=*/nb::object(),
870  /*location=*/nb::object());
871  return context;
872 }
873 
875  auto &stack = getStack();
876  if (stack.empty())
877  throw std::runtime_error("Unbalanced Context enter/exit");
878  auto &tos = stack.back();
879  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
880  throw std::runtime_error("Unbalanced Context enter/exit");
881  stack.pop_back();
882 }
883 
884 nb::object
885 PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
886  PyInsertionPoint &insertionPoint =
887  nb::cast<PyInsertionPoint &>(insertionPointObj);
888  nb::object contextObj =
889  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
890  push(FrameKind::InsertionPoint,
891  /*context=*/contextObj,
892  /*insertionPoint=*/insertionPointObj,
893  /*location=*/nb::object());
894  return insertionPointObj;
895 }
896 
898  auto &stack = getStack();
899  if (stack.empty())
900  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
901  auto &tos = stack.back();
902  if (tos.frameKind != FrameKind::InsertionPoint &&
903  tos.getInsertionPoint() != &insertionPoint)
904  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
905  stack.pop_back();
906 }
907 
908 nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
909  PyLocation &location = nb::cast<PyLocation &>(locationObj);
910  nb::object contextObj = location.getContext().getObject();
911  push(FrameKind::Location, /*context=*/contextObj,
912  /*insertionPoint=*/nb::object(),
913  /*location=*/locationObj);
914  return locationObj;
915 }
916 
918  auto &stack = getStack();
919  if (stack.empty())
920  throw std::runtime_error("Unbalanced Location enter/exit");
921  auto &tos = stack.back();
922  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
923  throw std::runtime_error("Unbalanced Location enter/exit");
924  stack.pop_back();
925 }
926 
927 //------------------------------------------------------------------------------
928 // PyDiagnostic*
929 //------------------------------------------------------------------------------
930 
932  valid = false;
933  if (materializedNotes) {
934  for (nb::handle noteObject : *materializedNotes) {
935  PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
936  note->invalidate();
937  }
938  }
939 }
940 
942  nb::object callback)
943  : context(context), callback(std::move(callback)) {}
944 
946 
948  if (!registeredID)
949  return;
950  MlirDiagnosticHandlerID localID = *registeredID;
951  mlirContextDetachDiagnosticHandler(context, localID);
952  assert(!registeredID && "should have unregistered");
953  // Not strictly necessary but keeps stale pointers from being around to cause
954  // issues.
955  context = {nullptr};
956 }
957 
958 void PyDiagnostic::checkValid() {
959  if (!valid) {
960  throw std::invalid_argument(
961  "Diagnostic is invalid (used outside of callback)");
962  }
963 }
964 
966  checkValid();
967  return mlirDiagnosticGetSeverity(diagnostic);
968 }
969 
971  checkValid();
972  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
973  MlirContext context = mlirLocationGetContext(loc);
974  return PyLocation(PyMlirContext::forContext(context), loc);
975 }
976 
978  checkValid();
979  nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
980  PyFileAccumulator accum(fileObject, /*binary=*/false);
981  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
982  return nb::cast<nb::str>(fileObject.attr("getvalue")());
983 }
984 
986  checkValid();
987  if (materializedNotes)
988  return *materializedNotes;
989  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
990  nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
991  for (intptr_t i = 0; i < numNotes; ++i) {
992  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
993  nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
994  PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
995  }
996  materializedNotes = std::move(notes);
997 
998  return *materializedNotes;
999 }
1000 
1002  std::vector<DiagnosticInfo> notes;
1003  for (nb::handle n : getNotes())
1004  notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1005  return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1006  std::move(notes)};
1007 }
1008 
1009 //------------------------------------------------------------------------------
1010 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1011 //------------------------------------------------------------------------------
1012 
1013 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1014  bool attrError) {
1015  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1016  {key.data(), key.size()});
1017  if (mlirDialectIsNull(dialect)) {
1018  std::string msg = (Twine("Dialect '") + key + "' not found").str();
1019  if (attrError)
1020  throw nb::attribute_error(msg.c_str());
1021  throw nb::index_error(msg.c_str());
1022  }
1023  return dialect;
1024 }
1025 
1027  return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1028 }
1029 
1031  MlirDialectRegistry rawRegistry =
1032  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1033  if (mlirDialectRegistryIsNull(rawRegistry))
1034  throw nb::python_error();
1035  return PyDialectRegistry(rawRegistry);
1036 }
1037 
1038 //------------------------------------------------------------------------------
1039 // PyLocation
1040 //------------------------------------------------------------------------------
1041 
1043  return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1044 }
1045 
1047  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1048  if (mlirLocationIsNull(rawLoc))
1049  throw nb::python_error();
1051  rawLoc);
1052 }
1053 
1054 nb::object PyLocation::contextEnter(nb::object locationObj) {
1055  return PyThreadContextEntry::pushLocation(locationObj);
1056 }
1057 
1058 void PyLocation::contextExit(const nb::object &excType,
1059  const nb::object &excVal,
1060  const nb::object &excTb) {
1062 }
1063 
1065  auto *location = PyThreadContextEntry::getDefaultLocation();
1066  if (!location) {
1067  throw std::runtime_error(
1068  "An MLIR function requires a Location but none was provided in the "
1069  "call or from the surrounding environment. Either pass to the function "
1070  "with a 'loc=' argument or establish a default using 'with loc:'");
1071  }
1072  return *location;
1073 }
1074 
1075 //------------------------------------------------------------------------------
1076 // PyModule
1077 //------------------------------------------------------------------------------
1078 
1079 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1080  : BaseContextObject(std::move(contextRef)), module(module) {}
1081 
1083  nb::gil_scoped_acquire acquire;
1084  auto &liveModules = getContext()->liveModules;
1085  assert(liveModules.count(module.ptr) == 1 &&
1086  "destroying module not in live map");
1087  liveModules.erase(module.ptr);
1088  mlirModuleDestroy(module);
1089 }
1090 
1091 PyModuleRef PyModule::forModule(MlirModule module) {
1092  MlirContext context = mlirModuleGetContext(module);
1093  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1094 
1095  nb::gil_scoped_acquire acquire;
1096  auto &liveModules = contextRef->liveModules;
1097  auto it = liveModules.find(module.ptr);
1098  if (it == liveModules.end()) {
1099  // Create.
1100  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1101  // Note that the default return value policy on cast is automatic_reference,
1102  // which does not take ownership (delete will not be called).
1103  // Just be explicit.
1104  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1105  unownedModule->handle = pyRef;
1106  liveModules[module.ptr] =
1107  std::make_pair(unownedModule->handle, unownedModule);
1108  return PyModuleRef(unownedModule, std::move(pyRef));
1109  }
1110  // Use existing.
1111  PyModule *existing = it->second.second;
1112  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1113  return PyModuleRef(existing, std::move(pyRef));
1114 }
1115 
1116 nb::object PyModule::createFromCapsule(nb::object capsule) {
1117  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1118  if (mlirModuleIsNull(rawModule))
1119  throw nb::python_error();
1120  return forModule(rawModule).releaseObject();
1121 }
1122 
1123 nb::object PyModule::getCapsule() {
1124  return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1125 }
1126 
1127 //------------------------------------------------------------------------------
1128 // PyOperation
1129 //------------------------------------------------------------------------------
1130 
1131 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1132  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1133 
1135  // If the operation has already been invalidated there is nothing to do.
1136  if (!valid)
1137  return;
1138  // Otherwise, invalidate the operation when it is attached.
1139  if (isAttached())
1140  setInvalid();
1141  else {
1142  // And destroy it when it is detached, i.e. owned by Python.
1143  erase();
1144  }
1145 }
1146 
1147 namespace {
1148 
1149 // Constructs a new object of type T in-place on the Python heap, returning a
1150 // PyObjectRef to it, loosely analogous to std::make_shared<T>().
1151 template <typename T, class... Args>
1152 PyObjectRef<T> makeObjectRef(Args &&...args) {
1153  nb::handle type = nb::type<T>();
1154  nb::object instance = nb::inst_alloc(type);
1155  T *ptr = nb::inst_ptr<T>(instance);
1156  new (ptr) T(std::forward<Args>(args)...);
1157  nb::inst_mark_ready(instance);
1158  return PyObjectRef<T>(ptr, std::move(instance));
1159 }
1160 
1161 } // namespace
1162 
1163 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1164  MlirOperation operation,
1165  nb::object parentKeepAlive) {
1166  // Create.
1167  PyOperationRef unownedOperation =
1168  makeObjectRef<PyOperation>(std::move(contextRef), operation);
1169  unownedOperation->handle = unownedOperation.getObject();
1170  if (parentKeepAlive) {
1171  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1172  }
1173  return unownedOperation;
1174 }
1175 
1177  MlirOperation operation,
1178  nb::object parentKeepAlive) {
1179  return createInstance(std::move(contextRef), operation,
1180  std::move(parentKeepAlive));
1181 }
1182 
1184  MlirOperation operation,
1185  nb::object parentKeepAlive) {
1186  PyOperationRef created = createInstance(std::move(contextRef), operation,
1187  std::move(parentKeepAlive));
1188  created->attached = false;
1189  return created;
1190 }
1191 
1193  const std::string &sourceStr,
1194  const std::string &sourceName) {
1195  PyMlirContext::ErrorCapture errors(contextRef);
1196  MlirOperation op =
1197  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1198  toMlirStringRef(sourceName));
1199  if (mlirOperationIsNull(op))
1200  throw MLIRError("Unable to parse operation assembly", errors.take());
1201  return PyOperation::createDetached(std::move(contextRef), op);
1202 }
1203 
1205  if (!valid) {
1206  throw std::runtime_error("the operation has been invalidated");
1207  }
1208 }
1209 
1210 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1211  std::optional<int64_t> largeResourceLimit,
1212  bool enableDebugInfo, bool prettyDebugInfo,
1213  bool printGenericOpForm, bool useLocalScope,
1214  bool useNameLocAsPrefix, bool assumeVerified,
1215  nb::object fileObject, bool binary,
1216  bool skipRegions) {
1217  PyOperation &operation = getOperation();
1218  operation.checkValid();
1219  if (fileObject.is_none())
1220  fileObject = nb::module_::import_("sys").attr("stdout");
1221 
1222  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1223  if (largeElementsLimit)
1224  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1225  if (largeResourceLimit)
1226  mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1227  if (enableDebugInfo)
1228  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1229  /*prettyForm=*/prettyDebugInfo);
1230  if (printGenericOpForm)
1232  if (useLocalScope)
1234  if (assumeVerified)
1236  if (skipRegions)
1238  if (useNameLocAsPrefix)
1240 
1241  PyFileAccumulator accum(fileObject, binary);
1242  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1243  accum.getUserData());
1245 }
1246 
1247 void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1248  bool binary) {
1249  PyOperation &operation = getOperation();
1250  operation.checkValid();
1251  if (fileObject.is_none())
1252  fileObject = nb::module_::import_("sys").attr("stdout");
1253  PyFileAccumulator accum(fileObject, binary);
1254  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1255  accum.getUserData());
1256 }
1257 
1258 void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1259  std::optional<int64_t> bytecodeVersion) {
1260  PyOperation &operation = getOperation();
1261  operation.checkValid();
1262  PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1263 
1264  if (!bytecodeVersion.has_value())
1265  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1266  accum.getUserData());
1267 
1268  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1271  operation, config, accum.getCallback(), accum.getUserData());
1273  if (mlirLogicalResultIsFailure(res))
1274  throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1275  Twine(*bytecodeVersion))
1276  .str()
1277  .c_str());
1278 }
1279 
1281  std::function<MlirWalkResult(MlirOperation)> callback,
1282  MlirWalkOrder walkOrder) {
1283  PyOperation &operation = getOperation();
1284  operation.checkValid();
1285  struct UserData {
1286  std::function<MlirWalkResult(MlirOperation)> callback;
1287  bool gotException;
1288  std::string exceptionWhat;
1289  nb::object exceptionType;
1290  };
1291  UserData userData{callback, false, {}, {}};
1292  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1293  void *userData) {
1294  UserData *calleeUserData = static_cast<UserData *>(userData);
1295  try {
1296  return (calleeUserData->callback)(op);
1297  } catch (nb::python_error &e) {
1298  calleeUserData->gotException = true;
1299  calleeUserData->exceptionWhat = std::string(e.what());
1300  calleeUserData->exceptionType = nb::borrow(e.type());
1302  }
1303  };
1304  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1305  if (userData.gotException) {
1306  std::string message("Exception raised in callback: ");
1307  message.append(userData.exceptionWhat);
1308  throw std::runtime_error(message);
1309  }
1310 }
1311 
1312 nb::object PyOperationBase::getAsm(bool binary,
1313  std::optional<int64_t> largeElementsLimit,
1314  std::optional<int64_t> largeResourceLimit,
1315  bool enableDebugInfo, bool prettyDebugInfo,
1316  bool printGenericOpForm, bool useLocalScope,
1317  bool useNameLocAsPrefix, bool assumeVerified,
1318  bool skipRegions) {
1319  nb::object fileObject;
1320  if (binary) {
1321  fileObject = nb::module_::import_("io").attr("BytesIO")();
1322  } else {
1323  fileObject = nb::module_::import_("io").attr("StringIO")();
1324  }
1325  print(/*largeElementsLimit=*/largeElementsLimit,
1326  /*largeResourceLimit=*/largeResourceLimit,
1327  /*enableDebugInfo=*/enableDebugInfo,
1328  /*prettyDebugInfo=*/prettyDebugInfo,
1329  /*printGenericOpForm=*/printGenericOpForm,
1330  /*useLocalScope=*/useLocalScope,
1331  /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1332  /*assumeVerified=*/assumeVerified,
1333  /*fileObject=*/fileObject,
1334  /*binary=*/binary,
1335  /*skipRegions=*/skipRegions);
1336 
1337  return fileObject.attr("getvalue")();
1338 }
1339 
1341  PyOperation &operation = getOperation();
1342  PyOperation &otherOp = other.getOperation();
1343  operation.checkValid();
1344  otherOp.checkValid();
1345  mlirOperationMoveAfter(operation, otherOp);
1346  operation.parentKeepAlive = otherOp.parentKeepAlive;
1347 }
1348 
1350  PyOperation &operation = getOperation();
1351  PyOperation &otherOp = other.getOperation();
1352  operation.checkValid();
1353  otherOp.checkValid();
1354  mlirOperationMoveBefore(operation, otherOp);
1355  operation.parentKeepAlive = otherOp.parentKeepAlive;
1356 }
1357 
1359  PyOperation &operation = getOperation();
1360  PyOperation &otherOp = other.getOperation();
1361  operation.checkValid();
1362  otherOp.checkValid();
1363  return mlirOperationIsBeforeInBlock(operation, otherOp);
1364 }
1365 
1367  PyOperation &op = getOperation();
1369  if (!mlirOperationVerify(op.get()))
1370  throw MLIRError("Verification failed", errors.take());
1371  return true;
1372 }
1373 
1374 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1375  checkValid();
1376  if (!isAttached())
1377  throw nb::value_error("Detached operations have no parent");
1378  MlirOperation operation = mlirOperationGetParentOperation(get());
1379  if (mlirOperationIsNull(operation))
1380  return {};
1381  return PyOperation::forOperation(getContext(), operation);
1382 }
1383 
1385  checkValid();
1386  std::optional<PyOperationRef> parentOperation = getParentOperation();
1387  MlirBlock block = mlirOperationGetBlock(get());
1388  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1389  assert(parentOperation && "Operation has no parent");
1390  return PyBlock{std::move(*parentOperation), block};
1391 }
1392 
1394  checkValid();
1395  return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1396 }
1397 
1398 nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1399  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1400  if (mlirOperationIsNull(rawOperation))
1401  throw nb::python_error();
1402  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1403  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1404  .releaseObject();
1405 }
1406 
1408  const nb::object &maybeIp) {
1409  // InsertPoint active?
1410  if (!maybeIp.is(nb::cast(false))) {
1411  PyInsertionPoint *ip;
1412  if (maybeIp.is_none()) {
1414  } else {
1415  ip = nb::cast<PyInsertionPoint *>(maybeIp);
1416  }
1417  if (ip)
1418  ip->insert(*op.get());
1419  }
1420 }
1421 
1422 nb::object PyOperation::create(std::string_view name,
1423  std::optional<std::vector<PyType *>> results,
1424  llvm::ArrayRef<MlirValue> operands,
1425  std::optional<nb::dict> attributes,
1426  std::optional<std::vector<PyBlock *>> successors,
1427  int regions, PyLocation &location,
1428  const nb::object &maybeIp, bool inferType) {
1429  llvm::SmallVector<MlirType, 4> mlirResults;
1430  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1432 
1433  // General parameter validation.
1434  if (regions < 0)
1435  throw nb::value_error("number of regions must be >= 0");
1436 
1437  // Unpack/validate results.
1438  if (results) {
1439  mlirResults.reserve(results->size());
1440  for (PyType *result : *results) {
1441  // TODO: Verify result type originate from the same context.
1442  if (!result)
1443  throw nb::value_error("result type cannot be None");
1444  mlirResults.push_back(*result);
1445  }
1446  }
1447  // Unpack/validate attributes.
1448  if (attributes) {
1449  mlirAttributes.reserve(attributes->size());
1450  for (std::pair<nb::handle, nb::handle> it : *attributes) {
1451  std::string key;
1452  try {
1453  key = nb::cast<std::string>(it.first);
1454  } catch (nb::cast_error &err) {
1455  std::string msg = "Invalid attribute key (not a string) when "
1456  "attempting to create the operation \"" +
1457  std::string(name) + "\" (" + err.what() + ")";
1458  throw nb::type_error(msg.c_str());
1459  }
1460  try {
1461  auto &attribute = nb::cast<PyAttribute &>(it.second);
1462  // TODO: Verify attribute originates from the same context.
1463  mlirAttributes.emplace_back(std::move(key), attribute);
1464  } catch (nb::cast_error &err) {
1465  std::string msg = "Invalid attribute value for the key \"" + key +
1466  "\" when attempting to create the operation \"" +
1467  std::string(name) + "\" (" + err.what() + ")";
1468  throw nb::type_error(msg.c_str());
1469  } catch (std::runtime_error &) {
1470  // This exception seems thrown when the value is "None".
1471  std::string msg =
1472  "Found an invalid (`None`?) attribute value for the key \"" + key +
1473  "\" when attempting to create the operation \"" +
1474  std::string(name) + "\"";
1475  throw std::runtime_error(msg);
1476  }
1477  }
1478  }
1479  // Unpack/validate successors.
1480  if (successors) {
1481  mlirSuccessors.reserve(successors->size());
1482  for (auto *successor : *successors) {
1483  // TODO: Verify successor originate from the same context.
1484  if (!successor)
1485  throw nb::value_error("successor block cannot be None");
1486  mlirSuccessors.push_back(successor->get());
1487  }
1488  }
1489 
1490  // Apply unpacked/validated to the operation state. Beyond this
1491  // point, exceptions cannot be thrown or else the state will leak.
1492  MlirOperationState state =
1493  mlirOperationStateGet(toMlirStringRef(name), location);
1494  if (!operands.empty())
1495  mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1496  state.enableResultTypeInference = inferType;
1497  if (!mlirResults.empty())
1498  mlirOperationStateAddResults(&state, mlirResults.size(),
1499  mlirResults.data());
1500  if (!mlirAttributes.empty()) {
1501  // Note that the attribute names directly reference bytes in
1502  // mlirAttributes, so that vector must not be changed from here
1503  // on.
1504  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1505  mlirNamedAttributes.reserve(mlirAttributes.size());
1506  for (auto &it : mlirAttributes)
1507  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1509  toMlirStringRef(it.first)),
1510  it.second));
1511  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1512  mlirNamedAttributes.data());
1513  }
1514  if (!mlirSuccessors.empty())
1515  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1516  mlirSuccessors.data());
1517  if (regions) {
1519  mlirRegions.resize(regions);
1520  for (int i = 0; i < regions; ++i)
1521  mlirRegions[i] = mlirRegionCreate();
1522  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1523  mlirRegions.data());
1524  }
1525 
1526  // Construct the operation.
1527  MlirOperation operation = mlirOperationCreate(&state);
1528  if (!operation.ptr)
1529  throw nb::value_error("Operation creation failed");
1530  PyOperationRef created =
1531  PyOperation::createDetached(location.getContext(), operation);
1532  maybeInsertOperation(created, maybeIp);
1533 
1534  return created.getObject();
1535 }
1536 
1537 nb::object PyOperation::clone(const nb::object &maybeIp) {
1538  MlirOperation clonedOperation = mlirOperationClone(operation);
1539  PyOperationRef cloned =
1540  PyOperation::createDetached(getContext(), clonedOperation);
1541  maybeInsertOperation(cloned, maybeIp);
1542 
1543  return cloned->createOpView();
1544 }
1545 
1547  checkValid();
1548  MlirIdentifier ident = mlirOperationGetName(get());
1549  MlirStringRef identStr = mlirIdentifierStr(ident);
1550  auto operationCls = PyGlobals::get().lookupOperationClass(
1551  StringRef(identStr.data, identStr.length));
1552  if (operationCls)
1553  return PyOpView::constructDerived(*operationCls, getRef().getObject());
1554  return nb::cast(PyOpView(getRef().getObject()));
1555 }
1556 
1558  checkValid();
1559  setInvalid();
1560  mlirOperationDestroy(operation);
1561 }
1562 
1563 namespace {
1564 /// CRTP base class for Python MLIR values that subclass Value and should be
1565 /// castable from it. The value hierarchy is one level deep and is not supposed
1566 /// to accommodate other levels unless core MLIR changes.
1567 template <typename DerivedTy>
1568 class PyConcreteValue : public PyValue {
1569 public:
1570  // Derived classes must define statics for:
1571  // IsAFunctionTy isaFunction
1572  // const char *pyClassName
1573  // and redefine bindDerived.
1574  using ClassTy = nb::class_<DerivedTy, PyValue>;
1575  using IsAFunctionTy = bool (*)(MlirValue);
1576 
1577  PyConcreteValue() = default;
1578  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1579  : PyValue(operationRef, value) {}
1580  PyConcreteValue(PyValue &orig)
1581  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1582 
1583  /// Attempts to cast the original value to the derived type and throws on
1584  /// type mismatches.
1585  static MlirValue castFrom(PyValue &orig) {
1586  if (!DerivedTy::isaFunction(orig.get())) {
1587  auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1588  throw nb::value_error((Twine("Cannot cast value to ") +
1589  DerivedTy::pyClassName + " (from " + origRepr +
1590  ")")
1591  .str()
1592  .c_str());
1593  }
1594  return orig.get();
1595  }
1596 
1597  /// Binds the Python module objects to functions of this class.
1598  static void bind(nb::module_ &m) {
1599  auto cls = ClassTy(m, DerivedTy::pyClassName);
1600  cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1601  cls.def_static(
1602  "isinstance",
1603  [](PyValue &otherValue) -> bool {
1604  return DerivedTy::isaFunction(otherValue);
1605  },
1606  nb::arg("other_value"));
1608  [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
1609  return self.maybeDownCast();
1610  });
1611  DerivedTy::bindDerived(cls);
1612  }
1613 
1614  /// Implemented by derived classes to add methods to the Python subclass.
1615  static void bindDerived(ClassTy &m) {}
1616 };
1617 
1618 } // namespace
1619 
1620 /// Python wrapper for MlirOpResult.
1621 class PyOpResult : public PyConcreteValue<PyOpResult> {
1622 public:
1623  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1624  static constexpr const char *pyClassName = "OpResult";
1625  using PyConcreteValue::PyConcreteValue;
1626 
1627  static void bindDerived(ClassTy &c) {
1628  c.def_prop_ro(
1629  "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
1630  assert(mlirOperationEqual(self.getParentOperation()->get(),
1631  mlirOpResultGetOwner(self.get())) &&
1632  "expected the owner of the value in Python to match that in "
1633  "the IR");
1634  return self.getParentOperation().getObject();
1635  });
1636  c.def_prop_ro("result_number", [](PyOpResult &self) {
1637  return mlirOpResultGetResultNumber(self.get());
1638  });
1639  }
1640 };
1641 
1642 /// Returns the list of types of the values held by container.
1643 template <typename Container>
1644 static std::vector<nb::typed<nb::object, PyType>>
1645 getValueTypes(Container &container, PyMlirContextRef &context) {
1646  std::vector<nb::typed<nb::object, PyType>> result;
1647  result.reserve(container.size());
1648  for (int i = 0, e = container.size(); i < e; ++i) {
1649  result.push_back(PyType(context->getRef(),
1650  mlirValueGetType(container.getElement(i).get()))
1651  .maybeDownCast());
1652  }
1653  return result;
1654 }
1655 
1656 /// A list of operation results. Internally, these are stored as consecutive
1657 /// elements, random access is cheap. The (returned) result list is associated
1658 /// with the operation whose results these are, and thus extends the lifetime of
1659 /// this operation.
1660 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1661 public:
1662  static constexpr const char *pyClassName = "OpResultList";
1664 
1665  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1666  intptr_t length = -1, intptr_t step = 1)
1667  : Sliceable(startIndex,
1668  length == -1 ? mlirOperationGetNumResults(operation->get())
1669  : length,
1670  step),
1671  operation(std::move(operation)) {}
1672 
1673  static void bindDerived(ClassTy &c) {
1674  c.def_prop_ro("types", [](PyOpResultList &self) {
1675  return getValueTypes(self, self.operation->getContext());
1676  });
1677  c.def_prop_ro("owner",
1678  [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1679  return self.operation->createOpView();
1680  });
1681  }
1682 
1683  PyOperationRef &getOperation() { return operation; }
1684 
1685 private:
1686  /// Give the parent CRTP class access to hook implementations below.
1687  friend class Sliceable<PyOpResultList, PyOpResult>;
1688 
1689  intptr_t getRawNumElements() {
1690  operation->checkValid();
1691  return mlirOperationGetNumResults(operation->get());
1692  }
1693 
1694  PyOpResult getRawElement(intptr_t index) {
1695  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1696  return PyOpResult(value);
1697  }
1698 
1699  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1700  return PyOpResultList(operation, startIndex, length, step);
1701  }
1702 
1703  PyOperationRef operation;
1704 };
1705 
1706 //------------------------------------------------------------------------------
1707 // PyOpView
1708 //------------------------------------------------------------------------------
1709 
1710 static void populateResultTypes(StringRef name, nb::list resultTypeList,
1711  const nb::object &resultSegmentSpecObj,
1712  std::vector<int32_t> &resultSegmentLengths,
1713  std::vector<PyType *> &resultTypes) {
1714  resultTypes.reserve(resultTypeList.size());
1715  if (resultSegmentSpecObj.is_none()) {
1716  // Non-variadic result unpacking.
1717  for (const auto &it : llvm::enumerate(resultTypeList)) {
1718  try {
1719  resultTypes.push_back(nb::cast<PyType *>(it.value()));
1720  if (!resultTypes.back())
1721  throw nb::cast_error();
1722  } catch (nb::cast_error &err) {
1723  throw nb::value_error((llvm::Twine("Result ") +
1724  llvm::Twine(it.index()) + " of operation \"" +
1725  name + "\" must be a Type (" + err.what() + ")")
1726  .str()
1727  .c_str());
1728  }
1729  }
1730  } else {
1731  // Sized result unpacking.
1732  auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1733  if (resultSegmentSpec.size() != resultTypeList.size()) {
1734  throw nb::value_error((llvm::Twine("Operation \"") + name +
1735  "\" requires " +
1736  llvm::Twine(resultSegmentSpec.size()) +
1737  " result segments but was provided " +
1738  llvm::Twine(resultTypeList.size()))
1739  .str()
1740  .c_str());
1741  }
1742  resultSegmentLengths.reserve(resultTypeList.size());
1743  for (const auto &it :
1744  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1745  int segmentSpec = std::get<1>(it.value());
1746  if (segmentSpec == 1 || segmentSpec == 0) {
1747  // Unpack unary element.
1748  try {
1749  auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1750  if (resultType) {
1751  resultTypes.push_back(resultType);
1752  resultSegmentLengths.push_back(1);
1753  } else if (segmentSpec == 0) {
1754  // Allowed to be optional.
1755  resultSegmentLengths.push_back(0);
1756  } else {
1757  throw nb::value_error(
1758  (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1759  " of operation \"" + name +
1760  "\" must be a Type (was None and result is not optional)")
1761  .str()
1762  .c_str());
1763  }
1764  } catch (nb::cast_error &err) {
1765  throw nb::value_error((llvm::Twine("Result ") +
1766  llvm::Twine(it.index()) + " of operation \"" +
1767  name + "\" must be a Type (" + err.what() +
1768  ")")
1769  .str()
1770  .c_str());
1771  }
1772  } else if (segmentSpec == -1) {
1773  // Unpack sequence by appending.
1774  try {
1775  if (std::get<0>(it.value()).is_none()) {
1776  // Treat it as an empty list.
1777  resultSegmentLengths.push_back(0);
1778  } else {
1779  // Unpack the list.
1780  auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1781  for (nb::handle segmentItem : segment) {
1782  resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1783  if (!resultTypes.back()) {
1784  throw nb::type_error("contained a None item");
1785  }
1786  }
1787  resultSegmentLengths.push_back(nb::len(segment));
1788  }
1789  } catch (std::exception &err) {
1790  // NOTE: Sloppy to be using a catch-all here, but there are at least
1791  // three different unrelated exceptions that can be thrown in the
1792  // above "casts". Just keep the scope above small and catch them all.
1793  throw nb::value_error((llvm::Twine("Result ") +
1794  llvm::Twine(it.index()) + " of operation \"" +
1795  name + "\" must be a Sequence of Types (" +
1796  err.what() + ")")
1797  .str()
1798  .c_str());
1799  }
1800  } else {
1801  throw nb::value_error("Unexpected segment spec");
1802  }
1803  }
1804  }
1805 }
1806 
1807 static MlirValue getUniqueResult(MlirOperation operation) {
1808  auto numResults = mlirOperationGetNumResults(operation);
1809  if (numResults != 1) {
1810  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1811  throw nb::value_error((Twine("Cannot call .result on operation ") +
1812  StringRef(name.data, name.length) + " which has " +
1813  Twine(numResults) +
1814  " results (it is only valid for operations with a "
1815  "single result)")
1816  .str()
1817  .c_str());
1818  }
1819  return mlirOperationGetResult(operation, 0);
1820 }
1821 
1822 static MlirValue getOpResultOrValue(nb::handle operand) {
1823  if (operand.is_none()) {
1824  throw nb::value_error("contained a None item");
1825  }
1826  PyOperationBase *op;
1827  if (nb::try_cast<PyOperationBase *>(operand, op)) {
1828  return getUniqueResult(op->getOperation());
1829  }
1830  PyOpResultList *opResultList;
1831  if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1832  return getUniqueResult(opResultList->getOperation()->get());
1833  }
1834  PyValue *value;
1835  if (nb::try_cast<PyValue *>(operand, value)) {
1836  return value->get();
1837  }
1838  throw nb::value_error("is not a Value");
1839 }
1840 
1842  std::string_view name, std::tuple<int, bool> opRegionSpec,
1843  nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1844  std::optional<nb::list> resultTypeList, nb::list operandList,
1845  std::optional<nb::dict> attributes,
1846  std::optional<std::vector<PyBlock *>> successors,
1847  std::optional<int> regions, PyLocation &location,
1848  const nb::object &maybeIp) {
1849  PyMlirContextRef context = location.getContext();
1850 
1851  // Class level operation construction metadata.
1852  // Operand and result segment specs are either none, which does no
1853  // variadic unpacking, or a list of ints with segment sizes, where each
1854  // element is either a positive number (typically 1 for a scalar) or -1 to
1855  // indicate that it is derived from the length of the same-indexed operand
1856  // or result (implying that it is a list at that position).
1857  std::vector<int32_t> operandSegmentLengths;
1858  std::vector<int32_t> resultSegmentLengths;
1859 
1860  // Validate/determine region count.
1861  int opMinRegionCount = std::get<0>(opRegionSpec);
1862  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1863  if (!regions) {
1864  regions = opMinRegionCount;
1865  }
1866  if (*regions < opMinRegionCount) {
1867  throw nb::value_error(
1868  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1869  llvm::Twine(opMinRegionCount) +
1870  " regions but was built with regions=" + llvm::Twine(*regions))
1871  .str()
1872  .c_str());
1873  }
1874  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1875  throw nb::value_error(
1876  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1877  llvm::Twine(opMinRegionCount) +
1878  " regions but was built with regions=" + llvm::Twine(*regions))
1879  .str()
1880  .c_str());
1881  }
1882 
1883  // Unpack results.
1884  std::vector<PyType *> resultTypes;
1885  if (resultTypeList.has_value()) {
1886  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1887  resultSegmentLengths, resultTypes);
1888  }
1889 
1890  // Unpack operands.
1892  operands.reserve(operands.size());
1893  if (operandSegmentSpecObj.is_none()) {
1894  // Non-sized operand unpacking.
1895  for (const auto &it : llvm::enumerate(operandList)) {
1896  try {
1897  operands.push_back(getOpResultOrValue(it.value()));
1898  } catch (nb::builtin_exception &err) {
1899  throw nb::value_error((llvm::Twine("Operand ") +
1900  llvm::Twine(it.index()) + " of operation \"" +
1901  name + "\" must be a Value (" + err.what() + ")")
1902  .str()
1903  .c_str());
1904  }
1905  }
1906  } else {
1907  // Sized operand unpacking.
1908  auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1909  if (operandSegmentSpec.size() != operandList.size()) {
1910  throw nb::value_error((llvm::Twine("Operation \"") + name +
1911  "\" requires " +
1912  llvm::Twine(operandSegmentSpec.size()) +
1913  "operand segments but was provided " +
1914  llvm::Twine(operandList.size()))
1915  .str()
1916  .c_str());
1917  }
1918  operandSegmentLengths.reserve(operandList.size());
1919  for (const auto &it :
1920  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1921  int segmentSpec = std::get<1>(it.value());
1922  if (segmentSpec == 1 || segmentSpec == 0) {
1923  // Unpack unary element.
1924  auto &operand = std::get<0>(it.value());
1925  if (!operand.is_none()) {
1926  try {
1927 
1928  operands.push_back(getOpResultOrValue(operand));
1929  } catch (nb::builtin_exception &err) {
1930  throw nb::value_error((llvm::Twine("Operand ") +
1931  llvm::Twine(it.index()) +
1932  " of operation \"" + name +
1933  "\" must be a Value (" + err.what() + ")")
1934  .str()
1935  .c_str());
1936  }
1937 
1938  operandSegmentLengths.push_back(1);
1939  } else if (segmentSpec == 0) {
1940  // Allowed to be optional.
1941  operandSegmentLengths.push_back(0);
1942  } else {
1943  throw nb::value_error(
1944  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
1945  " of operation \"" + name +
1946  "\" must be a Value (was None and operand is not optional)")
1947  .str()
1948  .c_str());
1949  }
1950  } else if (segmentSpec == -1) {
1951  // Unpack sequence by appending.
1952  try {
1953  if (std::get<0>(it.value()).is_none()) {
1954  // Treat it as an empty list.
1955  operandSegmentLengths.push_back(0);
1956  } else {
1957  // Unpack the list.
1958  auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1959  for (nb::handle segmentItem : segment) {
1960  operands.push_back(getOpResultOrValue(segmentItem));
1961  }
1962  operandSegmentLengths.push_back(nb::len(segment));
1963  }
1964  } catch (std::exception &err) {
1965  // NOTE: Sloppy to be using a catch-all here, but there are at least
1966  // three different unrelated exceptions that can be thrown in the
1967  // above "casts". Just keep the scope above small and catch them all.
1968  throw nb::value_error((llvm::Twine("Operand ") +
1969  llvm::Twine(it.index()) + " of operation \"" +
1970  name + "\" must be a Sequence of Values (" +
1971  err.what() + ")")
1972  .str()
1973  .c_str());
1974  }
1975  } else {
1976  throw nb::value_error("Unexpected segment spec");
1977  }
1978  }
1979  }
1980 
1981  // Merge operand/result segment lengths into attributes if needed.
1982  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1983  // Dup.
1984  if (attributes) {
1985  attributes = nb::dict(*attributes);
1986  } else {
1987  attributes = nb::dict();
1988  }
1989  if (attributes->contains("resultSegmentSizes") ||
1990  attributes->contains("operandSegmentSizes")) {
1991  throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1992  "'operandSegmentSizes' attribute is unsupported. "
1993  "Use Operation.create for such low-level access.");
1994  }
1995 
1996  // Add resultSegmentSizes attribute.
1997  if (!resultSegmentLengths.empty()) {
1998  MlirAttribute segmentLengthAttr =
1999  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
2000  resultSegmentLengths.data());
2001  (*attributes)["resultSegmentSizes"] =
2002  PyAttribute(context, segmentLengthAttr);
2003  }
2004 
2005  // Add operandSegmentSizes attribute.
2006  if (!operandSegmentLengths.empty()) {
2007  MlirAttribute segmentLengthAttr =
2008  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
2009  operandSegmentLengths.data());
2010  (*attributes)["operandSegmentSizes"] =
2011  PyAttribute(context, segmentLengthAttr);
2012  }
2013  }
2014 
2015  // Delegate to create.
2016  return PyOperation::create(name,
2017  /*results=*/std::move(resultTypes),
2018  /*operands=*/operands,
2019  /*attributes=*/std::move(attributes),
2020  /*successors=*/std::move(successors),
2021  /*regions=*/*regions, location, maybeIp,
2022  !resultTypeList);
2023 }
2024 
2025 nb::object PyOpView::constructDerived(const nb::object &cls,
2026  const nb::object &operation) {
2027  nb::handle opViewType = nb::type<PyOpView>();
2028  nb::object instance = cls.attr("__new__")(cls);
2029  opViewType.attr("__init__")(instance, operation);
2030  return instance;
2031 }
2032 
2033 PyOpView::PyOpView(const nb::object &operationObject)
2034  // Casting through the PyOperationBase base-class and then back to the
2035  // Operation lets us accept any PyOperationBase subclass.
2036  : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2037  operationObject(operation.getRef().getObject()) {}
2038 
2039 //------------------------------------------------------------------------------
2040 // PyInsertionPoint.
2041 //------------------------------------------------------------------------------
2042 
2043 PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
2044 
2046  : refOperation(beforeOperationBase.getOperation().getRef()),
2047  block((*refOperation)->getBlock()) {}
2048 
2050  : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
2051 
2053  PyOperation &operation = operationBase.getOperation();
2054  if (operation.isAttached())
2055  throw nb::value_error(
2056  "Attempt to insert operation that is already attached");
2057  block.getParentOperation()->checkValid();
2058  MlirOperation beforeOp = {nullptr};
2059  if (refOperation) {
2060  // Insert before operation.
2061  (*refOperation)->checkValid();
2062  beforeOp = (*refOperation)->get();
2063  } else {
2064  // Insert at end (before null) is only valid if the block does not
2065  // already end in a known terminator (violating this will cause assertion
2066  // failures later).
2067  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2068  throw nb::index_error("Cannot insert operation at the end of a block "
2069  "that already has a terminator. Did you mean to "
2070  "use 'InsertionPoint.at_block_terminator(block)' "
2071  "versus 'InsertionPoint(block)'?");
2072  }
2073  }
2074  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2075  operation.setAttached();
2076 }
2077 
2079  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2080  if (mlirOperationIsNull(firstOp)) {
2081  // Just insert at end.
2082  return PyInsertionPoint(block);
2083  }
2084 
2085  // Insert before first op.
2087  block.getParentOperation()->getContext(), firstOp);
2088  return PyInsertionPoint{block, std::move(firstOpRef)};
2089 }
2090 
2092  MlirOperation terminator = mlirBlockGetTerminator(block.get());
2093  if (mlirOperationIsNull(terminator))
2094  throw nb::value_error("Block has no terminator");
2095  PyOperationRef terminatorOpRef = PyOperation::forOperation(
2096  block.getParentOperation()->getContext(), terminator);
2097  return PyInsertionPoint{block, std::move(terminatorOpRef)};
2098 }
2099 
2101  PyOperation &operation = op.getOperation();
2102  PyBlock block = operation.getBlock();
2103  MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
2104  if (mlirOperationIsNull(nextOperation))
2105  return PyInsertionPoint(block);
2107  block.getParentOperation()->getContext(), nextOperation);
2108  return PyInsertionPoint{block, std::move(nextOpRef)};
2109 }
2110 
2111 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
2112 
2113 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2114  return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
2115 }
2116 
2117 void PyInsertionPoint::contextExit(const nb::object &excType,
2118  const nb::object &excVal,
2119  const nb::object &excTb) {
2121 }
2122 
2123 //------------------------------------------------------------------------------
2124 // PyAttribute.
2125 //------------------------------------------------------------------------------
2126 
2127 bool PyAttribute::operator==(const PyAttribute &other) const {
2128  return mlirAttributeEqual(attr, other.attr);
2129 }
2130 
2132  return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2133 }
2134 
2135 PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
2136  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2137  if (mlirAttributeIsNull(rawAttr))
2138  throw nb::python_error();
2139  return PyAttribute(
2141 }
2142 
2144  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
2145  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2146  "mlirTypeID was expected to be non-null.");
2147  std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
2148  mlirTypeID, mlirAttributeGetDialect(this->get()));
2149  // nb::rv_policy::move means use std::move to move the return value
2150  // contents into a new instance that will be owned by Python.
2151  nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2152  if (!typeCaster)
2153  return thisObj;
2154  return typeCaster.value()(thisObj);
2155 }
2156 
2157 //------------------------------------------------------------------------------
2158 // PyNamedAttribute.
2159 //------------------------------------------------------------------------------
2160 
2161 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2162  : ownedName(new std::string(std::move(ownedName))) {
2165  toMlirStringRef(*this->ownedName)),
2166  attr);
2167 }
2168 
2169 //------------------------------------------------------------------------------
2170 // PyType.
2171 //------------------------------------------------------------------------------
2172 
2173 bool PyType::operator==(const PyType &other) const {
2174  return mlirTypeEqual(type, other.type);
2175 }
2176 
2177 nb::object PyType::getCapsule() {
2178  return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2179 }
2180 
2181 PyType PyType::createFromCapsule(nb::object capsule) {
2182  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2183  if (mlirTypeIsNull(rawType))
2184  throw nb::python_error();
2186  rawType);
2187 }
2188 
2190  MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
2191  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2192  "mlirTypeID was expected to be non-null.");
2193  std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
2194  mlirTypeID, mlirTypeGetDialect(this->get()));
2195  // nb::rv_policy::move means use std::move to move the return value
2196  // contents into a new instance that will be owned by Python.
2197  nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2198  if (!typeCaster)
2199  return thisObj;
2200  return typeCaster.value()(thisObj);
2201 }
2202 
2203 //------------------------------------------------------------------------------
2204 // PyTypeID.
2205 //------------------------------------------------------------------------------
2206 
2207 nb::object PyTypeID::getCapsule() {
2208  return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2209 }
2210 
2212  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2213  if (mlirTypeIDIsNull(mlirTypeID))
2214  throw nb::python_error();
2215  return PyTypeID(mlirTypeID);
2216 }
2217 bool PyTypeID::operator==(const PyTypeID &other) const {
2218  return mlirTypeIDEqual(typeID, other.typeID);
2219 }
2220 
2221 //------------------------------------------------------------------------------
2222 // PyValue and subclasses.
2223 //------------------------------------------------------------------------------
2224 
2225 nb::object PyValue::getCapsule() {
2226  return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2227 }
2228 
2230  MlirType type = mlirValueGetType(get());
2231  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2232  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2233  "mlirTypeID was expected to be non-null.");
2234  std::optional<nb::callable> valueCaster =
2236  // nb::rv_policy::move means use std::move to move the return value
2237  // contents into a new instance that will be owned by Python.
2238  nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2239  if (!valueCaster)
2240  return thisObj;
2241  return valueCaster.value()(thisObj);
2242 }
2243 
2245  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2246  if (mlirValueIsNull(value))
2247  throw nb::python_error();
2248  MlirOperation owner;
2249  if (mlirValueIsAOpResult(value))
2250  owner = mlirOpResultGetOwner(value);
2251  if (mlirValueIsABlockArgument(value))
2253  if (mlirOperationIsNull(owner))
2254  throw nb::python_error();
2255  MlirContext ctx = mlirOperationGetContext(owner);
2256  PyOperationRef ownerRef =
2258  return PyValue(ownerRef, value);
2259 }
2260 
2261 //------------------------------------------------------------------------------
2262 // PySymbolTable.
2263 //------------------------------------------------------------------------------
2264 
2266  : operation(operation.getOperation().getRef()) {
2267  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2268  if (mlirSymbolTableIsNull(symbolTable)) {
2269  throw nb::type_error("Operation is not a Symbol Table.");
2270  }
2271 }
2272 
2273 nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2274  operation->checkValid();
2275  MlirOperation symbol = mlirSymbolTableLookup(
2276  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2277  if (mlirOperationIsNull(symbol))
2278  throw nb::key_error(
2279  ("Symbol '" + name + "' not in the symbol table.").c_str());
2280 
2281  return PyOperation::forOperation(operation->getContext(), symbol,
2282  operation.getObject())
2283  ->createOpView();
2284 }
2285 
2287  operation->checkValid();
2288  symbol.getOperation().checkValid();
2289  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2290  // The operation is also erased, so we must invalidate it. There may be Python
2291  // references to this operation so we don't want to delete it from the list of
2292  // live operations here.
2293  symbol.getOperation().valid = false;
2294 }
2295 
2296 void PySymbolTable::dunderDel(const std::string &name) {
2297  nb::object operation = dunderGetItem(name);
2298  erase(nb::cast<PyOperationBase &>(operation));
2299 }
2300 
2302  operation->checkValid();
2303  symbol.getOperation().checkValid();
2304  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2306  if (mlirAttributeIsNull(symbolAttr))
2307  throw nb::value_error("Expected operation to have a symbol name.");
2308  return PyStringAttribute(
2309  symbol.getOperation().getContext(),
2310  mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2311 }
2312 
2314  // Op must already be a symbol.
2315  PyOperation &operation = symbol.getOperation();
2316  operation.checkValid();
2318  MlirAttribute existingNameAttr =
2319  mlirOperationGetAttributeByName(operation.get(), attrName);
2320  if (mlirAttributeIsNull(existingNameAttr))
2321  throw nb::value_error("Expected operation to have a symbol name.");
2322  return PyStringAttribute(symbol.getOperation().getContext(),
2323  existingNameAttr);
2324 }
2325 
2327  const std::string &name) {
2328  // Op must already be a symbol.
2329  PyOperation &operation = symbol.getOperation();
2330  operation.checkValid();
2332  MlirAttribute existingNameAttr =
2333  mlirOperationGetAttributeByName(operation.get(), attrName);
2334  if (mlirAttributeIsNull(existingNameAttr))
2335  throw nb::value_error("Expected operation to have a symbol name.");
2336  MlirAttribute newNameAttr =
2337  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2338  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2339 }
2340 
2342  PyOperation &operation = symbol.getOperation();
2343  operation.checkValid();
2345  MlirAttribute existingVisAttr =
2346  mlirOperationGetAttributeByName(operation.get(), attrName);
2347  if (mlirAttributeIsNull(existingVisAttr))
2348  throw nb::value_error("Expected operation to have a symbol visibility.");
2349  return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2350 }
2351 
2353  const std::string &visibility) {
2354  if (visibility != "public" && visibility != "private" &&
2355  visibility != "nested")
2356  throw nb::value_error(
2357  "Expected visibility to be 'public', 'private' or 'nested'");
2358  PyOperation &operation = symbol.getOperation();
2359  operation.checkValid();
2361  MlirAttribute existingVisAttr =
2362  mlirOperationGetAttributeByName(operation.get(), attrName);
2363  if (mlirAttributeIsNull(existingVisAttr))
2364  throw nb::value_error("Expected operation to have a symbol visibility.");
2365  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2366  toMlirStringRef(visibility));
2367  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2368 }
2369 
2370 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2371  const std::string &newSymbol,
2372  PyOperationBase &from) {
2373  PyOperation &fromOperation = from.getOperation();
2374  fromOperation.checkValid();
2376  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2377  from.getOperation())))
2378 
2379  throw nb::value_error("Symbol rename failed");
2380 }
2381 
2383  bool allSymUsesVisible,
2384  nb::object callback) {
2385  PyOperation &fromOperation = from.getOperation();
2386  fromOperation.checkValid();
2387  struct UserData {
2388  PyMlirContextRef context;
2389  nb::object callback;
2390  bool gotException;
2391  std::string exceptionWhat;
2392  nb::object exceptionType;
2393  };
2394  UserData userData{
2395  fromOperation.getContext(), std::move(callback), false, {}, {}};
2397  fromOperation.get(), allSymUsesVisible,
2398  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2399  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2400  auto pyFoundOp =
2401  PyOperation::forOperation(calleeUserData->context, foundOp);
2402  if (calleeUserData->gotException)
2403  return;
2404  try {
2405  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2406  } catch (nb::python_error &e) {
2407  calleeUserData->gotException = true;
2408  calleeUserData->exceptionWhat = e.what();
2409  calleeUserData->exceptionType = nb::borrow(e.type());
2410  }
2411  },
2412  static_cast<void *>(&userData));
2413  if (userData.gotException) {
2414  std::string message("Exception raised in callback: ");
2415  message.append(userData.exceptionWhat);
2416  throw std::runtime_error(message);
2417  }
2418 }
2419 
2420 namespace {
2421 
2422 /// Python wrapper for MlirBlockArgument.
2423 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2424 public:
2425  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2426  static constexpr const char *pyClassName = "BlockArgument";
2427  using PyConcreteValue::PyConcreteValue;
2428 
2429  static void bindDerived(ClassTy &c) {
2430  c.def_prop_ro("owner", [](PyBlockArgument &self) {
2431  return PyBlock(self.getParentOperation(),
2432  mlirBlockArgumentGetOwner(self.get()));
2433  });
2434  c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2435  return mlirBlockArgumentGetArgNumber(self.get());
2436  });
2437  c.def(
2438  "set_type",
2439  [](PyBlockArgument &self, PyType type) {
2440  return mlirBlockArgumentSetType(self.get(), type);
2441  },
2442  nb::arg("type"));
2443  }
2444 };
2445 
2446 /// A list of block arguments. Internally, these are stored as consecutive
2447 /// elements, random access is cheap. The argument list is associated with the
2448 /// operation that contains the block (detached blocks are not allowed in
2449 /// Python bindings) and extends its lifetime.
2450 class PyBlockArgumentList
2451  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2452 public:
2453  static constexpr const char *pyClassName = "BlockArgumentList";
2455 
2456  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2457  intptr_t startIndex = 0, intptr_t length = -1,
2458  intptr_t step = 1)
2459  : Sliceable(startIndex,
2460  length == -1 ? mlirBlockGetNumArguments(block) : length,
2461  step),
2462  operation(std::move(operation)), block(block) {}
2463 
2464  static void bindDerived(ClassTy &c) {
2465  c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2466  return getValueTypes(self, self.operation->getContext());
2467  });
2468  }
2469 
2470 private:
2471  /// Give the parent CRTP class access to hook implementations below.
2472  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2473 
2474  /// Returns the number of arguments in the list.
2475  intptr_t getRawNumElements() {
2476  operation->checkValid();
2477  return mlirBlockGetNumArguments(block);
2478  }
2479 
2480  /// Returns `pos`-the element in the list.
2481  PyBlockArgument getRawElement(intptr_t pos) {
2482  MlirValue argument = mlirBlockGetArgument(block, pos);
2483  return PyBlockArgument(operation, argument);
2484  }
2485 
2486  /// Returns a sublist of this list.
2487  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2488  intptr_t step) {
2489  return PyBlockArgumentList(operation, block, startIndex, length, step);
2490  }
2491 
2492  PyOperationRef operation;
2493  MlirBlock block;
2494 };
2495 
2496 /// A list of operation operands. Internally, these are stored as consecutive
2497 /// elements, random access is cheap. The (returned) operand list is associated
2498 /// with the operation whose operands these are, and thus extends the lifetime
2499 /// of this operation.
2500 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2501 public:
2502  static constexpr const char *pyClassName = "OpOperandList";
2503  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2504 
2505  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2506  intptr_t length = -1, intptr_t step = 1)
2507  : Sliceable(startIndex,
2508  length == -1 ? mlirOperationGetNumOperands(operation->get())
2509  : length,
2510  step),
2511  operation(operation) {}
2512 
2513  void dunderSetItem(intptr_t index, PyValue value) {
2514  index = wrapIndex(index);
2515  mlirOperationSetOperand(operation->get(), index, value.get());
2516  }
2517 
2518  static void bindDerived(ClassTy &c) {
2519  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2520  }
2521 
2522 private:
2523  /// Give the parent CRTP class access to hook implementations below.
2524  friend class Sliceable<PyOpOperandList, PyValue>;
2525 
2526  intptr_t getRawNumElements() {
2527  operation->checkValid();
2528  return mlirOperationGetNumOperands(operation->get());
2529  }
2530 
2531  PyValue getRawElement(intptr_t pos) {
2532  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2533  MlirOperation owner;
2534  if (mlirValueIsAOpResult(operand))
2535  owner = mlirOpResultGetOwner(operand);
2536  else if (mlirValueIsABlockArgument(operand))
2538  else
2539  assert(false && "Value must be an block arg or op result.");
2540  PyOperationRef pyOwner =
2541  PyOperation::forOperation(operation->getContext(), owner);
2542  return PyValue(pyOwner, operand);
2543  }
2544 
2545  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2546  return PyOpOperandList(operation, startIndex, length, step);
2547  }
2548 
2549  PyOperationRef operation;
2550 };
2551 
2552 /// A list of operation successors. Internally, these are stored as consecutive
2553 /// elements, random access is cheap. The (returned) successor list is
2554 /// associated with the operation whose successors these are, and thus extends
2555 /// the lifetime of this operation.
2556 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2557 public:
2558  static constexpr const char *pyClassName = "OpSuccessors";
2559 
2560  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2561  intptr_t length = -1, intptr_t step = 1)
2562  : Sliceable(startIndex,
2563  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2564  : length,
2565  step),
2566  operation(operation) {}
2567 
2568  void dunderSetItem(intptr_t index, PyBlock block) {
2569  index = wrapIndex(index);
2570  mlirOperationSetSuccessor(operation->get(), index, block.get());
2571  }
2572 
2573  static void bindDerived(ClassTy &c) {
2574  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2575  }
2576 
2577 private:
2578  /// Give the parent CRTP class access to hook implementations below.
2579  friend class Sliceable<PyOpSuccessors, PyBlock>;
2580 
2581  intptr_t getRawNumElements() {
2582  operation->checkValid();
2583  return mlirOperationGetNumSuccessors(operation->get());
2584  }
2585 
2586  PyBlock getRawElement(intptr_t pos) {
2587  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2588  return PyBlock(operation, block);
2589  }
2590 
2591  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2592  return PyOpSuccessors(operation, startIndex, length, step);
2593  }
2594 
2595  PyOperationRef operation;
2596 };
2597 
2598 /// A list of block successors. Internally, these are stored as consecutive
2599 /// elements, random access is cheap. The (returned) successor list is
2600 /// associated with the operation and block whose successors these are, and thus
2601 /// extends the lifetime of this operation and block.
2602 class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2603 public:
2604  static constexpr const char *pyClassName = "BlockSuccessors";
2605 
2606  PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2607  intptr_t startIndex = 0, intptr_t length = -1,
2608  intptr_t step = 1)
2609  : Sliceable(startIndex,
2610  length == -1 ? mlirBlockGetNumSuccessors(block.get())
2611  : length,
2612  step),
2613  operation(operation), block(block) {}
2614 
2615 private:
2616  /// Give the parent CRTP class access to hook implementations below.
2617  friend class Sliceable<PyBlockSuccessors, PyBlock>;
2618 
2619  intptr_t getRawNumElements() {
2620  block.checkValid();
2621  return mlirBlockGetNumSuccessors(block.get());
2622  }
2623 
2624  PyBlock getRawElement(intptr_t pos) {
2625  MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2626  return PyBlock(operation, block);
2627  }
2628 
2629  PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2630  return PyBlockSuccessors(block, operation, startIndex, length, step);
2631  }
2632 
2633  PyOperationRef operation;
2634  PyBlock block;
2635 };
2636 
2637 /// A list of block predecessors. The (returned) predecessor list is
2638 /// associated with the operation and block whose predecessors these are, and
2639 /// thus extends the lifetime of this operation and block.
2640 ///
2641 /// WARNING: This Sliceable is more expensive than the others here because
2642 /// mlirBlockGetPredecessor actually iterates the use-def chain (of block
2643 /// operands) anew for each indexed access.
2644 class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2645 public:
2646  static constexpr const char *pyClassName = "BlockPredecessors";
2647 
2648  PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2649  intptr_t startIndex = 0, intptr_t length = -1,
2650  intptr_t step = 1)
2651  : Sliceable(startIndex,
2652  length == -1 ? mlirBlockGetNumPredecessors(block.get())
2653  : length,
2654  step),
2655  operation(operation), block(block) {}
2656 
2657 private:
2658  /// Give the parent CRTP class access to hook implementations below.
2659  friend class Sliceable<PyBlockPredecessors, PyBlock>;
2660 
2661  intptr_t getRawNumElements() {
2662  block.checkValid();
2663  return mlirBlockGetNumPredecessors(block.get());
2664  }
2665 
2666  PyBlock getRawElement(intptr_t pos) {
2667  MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2668  return PyBlock(operation, block);
2669  }
2670 
2671  PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2672  intptr_t step) {
2673  return PyBlockPredecessors(block, operation, startIndex, length, step);
2674  }
2675 
2676  PyOperationRef operation;
2677  PyBlock block;
2678 };
2679 
2680 /// A list of operation attributes. Can be indexed by name, producing
2681 /// attributes, or by index, producing named attributes.
2682 class PyOpAttributeMap {
2683 public:
2684  PyOpAttributeMap(PyOperationRef operation)
2685  : operation(std::move(operation)) {}
2686 
2687  nb::typed<nb::object, PyAttribute>
2688  dunderGetItemNamed(const std::string &name) {
2689  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2690  toMlirStringRef(name));
2691  if (mlirAttributeIsNull(attr)) {
2692  throw nb::key_error("attempt to access a non-existent attribute");
2693  }
2694  return PyAttribute(operation->getContext(), attr).maybeDownCast();
2695  }
2696 
2697  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2698  if (index < 0) {
2699  index += dunderLen();
2700  }
2701  if (index < 0 || index >= dunderLen()) {
2702  throw nb::index_error("attempt to access out of bounds attribute");
2703  }
2704  MlirNamedAttribute namedAttr =
2705  mlirOperationGetAttribute(operation->get(), index);
2706  return PyNamedAttribute(
2707  namedAttr.attribute,
2708  std::string(mlirIdentifierStr(namedAttr.name).data,
2709  mlirIdentifierStr(namedAttr.name).length));
2710  }
2711 
2712  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2714  attr);
2715  }
2716 
2717  void dunderDelItem(const std::string &name) {
2718  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2719  toMlirStringRef(name));
2720  if (!removed)
2721  throw nb::key_error("attempt to delete a non-existent attribute");
2722  }
2723 
2724  intptr_t dunderLen() {
2725  return mlirOperationGetNumAttributes(operation->get());
2726  }
2727 
2728  bool dunderContains(const std::string &name) {
2730  operation->get(), toMlirStringRef(name)));
2731  }
2732 
2733  static void bind(nb::module_ &m) {
2734  nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2735  .def("__contains__", &PyOpAttributeMap::dunderContains)
2736  .def("__len__", &PyOpAttributeMap::dunderLen)
2737  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2738  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2739  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2740  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2741  }
2742 
2743 private:
2744  PyOperationRef operation;
2745 };
2746 
2747 // see
2748 // https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
2749 
2750 #ifndef _Py_CAST
2751 #define _Py_CAST(type, expr) ((type)(expr))
2752 #endif
2753 
2754 // Static inline functions should use _Py_NULL rather than using directly NULL
2755 // to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
2756 // _Py_NULL is defined as nullptr.
2757 #ifndef _Py_NULL
2758 #if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
2759  (defined(__cplusplus) && __cplusplus >= 201103)
2760 #define _Py_NULL nullptr
2761 #else
2762 #define _Py_NULL NULL
2763 #endif
2764 #endif
2765 
2766 // Python 3.10.0a3
2767 #if PY_VERSION_HEX < 0x030A00A3
2768 
2769 // bpo-42262 added Py_XNewRef()
2770 #if !defined(Py_XNewRef)
2771 [[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
2772  Py_XINCREF(obj);
2773  return obj;
2774 }
2775 #define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
2776 #endif
2777 
2778 // bpo-42262 added Py_NewRef()
2779 #if !defined(Py_NewRef)
2780 [[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
2781  Py_INCREF(obj);
2782  return obj;
2783 }
2784 #define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
2785 #endif
2786 
2787 #endif // Python 3.10.0a3
2788 
2789 // Python 3.9.0b1
2790 #if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
2791 
2792 // bpo-40429 added PyThreadState_GetFrame()
2793 PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
2794  assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
2795  return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
2796 }
2797 
2798 // bpo-40421 added PyFrame_GetBack()
2799 PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
2800  assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2801  return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
2802 }
2803 
2804 // bpo-40421 added PyFrame_GetCode()
2805 PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
2806  assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2807  assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
2808  return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
2809 }
2810 
2811 #endif // Python 3.9.0b1
2812 
2813 MlirLocation tracebackToLocation(MlirContext ctx) {
2814  size_t framesLimit =
2815  PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
2816  // Use a thread_local here to avoid requiring a large amount of space.
2817  thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2818  frames;
2819  size_t count = 0;
2820 
2821  nb::gil_scoped_acquire acquire;
2822  PyThreadState *tstate = PyThreadState_GET();
2823  PyFrameObject *next;
2824  PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2825  // In the increment expression:
2826  // 1. get the next prev frame;
2827  // 2. decrement the ref count on the current frame (in order that it can get
2828  // gc'd, along with any objects in its closure and etc);
2829  // 3. set current = next.
2830  for (; pyFrame != nullptr && count < framesLimit;
2831  next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2832  PyCodeObject *code = PyFrame_GetCode(pyFrame);
2833  auto fileNameStr =
2834  nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2835  llvm::StringRef fileName(fileNameStr);
2836  if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2837  continue;
2838 
2839  // co_qualname and PyCode_Addr2Location added in py3.11
2840 #if PY_VERSION_HEX < 0x030B00F0
2841  std::string name =
2842  nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2843  llvm::StringRef funcName(name);
2844  int startLine = PyFrame_GetLineNumber(pyFrame);
2845  MlirLocation loc =
2846  mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2847 #else
2848  std::string name =
2849  nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2850  llvm::StringRef funcName(name);
2851  int startLine, startCol, endLine, endCol;
2852  int lasti = PyFrame_GetLasti(pyFrame);
2853  if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2854  &endCol)) {
2855  throw nb::python_error();
2856  }
2857  MlirLocation loc = mlirLocationFileLineColRangeGet(
2858  ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2859 #endif
2860 
2861  frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2862  ++count;
2863  }
2864  // When the loop breaks (after the last iter), current frame (if non-null)
2865  // is leaked without this.
2866  Py_XDECREF(pyFrame);
2867 
2868  if (count == 0)
2869  return mlirLocationUnknownGet(ctx);
2870 
2871  MlirLocation callee = frames[0];
2872  assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2873  if (count == 1)
2874  return callee;
2875 
2876  MlirLocation caller = frames[count - 1];
2877  assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2878  for (int i = count - 2; i >= 1; i--)
2879  caller = mlirLocationCallSiteGet(frames[i], caller);
2880 
2881  return mlirLocationCallSiteGet(callee, caller);
2882 }
2883 
2884 PyLocation
2885 maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2886  if (location.has_value())
2887  return location.value();
2888  if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2889  return DefaultingPyLocation::resolve();
2890 
2891  PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2892  MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2893  PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
2894  return {ref, mlirLoc};
2895 }
2896 
2897 } // namespace
2898 
2899 //------------------------------------------------------------------------------
2900 // Populates the core exports of the 'ir' submodule.
2901 //------------------------------------------------------------------------------
2902 
2903 void mlir::python::populateIRCore(nb::module_ &m) {
2904  // disable leak warnings which tend to be false positives.
2905  nb::set_leak_warnings(false);
2906  //----------------------------------------------------------------------------
2907  // Enums.
2908  //----------------------------------------------------------------------------
2909  nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2910  .value("ERROR", MlirDiagnosticError)
2911  .value("WARNING", MlirDiagnosticWarning)
2912  .value("NOTE", MlirDiagnosticNote)
2913  .value("REMARK", MlirDiagnosticRemark);
2914 
2915  nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2916  .value("PRE_ORDER", MlirWalkPreOrder)
2917  .value("POST_ORDER", MlirWalkPostOrder);
2918 
2919  nb::enum_<MlirWalkResult>(m, "WalkResult")
2920  .value("ADVANCE", MlirWalkResultAdvance)
2921  .value("INTERRUPT", MlirWalkResultInterrupt)
2922  .value("SKIP", MlirWalkResultSkip);
2923 
2924  //----------------------------------------------------------------------------
2925  // Mapping of Diagnostics.
2926  //----------------------------------------------------------------------------
2927  nb::class_<PyDiagnostic>(m, "Diagnostic")
2928  .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2929  .def_prop_ro("location", &PyDiagnostic::getLocation)
2930  .def_prop_ro("message", &PyDiagnostic::getMessage)
2931  .def_prop_ro("notes", &PyDiagnostic::getNotes)
2932  .def("__str__", [](PyDiagnostic &self) -> nb::str {
2933  if (!self.isValid())
2934  return nb::str("<Invalid Diagnostic>");
2935  return self.getMessage();
2936  });
2937 
2938  nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2939  .def("__init__",
2941  new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2942  })
2943  .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2944  .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2945  .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2946  .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2947  .def("__str__",
2948  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2949 
2950  nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2951  .def("detach", &PyDiagnosticHandler::detach)
2952  .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2953  .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2954  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2955  .def("__exit__", &PyDiagnosticHandler::contextExit,
2956  nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2957  nb::arg("traceback").none());
2958 
2959  // Expose DefaultThreadPool to python
2960  nb::class_<PyThreadPool>(m, "ThreadPool")
2961  .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2962  .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
2963  .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
2964 
2965  nb::class_<PyMlirContext>(m, "Context")
2966  .def("__init__",
2967  [](PyMlirContext &self) {
2968  MlirContext context = mlirContextCreateWithThreading(false);
2969  new (&self) PyMlirContext(context);
2970  })
2971  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2972  .def("_get_context_again",
2973  [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
2974  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2975  return ref.releaseObject();
2976  })
2977  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2978  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2979  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
2980  &PyMlirContext::createFromCapsule)
2981  .def("__enter__", &PyMlirContext::contextEnter)
2982  .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2983  nb::arg("exc_value").none(), nb::arg("traceback").none())
2984  .def_prop_ro_static(
2985  "current",
2986  [](nb::object & /*class*/)
2987  -> std::optional<nb::typed<nb::object, PyMlirContext>> {
2988  auto *context = PyThreadContextEntry::getDefaultContext();
2989  if (!context)
2990  return {};
2991  return nb::cast(context);
2992  },
2993  nb::sig("def current(/) -> Context | None"),
2994  "Gets the Context bound to the current thread or raises ValueError")
2995  .def_prop_ro(
2996  "dialects",
2997  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2998  "Gets a container for accessing dialects by name")
2999  .def_prop_ro(
3000  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3001  "Alias for 'dialect'")
3002  .def(
3003  "get_dialect_descriptor",
3004  [=](PyMlirContext &self, std::string &name) {
3005  MlirDialect dialect = mlirContextGetOrLoadDialect(
3006  self.get(), {name.data(), name.size()});
3007  if (mlirDialectIsNull(dialect)) {
3008  throw nb::value_error(
3009  (Twine("Dialect '") + name + "' not found").str().c_str());
3010  }
3011  return PyDialectDescriptor(self.getRef(), dialect);
3012  },
3013  nb::arg("dialect_name"),
3014  "Gets or loads a dialect by name, returning its descriptor object")
3015  .def_prop_rw(
3016  "allow_unregistered_dialects",
3017  [](PyMlirContext &self) -> bool {
3019  },
3020  [](PyMlirContext &self, bool value) {
3022  })
3023  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3024  nb::arg("callback"),
3025  "Attaches a diagnostic handler that will receive callbacks")
3026  .def(
3027  "enable_multithreading",
3028  [](PyMlirContext &self, bool enable) {
3029  mlirContextEnableMultithreading(self.get(), enable);
3030  },
3031  nb::arg("enable"))
3032  .def("set_thread_pool",
3033  [](PyMlirContext &self, PyThreadPool &pool) {
3034  // we should disable multi-threading first before setting
3035  // new thread pool otherwise the assert in
3036  // MLIRContext::setThreadPool will be raised.
3037  mlirContextEnableMultithreading(self.get(), false);
3038  mlirContextSetThreadPool(self.get(), pool.get());
3039  })
3040  .def("get_num_threads",
3041  [](PyMlirContext &self) {
3042  return mlirContextGetNumThreads(self.get());
3043  })
3044  .def("_mlir_thread_pool_ptr",
3045  [](PyMlirContext &self) {
3046  MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3047  std::stringstream ss;
3048  ss << pool.ptr;
3049  return ss.str();
3050  })
3051  .def(
3052  "is_registered_operation",
3053  [](PyMlirContext &self, std::string &name) {
3055  self.get(), MlirStringRef{name.data(), name.size()});
3056  },
3057  nb::arg("operation_name"))
3058  .def(
3059  "append_dialect_registry",
3060  [](PyMlirContext &self, PyDialectRegistry &registry) {
3061  mlirContextAppendDialectRegistry(self.get(), registry);
3062  },
3063  nb::arg("registry"))
3064  .def_prop_rw("emit_error_diagnostics",
3065  &PyMlirContext::getEmitErrorDiagnostics,
3066  &PyMlirContext::setEmitErrorDiagnostics,
3067  "Emit error diagnostics to diagnostic handlers. By default "
3068  "error diagnostics are captured and reported through "
3069  "MLIRError exceptions.")
3070  .def("load_all_available_dialects", [](PyMlirContext &self) {
3072  });
3073 
3074  //----------------------------------------------------------------------------
3075  // Mapping of PyDialectDescriptor
3076  //----------------------------------------------------------------------------
3077  nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3078  .def_prop_ro("namespace",
3079  [](PyDialectDescriptor &self) {
3081  return nb::str(ns.data, ns.length);
3082  })
3083  .def(
3084  "__repr__",
3085  [](PyDialectDescriptor &self) {
3087  std::string repr("<DialectDescriptor ");
3088  repr.append(ns.data, ns.length);
3089  repr.append(">");
3090  return repr;
3091  },
3092  nb::sig("def __repr__(self) -> str"));
3093 
3094  //----------------------------------------------------------------------------
3095  // Mapping of PyDialects
3096  //----------------------------------------------------------------------------
3097  nb::class_<PyDialects>(m, "Dialects")
3098  .def("__getitem__",
3099  [=](PyDialects &self, std::string keyName) {
3100  MlirDialect dialect =
3101  self.getDialectForKey(keyName, /*attrError=*/false);
3102  nb::object descriptor =
3103  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3104  return createCustomDialectWrapper(keyName, std::move(descriptor));
3105  })
3106  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
3107  MlirDialect dialect =
3108  self.getDialectForKey(attrName, /*attrError=*/true);
3109  nb::object descriptor =
3110  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3111  return createCustomDialectWrapper(attrName, std::move(descriptor));
3112  });
3113 
3114  //----------------------------------------------------------------------------
3115  // Mapping of PyDialect
3116  //----------------------------------------------------------------------------
3117  nb::class_<PyDialect>(m, "Dialect")
3118  .def(nb::init<nb::object>(), nb::arg("descriptor"))
3119  .def_prop_ro("descriptor",
3120  [](PyDialect &self) { return self.getDescriptor(); })
3121  .def(
3122  "__repr__",
3123  [](const nb::object &self) {
3124  auto clazz = self.attr("__class__");
3125  return nb::str("<Dialect ") +
3126  self.attr("descriptor").attr("namespace") +
3127  nb::str(" (class ") + clazz.attr("__module__") +
3128  nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3129  },
3130  nb::sig("def __repr__(self) -> str"));
3131 
3132  //----------------------------------------------------------------------------
3133  // Mapping of PyDialectRegistry
3134  //----------------------------------------------------------------------------
3135  nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3136  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
3137  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
3138  &PyDialectRegistry::createFromCapsule)
3139  .def(nb::init<>());
3140 
3141  //----------------------------------------------------------------------------
3142  // Mapping of Location
3143  //----------------------------------------------------------------------------
3144  nb::class_<PyLocation>(m, "Location")
3145  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
3146  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
3147  .def("__enter__", &PyLocation::contextEnter)
3148  .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
3149  nb::arg("exc_value").none(), nb::arg("traceback").none())
3150  .def("__eq__",
3151  [](PyLocation &self, PyLocation &other) -> bool {
3152  return mlirLocationEqual(self, other);
3153  })
3154  .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
3155  .def_prop_ro_static(
3156  "current",
3157  [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3158  auto *loc = PyThreadContextEntry::getDefaultLocation();
3159  if (!loc)
3160  return std::nullopt;
3161  return loc;
3162  },
3163  // clang-format off
3164  nb::sig("def current(/) -> Location | None"),
3165  // clang-format on
3166  "Gets the Location bound to the current thread or raises ValueError")
3167  .def_static(
3168  "unknown",
3169  [](DefaultingPyMlirContext context) {
3170  return PyLocation(context->getRef(),
3171  mlirLocationUnknownGet(context->get()));
3172  },
3173  nb::arg("context") = nb::none(),
3174  "Gets a Location representing an unknown location")
3175  .def_static(
3176  "callsite",
3177  [](PyLocation callee, const std::vector<PyLocation> &frames,
3178  DefaultingPyMlirContext context) {
3179  if (frames.empty())
3180  throw nb::value_error("No caller frames provided");
3181  MlirLocation caller = frames.back().get();
3182  for (const PyLocation &frame :
3183  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
3184  caller = mlirLocationCallSiteGet(frame.get(), caller);
3185  return PyLocation(context->getRef(),
3186  mlirLocationCallSiteGet(callee.get(), caller));
3187  },
3188  nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
3190  .def("is_a_callsite", mlirLocationIsACallSite)
3191  .def_prop_ro("callee",
3192  [](PyLocation &self) {
3193  return PyLocation(self.getContext(),
3195  })
3196  .def_prop_ro("caller",
3197  [](PyLocation &self) {
3198  return PyLocation(self.getContext(),
3200  })
3201  .def_static(
3202  "file",
3203  [](std::string filename, int line, int col,
3204  DefaultingPyMlirContext context) {
3205  return PyLocation(
3206  context->getRef(),
3208  context->get(), toMlirStringRef(filename), line, col));
3209  },
3210  nb::arg("filename"), nb::arg("line"), nb::arg("col"),
3211  nb::arg("context") = nb::none(), kContextGetFileLocationDocstring)
3212  .def_static(
3213  "file",
3214  [](std::string filename, int startLine, int startCol, int endLine,
3215  int endCol, DefaultingPyMlirContext context) {
3216  return PyLocation(context->getRef(),
3218  context->get(), toMlirStringRef(filename),
3219  startLine, startCol, endLine, endCol));
3220  },
3221  nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
3222  nb::arg("end_line"), nb::arg("end_col"),
3223  nb::arg("context") = nb::none(), kContextGetFileRangeDocstring)
3224  .def("is_a_file", mlirLocationIsAFileLineColRange)
3225  .def_prop_ro("filename",
3226  [](MlirLocation loc) {
3227  return mlirIdentifierStr(
3229  })
3230  .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
3231  .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
3232  .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
3233  .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
3234  .def_static(
3235  "fused",
3236  [](const std::vector<PyLocation> &pyLocations,
3237  std::optional<PyAttribute> metadata,
3238  DefaultingPyMlirContext context) {
3240  locations.reserve(pyLocations.size());
3241  for (auto &pyLocation : pyLocations)
3242  locations.push_back(pyLocation.get());
3243  MlirLocation location = mlirLocationFusedGet(
3244  context->get(), locations.size(), locations.data(),
3245  metadata ? metadata->get() : MlirAttribute{0});
3246  return PyLocation(context->getRef(), location);
3247  },
3248  nb::arg("locations"), nb::arg("metadata") = nb::none(),
3249  nb::arg("context") = nb::none(), kContextGetFusedLocationDocstring)
3250  .def("is_a_fused", mlirLocationIsAFused)
3251  .def_prop_ro(
3252  "locations",
3253  [](PyLocation &self) {
3254  unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3255  std::vector<MlirLocation> locations(numLocations);
3256  if (numLocations)
3257  mlirLocationFusedGetLocations(self, locations.data());
3258  std::vector<PyLocation> pyLocations{};
3259  pyLocations.reserve(numLocations);
3260  for (unsigned i = 0; i < numLocations; ++i)
3261  pyLocations.emplace_back(self.getContext(), locations[i]);
3262  return pyLocations;
3263  })
3264  .def_static(
3265  "name",
3266  [](std::string name, std::optional<PyLocation> childLoc,
3267  DefaultingPyMlirContext context) {
3268  return PyLocation(
3269  context->getRef(),
3271  context->get(), toMlirStringRef(name),
3272  childLoc ? childLoc->get()
3273  : mlirLocationUnknownGet(context->get())));
3274  },
3275  nb::arg("name"), nb::arg("childLoc") = nb::none(),
3276  nb::arg("context") = nb::none(), kContextGetNameLocationDocString)
3277  .def("is_a_name", mlirLocationIsAName)
3278  .def_prop_ro("name_str",
3279  [](MlirLocation loc) {
3281  })
3282  .def_prop_ro("child_loc",
3283  [](PyLocation &self) {
3284  return PyLocation(self.getContext(),
3286  })
3287  .def_static(
3288  "from_attr",
3289  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3290  return PyLocation(context->getRef(),
3291  mlirLocationFromAttribute(attribute));
3292  },
3293  nb::arg("attribute"), nb::arg("context") = nb::none(),
3294  "Gets a Location from a LocationAttr")
3295  .def_prop_ro(
3296  "context",
3297  [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3298  return self.getContext().getObject();
3299  },
3300  "Context that owns the Location")
3301  .def_prop_ro(
3302  "attr",
3303  [](PyLocation &self) {
3304  return PyAttribute(self.getContext(),
3305  mlirLocationGetAttribute(self));
3306  },
3307  "Get the underlying LocationAttr")
3308  .def(
3309  "emit_error",
3310  [](PyLocation &self, std::string message) {
3311  mlirEmitError(self, message.c_str());
3312  },
3313  nb::arg("message"), "Emits an error at this location")
3314  .def("__repr__", [](PyLocation &self) {
3315  PyPrintAccumulator printAccum;
3316  mlirLocationPrint(self, printAccum.getCallback(),
3317  printAccum.getUserData());
3318  return printAccum.join();
3319  });
3320 
3321  //----------------------------------------------------------------------------
3322  // Mapping of Module
3323  //----------------------------------------------------------------------------
3324  nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3325  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3326  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
3328  .def("_clear_mlir_module", &PyModule::clearMlirModule)
3329  .def_static(
3330  "parse",
3331  [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3332  -> nb::typed<nb::object, PyModule> {
3333  PyMlirContext::ErrorCapture errors(context->getRef());
3334  MlirModule module = mlirModuleCreateParse(
3335  context->get(), toMlirStringRef(moduleAsm));
3336  if (mlirModuleIsNull(module))
3337  throw MLIRError("Unable to parse module assembly", errors.take());
3338  return PyModule::forModule(module).releaseObject();
3339  },
3340  nb::arg("asm"), nb::arg("context") = nb::none(),
3342  .def_static(
3343  "parse",
3344  [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3345  -> nb::typed<nb::object, PyModule> {
3346  PyMlirContext::ErrorCapture errors(context->getRef());
3347  MlirModule module = mlirModuleCreateParse(
3348  context->get(), toMlirStringRef(moduleAsm));
3349  if (mlirModuleIsNull(module))
3350  throw MLIRError("Unable to parse module assembly", errors.take());
3351  return PyModule::forModule(module).releaseObject();
3352  },
3353  nb::arg("asm"), nb::arg("context") = nb::none(),
3355  .def_static(
3356  "parseFile",
3357  [](const std::string &path, DefaultingPyMlirContext context)
3358  -> nb::typed<nb::object, PyModule> {
3359  PyMlirContext::ErrorCapture errors(context->getRef());
3360  MlirModule module = mlirModuleCreateParseFromFile(
3361  context->get(), toMlirStringRef(path));
3362  if (mlirModuleIsNull(module))
3363  throw MLIRError("Unable to parse module assembly", errors.take());
3364  return PyModule::forModule(module).releaseObject();
3365  },
3366  nb::arg("path"), nb::arg("context") = nb::none(),
3368  .def_static(
3369  "create",
3370  [](const std::optional<PyLocation> &loc)
3371  -> nb::typed<nb::object, PyModule> {
3372  PyLocation pyLoc = maybeGetTracebackLocation(loc);
3373  MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3374  return PyModule::forModule(module).releaseObject();
3375  },
3376  nb::arg("loc") = nb::none(), "Creates an empty module")
3377  .def_prop_ro(
3378  "context",
3379  [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3380  return self.getContext().getObject();
3381  },
3382  "Context that created the Module")
3383  .def_prop_ro(
3384  "operation",
3385  [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3386  return PyOperation::forOperation(self.getContext(),
3387  mlirModuleGetOperation(self.get()),
3388  self.getRef().releaseObject())
3389  .releaseObject();
3390  },
3391  "Accesses the module as an operation")
3392  .def_prop_ro(
3393  "body",
3394  [](PyModule &self) {
3395  PyOperationRef moduleOp = PyOperation::forOperation(
3396  self.getContext(), mlirModuleGetOperation(self.get()),
3397  self.getRef().releaseObject());
3398  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3399  return returnBlock;
3400  },
3401  "Return the block for this module")
3402  .def(
3403  "dump",
3404  [](PyModule &self) {
3406  },
3408  .def(
3409  "__str__",
3410  [](const nb::object &self) {
3411  // Defer to the operation's __str__.
3412  return self.attr("operation").attr("__str__")();
3413  },
3414  nb::sig("def __str__(self) -> str"), kOperationStrDunderDocstring)
3415  .def(
3416  "__eq__",
3417  [](PyModule &self, PyModule &other) {
3418  return mlirModuleEqual(self.get(), other.get());
3419  },
3420  "other"_a)
3421  .def("__hash__",
3422  [](PyModule &self) { return mlirModuleHashValue(self.get()); });
3423 
3424  //----------------------------------------------------------------------------
3425  // Mapping of Operation.
3426  //----------------------------------------------------------------------------
3427  nb::class_<PyOperationBase>(m, "_OperationBase")
3428  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3429  [](PyOperationBase &self) {
3430  return self.getOperation().getCapsule();
3431  })
3432  .def("__eq__",
3433  [](PyOperationBase &self, PyOperationBase &other) {
3434  return mlirOperationEqual(self.getOperation().get(),
3435  other.getOperation().get());
3436  })
3437  .def("__eq__",
3438  [](PyOperationBase &self, nb::object other) { return false; })
3439  .def("__hash__",
3440  [](PyOperationBase &self) {
3441  return mlirOperationHashValue(self.getOperation().get());
3442  })
3443  .def_prop_ro("attributes",
3444  [](PyOperationBase &self) {
3445  return PyOpAttributeMap(self.getOperation().getRef());
3446  })
3447  .def_prop_ro(
3448  "context",
3449  [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3450  PyOperation &concreteOperation = self.getOperation();
3451  concreteOperation.checkValid();
3452  return concreteOperation.getContext().getObject();
3453  },
3454  "Context that owns the Operation")
3455  .def_prop_ro("name",
3456  [](PyOperationBase &self) {
3457  auto &concreteOperation = self.getOperation();
3458  concreteOperation.checkValid();
3459  MlirOperation operation = concreteOperation.get();
3460  return mlirIdentifierStr(mlirOperationGetName(operation));
3461  })
3462  .def_prop_ro("operands",
3463  [](PyOperationBase &self) {
3464  return PyOpOperandList(self.getOperation().getRef());
3465  })
3466  .def_prop_ro("regions",
3467  [](PyOperationBase &self) {
3468  return PyRegionList(self.getOperation().getRef());
3469  })
3470  .def_prop_ro(
3471  "results",
3472  [](PyOperationBase &self) {
3473  return PyOpResultList(self.getOperation().getRef());
3474  },
3475  "Returns the list of Operation results.")
3476  .def_prop_ro(
3477  "result",
3478  [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3479  auto &operation = self.getOperation();
3480  return PyOpResult(operation.getRef(), getUniqueResult(operation))
3481  .maybeDownCast();
3482  },
3483  "Shortcut to get an op result if it has only one (throws an error "
3484  "otherwise).")
3485  .def_prop_rw(
3486  "location",
3487  [](PyOperationBase &self) {
3488  PyOperation &operation = self.getOperation();
3489  return PyLocation(operation.getContext(),
3490  mlirOperationGetLocation(operation.get()));
3491  },
3492  [](PyOperationBase &self, const PyLocation &location) {
3493  PyOperation &operation = self.getOperation();
3494  mlirOperationSetLocation(operation.get(), location.get());
3495  },
3496  nb::for_getter("Returns the source location the operation was "
3497  "defined or derived from."),
3498  nb::for_setter("Sets the source location the operation was defined "
3499  "or derived from."))
3500  .def_prop_ro("parent",
3501  [](PyOperationBase &self)
3502  -> std::optional<nb::typed<nb::object, PyOperation>> {
3503  auto parent = self.getOperation().getParentOperation();
3504  if (parent)
3505  return parent->getObject();
3506  return {};
3507  })
3508  .def(
3509  "__str__",
3510  [](PyOperationBase &self) {
3511  return self.getAsm(/*binary=*/false,
3512  /*largeElementsLimit=*/std::nullopt,
3513  /*largeResourceLimit=*/std::nullopt,
3514  /*enableDebugInfo=*/false,
3515  /*prettyDebugInfo=*/false,
3516  /*printGenericOpForm=*/false,
3517  /*useLocalScope=*/false,
3518  /*useNameLocAsPrefix=*/false,
3519  /*assumeVerified=*/false,
3520  /*skipRegions=*/false);
3521  },
3522  nb::sig("def __str__(self) -> str"),
3523  "Returns the assembly form of the operation.")
3524  .def("print",
3525  nb::overload_cast<PyAsmState &, nb::object, bool>(
3527  nb::arg("state"), nb::arg("file") = nb::none(),
3528  nb::arg("binary") = false, kOperationPrintStateDocstring)
3529  .def("print",
3530  nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3531  bool, bool, bool, bool, bool, bool, nb::object,
3532  bool, bool>(&PyOperationBase::print),
3533  // Careful: Lots of arguments must match up with print method.
3534  nb::arg("large_elements_limit") = nb::none(),
3535  nb::arg("large_resource_limit") = nb::none(),
3536  nb::arg("enable_debug_info") = false,
3537  nb::arg("pretty_debug_info") = false,
3538  nb::arg("print_generic_op_form") = false,
3539  nb::arg("use_local_scope") = false,
3540  nb::arg("use_name_loc_as_prefix") = false,
3541  nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
3542  nb::arg("binary") = false, nb::arg("skip_regions") = false,
3544  .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3545  nb::arg("desired_version") = nb::none(),
3547  .def("get_asm", &PyOperationBase::getAsm,
3548  // Careful: Lots of arguments must match up with get_asm method.
3549  nb::arg("binary") = false,
3550  nb::arg("large_elements_limit") = nb::none(),
3551  nb::arg("large_resource_limit") = nb::none(),
3552  nb::arg("enable_debug_info") = false,
3553  nb::arg("pretty_debug_info") = false,
3554  nb::arg("print_generic_op_form") = false,
3555  nb::arg("use_local_scope") = false,
3556  nb::arg("use_name_loc_as_prefix") = false,
3557  nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3559  .def("verify", &PyOperationBase::verify,
3560  "Verify the operation. Raises MLIRError if verification fails, and "
3561  "returns true otherwise.")
3562  .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3563  "Puts self immediately after the other operation in its parent "
3564  "block.")
3565  .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3566  "Puts self immediately before the other operation in its parent "
3567  "block.")
3568  .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
3569  nb::arg("other"),
3570  "Given an operation 'other' that is within the same parent block, "
3571  "return"
3572  "whether the current operation is before 'other' in the operation "
3573  "list"
3574  "of the parent block.")
3575  .def(
3576  "clone",
3577  [](PyOperationBase &self,
3578  const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3579  return self.getOperation().clone(ip);
3580  },
3581  nb::arg("ip") = nb::none())
3582  .def(
3583  "detach_from_parent",
3584  [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3585  PyOperation &operation = self.getOperation();
3586  operation.checkValid();
3587  if (!operation.isAttached())
3588  throw nb::value_error("Detached operation has no parent.");
3589 
3590  operation.detachFromParent();
3591  return operation.createOpView();
3592  },
3593  "Detaches the operation from its parent block.")
3594  .def_prop_ro(
3595  "attached",
3596  [](PyOperationBase &self) {
3597  PyOperation &operation = self.getOperation();
3598  operation.checkValid();
3599  return operation.isAttached();
3600  },
3601  "Reports if the operation is attached to its parent block.")
3602  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3603  .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3604  nb::arg("walk_order") = MlirWalkPostOrder,
3605  // clang-format off
3606  nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None")
3607  // clang-format on
3608  );
3609 
3610  nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3611  .def_static(
3612  "create",
3613  [](std::string_view name,
3614  std::optional<std::vector<PyType *>> results,
3615  std::optional<std::vector<PyValue *>> operands,
3616  std::optional<nb::dict> attributes,
3617  std::optional<std::vector<PyBlock *>> successors, int regions,
3618  const std::optional<PyLocation> &location,
3619  const nb::object &maybeIp,
3620  bool inferType) -> nb::typed<nb::object, PyOperation> {
3621  // Unpack/validate operands.
3622  llvm::SmallVector<MlirValue, 4> mlirOperands;
3623  if (operands) {
3624  mlirOperands.reserve(operands->size());
3625  for (PyValue *operand : *operands) {
3626  if (!operand)
3627  throw nb::value_error("operand value cannot be None");
3628  mlirOperands.push_back(operand->get());
3629  }
3630  }
3631 
3632  PyLocation pyLoc = maybeGetTracebackLocation(location);
3633  return PyOperation::create(name, results, mlirOperands, attributes,
3634  successors, regions, pyLoc, maybeIp,
3635  inferType);
3636  },
3637  nb::arg("name"), nb::arg("results") = nb::none(),
3638  nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
3639  nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
3640  nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
3641  nb::arg("infer_type") = false, kOperationCreateDocstring)
3642  .def_static(
3643  "parse",
3644  [](const std::string &sourceStr, const std::string &sourceName,
3645  DefaultingPyMlirContext context)
3646  -> nb::typed<nb::object, PyOpView> {
3647  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3648  ->createOpView();
3649  },
3650  nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3651  nb::arg("context") = nb::none(),
3652  "Parses an operation. Supports both text assembly format and binary "
3653  "bytecode format.")
3654  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3655  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
3656  &PyOperation::createFromCapsule)
3657  .def_prop_ro("operation",
3658  [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3659  return self;
3660  })
3661  .def_prop_ro("opview",
3662  [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
3663  return self.createOpView();
3664  })
3665  .def_prop_ro("block", &PyOperation::getBlock)
3666  .def_prop_ro(
3667  "successors",
3668  [](PyOperationBase &self) {
3669  return PyOpSuccessors(self.getOperation().getRef());
3670  },
3671  "Returns the list of Operation successors.")
3672  .def("_set_invalid", &PyOperation::setInvalid,
3673  "Invalidate the operation.");
3674 
3675  auto opViewClass =
3676  nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3677  .def(nb::init<nb::typed<nb::object, PyOperation>>(),
3678  nb::arg("operation"))
3679  .def(
3680  "__init__",
3681  [](PyOpView *self, std::string_view name,
3682  std::tuple<int, bool> opRegionSpec,
3683  nb::object operandSegmentSpecObj,
3684  nb::object resultSegmentSpecObj,
3685  std::optional<nb::list> resultTypeList, nb::list operandList,
3686  std::optional<nb::dict> attributes,
3687  std::optional<std::vector<PyBlock *>> successors,
3688  std::optional<int> regions,
3689  const std::optional<PyLocation> &location,
3690  const nb::object &maybeIp) {
3691  PyLocation pyLoc = maybeGetTracebackLocation(location);
3692  new (self) PyOpView(PyOpView::buildGeneric(
3693  name, opRegionSpec, operandSegmentSpecObj,
3694  resultSegmentSpecObj, resultTypeList, operandList,
3695  attributes, successors, regions, pyLoc, maybeIp));
3696  },
3697  nb::arg("name"), nb::arg("opRegionSpec"),
3698  nb::arg("operandSegmentSpecObj") = nb::none(),
3699  nb::arg("resultSegmentSpecObj") = nb::none(),
3700  nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
3701  nb::arg("attributes") = nb::none(),
3702  nb::arg("successors") = nb::none(),
3703  nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
3704  nb::arg("ip") = nb::none())
3705  .def_prop_ro(
3706  "operation",
3707  [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
3708  return self.getOperationObject();
3709  })
3710  .def_prop_ro("opview",
3711  [](nb::object self) -> nb::typed<nb::object, PyOpView> {
3712  return self;
3713  })
3714  .def(
3715  "__str__",
3716  [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3717  .def_prop_ro(
3718  "successors",
3719  [](PyOperationBase &self) {
3720  return PyOpSuccessors(self.getOperation().getRef());
3721  },
3722  "Returns the list of Operation successors.")
3723  .def(
3724  "_set_invalid",
3725  [](PyOpView &self) { self.getOperation().setInvalid(); },
3726  "Invalidate the operation.");
3727  opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3728  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3729  opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3730  // It is faster to pass the operation_name, ods_regions, and
3731  // ods_operand_segments/ods_result_segments as arguments to the constructor,
3732  // rather than to access them as attributes.
3733  opViewClass.attr("build_generic") = classmethod(
3734  [](nb::handle cls, std::optional<nb::list> resultTypeList,
3735  nb::list operandList, std::optional<nb::dict> attributes,
3736  std::optional<std::vector<PyBlock *>> successors,
3737  std::optional<int> regions, std::optional<PyLocation> location,
3738  const nb::object &maybeIp) {
3739  std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3740  std::tuple<int, bool> opRegionSpec =
3741  nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3742  nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3743  nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3744  PyLocation pyLoc = maybeGetTracebackLocation(location);
3745  return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3746  resultSegmentSpec, resultTypeList,
3747  operandList, attributes, successors,
3748  regions, pyLoc, maybeIp);
3749  },
3750  nb::arg("cls"), nb::arg("results") = nb::none(),
3751  nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
3752  nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
3753  nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
3754  "Builds a specific, generated OpView based on class level attributes.");
3755  opViewClass.attr("parse") = classmethod(
3756  [](const nb::object &cls, const std::string &sourceStr,
3757  const std::string &sourceName,
3758  DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
3759  PyOperationRef parsed =
3760  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3761 
3762  // Check if the expected operation was parsed, and cast to to the
3763  // appropriate `OpView` subclass if successful.
3764  // NOTE: This accesses attributes that have been automatically added to
3765  // `OpView` subclasses, and is not intended to be used on `OpView`
3766  // directly.
3767  std::string clsOpName =
3768  nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3769  MlirStringRef identifier =
3771  std::string_view parsedOpName(identifier.data, identifier.length);
3772  if (clsOpName != parsedOpName)
3773  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3774  parsedOpName + "'");
3775  return PyOpView::constructDerived(cls, parsed.getObject());
3776  },
3777  nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3778  nb::arg("source_name") = "", nb::arg("context") = nb::none(),
3779  "Parses a specific, generated OpView based on class level attributes");
3780 
3781  //----------------------------------------------------------------------------
3782  // Mapping of PyRegion.
3783  //----------------------------------------------------------------------------
3784  nb::class_<PyRegion>(m, "Region")
3785  .def_prop_ro(
3786  "blocks",
3787  [](PyRegion &self) {
3788  return PyBlockList(self.getParentOperation(), self.get());
3789  },
3790  "Returns a forward-optimized sequence of blocks.")
3791  .def_prop_ro(
3792  "owner",
3793  [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
3794  return self.getParentOperation()->createOpView();
3795  },
3796  "Returns the operation owning this region.")
3797  .def(
3798  "__iter__",
3799  [](PyRegion &self) {
3800  self.checkValid();
3801  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3802  return PyBlockIterator(self.getParentOperation(), firstBlock);
3803  },
3804  "Iterates over blocks in the region.")
3805  .def("__eq__",
3806  [](PyRegion &self, PyRegion &other) {
3807  return self.get().ptr == other.get().ptr;
3808  })
3809  .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3810 
3811  //----------------------------------------------------------------------------
3812  // Mapping of PyBlock.
3813  //----------------------------------------------------------------------------
3814  nb::class_<PyBlock>(m, "Block")
3815  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3816  .def_prop_ro(
3817  "owner",
3818  [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
3819  return self.getParentOperation()->createOpView();
3820  },
3821  "Returns the owning operation of this block.")
3822  .def_prop_ro(
3823  "region",
3824  [](PyBlock &self) {
3825  MlirRegion region = mlirBlockGetParentRegion(self.get());
3826  return PyRegion(self.getParentOperation(), region);
3827  },
3828  "Returns the owning region of this block.")
3829  .def_prop_ro(
3830  "arguments",
3831  [](PyBlock &self) {
3832  return PyBlockArgumentList(self.getParentOperation(), self.get());
3833  },
3834  "Returns a list of block arguments.")
3835  .def(
3836  "add_argument",
3837  [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3838  return PyBlockArgument(self.getParentOperation(),
3839  mlirBlockAddArgument(self.get(), type, loc));
3840  },
3841  "type"_a, "loc"_a,
3842  "Append an argument of the specified type to the block and returns "
3843  "the newly added argument.")
3844  .def(
3845  "erase_argument",
3846  [](PyBlock &self, unsigned index) {
3847  return mlirBlockEraseArgument(self.get(), index);
3848  },
3849  "Erase the argument at 'index' and remove it from the argument list.")
3850  .def_prop_ro(
3851  "operations",
3852  [](PyBlock &self) {
3853  return PyOperationList(self.getParentOperation(), self.get());
3854  },
3855  "Returns a forward-optimized sequence of operations.")
3856  .def_static(
3857  "create_at_start",
3858  [](PyRegion &parent, const nb::sequence &pyArgTypes,
3859  const std::optional<nb::sequence> &pyArgLocs) {
3860  parent.checkValid();
3861  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3862  mlirRegionInsertOwnedBlock(parent, 0, block);
3863  return PyBlock(parent.getParentOperation(), block);
3864  },
3865  nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3866  nb::arg("arg_locs") = std::nullopt,
3867  "Creates and returns a new Block at the beginning of the given "
3868  "region (with given argument types and locations).")
3869  .def(
3870  "append_to",
3871  [](PyBlock &self, PyRegion &region) {
3872  MlirBlock b = self.get();
3874  mlirBlockDetach(b);
3875  mlirRegionAppendOwnedBlock(region.get(), b);
3876  },
3877  "Append this block to a region, transferring ownership if necessary")
3878  .def(
3879  "create_before",
3880  [](PyBlock &self, const nb::args &pyArgTypes,
3881  const std::optional<nb::sequence> &pyArgLocs) {
3882  self.checkValid();
3883  MlirBlock block =
3884  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3885  MlirRegion region = mlirBlockGetParentRegion(self.get());
3886  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3887  return PyBlock(self.getParentOperation(), block);
3888  },
3889  nb::arg("arg_types"), nb::kw_only(),
3890  nb::arg("arg_locs") = std::nullopt,
3891  "Creates and returns a new Block before this block "
3892  "(with given argument types and locations).")
3893  .def(
3894  "create_after",
3895  [](PyBlock &self, const nb::args &pyArgTypes,
3896  const std::optional<nb::sequence> &pyArgLocs) {
3897  self.checkValid();
3898  MlirBlock block =
3899  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3900  MlirRegion region = mlirBlockGetParentRegion(self.get());
3901  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3902  return PyBlock(self.getParentOperation(), block);
3903  },
3904  nb::arg("arg_types"), nb::kw_only(),
3905  nb::arg("arg_locs") = std::nullopt,
3906  "Creates and returns a new Block after this block "
3907  "(with given argument types and locations).")
3908  .def(
3909  "__iter__",
3910  [](PyBlock &self) {
3911  self.checkValid();
3912  MlirOperation firstOperation =
3914  return PyOperationIterator(self.getParentOperation(),
3915  firstOperation);
3916  },
3917  "Iterates over operations in the block.")
3918  .def("__eq__",
3919  [](PyBlock &self, PyBlock &other) {
3920  return self.get().ptr == other.get().ptr;
3921  })
3922  .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3923  .def("__hash__",
3924  [](PyBlock &self) {
3925  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3926  })
3927  .def(
3928  "__str__",
3929  [](PyBlock &self) {
3930  self.checkValid();
3931  PyPrintAccumulator printAccum;
3932  mlirBlockPrint(self.get(), printAccum.getCallback(),
3933  printAccum.getUserData());
3934  return printAccum.join();
3935  },
3936  "Returns the assembly form of the block.")
3937  .def(
3938  "append",
3939  [](PyBlock &self, PyOperationBase &operation) {
3940  if (operation.getOperation().isAttached())
3941  operation.getOperation().detachFromParent();
3942 
3943  MlirOperation mlirOperation = operation.getOperation().get();
3944  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3945  operation.getOperation().setAttached(
3946  self.getParentOperation().getObject());
3947  },
3948  nb::arg("operation"),
3949  "Appends an operation to this block. If the operation is currently "
3950  "in another block, it will be moved.")
3951  .def_prop_ro(
3952  "successors",
3953  [](PyBlock &self) {
3954  return PyBlockSuccessors(self, self.getParentOperation());
3955  },
3956  "Returns the list of Block successors.")
3957  .def_prop_ro(
3958  "predecessors",
3959  [](PyBlock &self) {
3960  return PyBlockPredecessors(self, self.getParentOperation());
3961  },
3962  "Returns the list of Block predecessors.");
3963 
3964  //----------------------------------------------------------------------------
3965  // Mapping of PyInsertionPoint.
3966  //----------------------------------------------------------------------------
3967 
3968  nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3969  .def(nb::init<PyBlock &>(), nb::arg("block"),
3970  "Inserts after the last operation but still inside the block.")
3971  .def("__enter__", &PyInsertionPoint::contextEnter)
3972  .def("__exit__", &PyInsertionPoint::contextExit,
3973  nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3974  nb::arg("traceback").none())
3975  .def_prop_ro_static(
3976  "current",
3977  [](nb::object & /*class*/) {
3978  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3979  if (!ip)
3980  throw nb::value_error("No current InsertionPoint");
3981  return ip;
3982  },
3983  nb::sig("def current(/) -> InsertionPoint"),
3984  "Gets the InsertionPoint bound to the current thread or raises "
3985  "ValueError if none has been set")
3986  .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3987  "Inserts before a referenced operation.")
3988  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3989  nb::arg("block"), "Inserts at the beginning of the block.")
3990  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3991  nb::arg("block"), "Inserts before the block terminator.")
3992  .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
3993  "Inserts after the operation.")
3994  .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3995  "Inserts an operation.")
3996  .def_prop_ro(
3997  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3998  "Returns the block that this InsertionPoint points to.")
3999  .def_prop_ro(
4000  "ref_operation",
4001  [](PyInsertionPoint &self)
4002  -> std::optional<nb::typed<nb::object, PyOperation>> {
4003  auto refOperation = self.getRefOperation();
4004  if (refOperation)
4005  return refOperation->getObject();
4006  return {};
4007  },
4008  "The reference operation before which new operations are "
4009  "inserted, or None if the insertion point is at the end of "
4010  "the block");
4011 
4012  //----------------------------------------------------------------------------
4013  // Mapping of PyAttribute.
4014  //----------------------------------------------------------------------------
4015  nb::class_<PyAttribute>(m, "Attribute")
4016  // Delegate to the PyAttribute copy constructor, which will also lifetime
4017  // extend the backing context which owns the MlirAttribute.
4018  .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
4019  "Casts the passed attribute to the generic Attribute")
4020  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
4021  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
4022  &PyAttribute::createFromCapsule)
4023  .def_static(
4024  "parse",
4025  [](const std::string &attrSpec, DefaultingPyMlirContext context)
4026  -> nb::typed<nb::object, PyAttribute> {
4027  PyMlirContext::ErrorCapture errors(context->getRef());
4028  MlirAttribute attr = mlirAttributeParseGet(
4029  context->get(), toMlirStringRef(attrSpec));
4030  if (mlirAttributeIsNull(attr))
4031  throw MLIRError("Unable to parse attribute", errors.take());
4032  return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4033  },
4034  nb::arg("asm"), nb::arg("context") = nb::none(),
4035  "Parses an attribute from an assembly form. Raises an MLIRError on "
4036  "failure.")
4037  .def_prop_ro(
4038  "context",
4039  [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4040  return self.getContext().getObject();
4041  },
4042  "Context that owns the Attribute")
4043  .def_prop_ro("type",
4044  [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4045  return PyType(self.getContext(),
4046  mlirAttributeGetType(self))
4047  .maybeDownCast();
4048  })
4049  .def(
4050  "get_named",
4051  [](PyAttribute &self, std::string name) {
4052  return PyNamedAttribute(self, std::move(name));
4053  },
4054  nb::keep_alive<0, 1>(), "Binds a name to the attribute")
4055  .def("__eq__",
4056  [](PyAttribute &self, PyAttribute &other) { return self == other; })
4057  .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
4058  .def("__hash__",
4059  [](PyAttribute &self) {
4060  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4061  })
4062  .def(
4063  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4065  .def(
4066  "__str__",
4067  [](PyAttribute &self) {
4068  PyPrintAccumulator printAccum;
4069  mlirAttributePrint(self, printAccum.getCallback(),
4070  printAccum.getUserData());
4071  return printAccum.join();
4072  },
4073  "Returns the assembly form of the Attribute.")
4074  .def("__repr__",
4075  [](PyAttribute &self) {
4076  // Generally, assembly formats are not printed for __repr__ because
4077  // this can cause exceptionally long debug output and exceptions.
4078  // However, attribute values are generally considered useful and
4079  // are printed. This may need to be re-evaluated if debug dumps end
4080  // up being excessive.
4081  PyPrintAccumulator printAccum;
4082  printAccum.parts.append("Attribute(");
4083  mlirAttributePrint(self, printAccum.getCallback(),
4084  printAccum.getUserData());
4085  printAccum.parts.append(")");
4086  return printAccum.join();
4087  })
4088  .def_prop_ro("typeid",
4089  [](PyAttribute &self) {
4090  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4091  assert(!mlirTypeIDIsNull(mlirTypeID) &&
4092  "mlirTypeID was expected to be non-null.");
4093  return PyTypeID(mlirTypeID);
4094  })
4096  [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4097  return self.maybeDownCast();
4098  });
4099 
4100  //----------------------------------------------------------------------------
4101  // Mapping of PyNamedAttribute
4102  //----------------------------------------------------------------------------
4103  nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4104  .def("__repr__",
4105  [](PyNamedAttribute &self) {
4106  PyPrintAccumulator printAccum;
4107  printAccum.parts.append("NamedAttribute(");
4108  printAccum.parts.append(
4109  nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4110  mlirIdentifierStr(self.namedAttr.name).length));
4111  printAccum.parts.append("=");
4112  mlirAttributePrint(self.namedAttr.attribute,
4113  printAccum.getCallback(),
4114  printAccum.getUserData());
4115  printAccum.parts.append(")");
4116  return printAccum.join();
4117  })
4118  .def_prop_ro(
4119  "name",
4120  [](PyNamedAttribute &self) {
4121  return mlirIdentifierStr(self.namedAttr.name);
4122  },
4123  "The name of the NamedAttribute binding")
4124  .def_prop_ro(
4125  "attr",
4126  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4127  nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4128  "The underlying generic attribute of the NamedAttribute binding");
4129 
4130  //----------------------------------------------------------------------------
4131  // Mapping of PyType.
4132  //----------------------------------------------------------------------------
4133  nb::class_<PyType>(m, "Type")
4134  // Delegate to the PyType copy constructor, which will also lifetime
4135  // extend the backing context which owns the MlirType.
4136  .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
4137  "Casts the passed type to the generic Type")
4138  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
4139  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
4140  .def_static(
4141  "parse",
4142  [](std::string typeSpec,
4143  DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4144  PyMlirContext::ErrorCapture errors(context->getRef());
4145  MlirType type =
4146  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4147  if (mlirTypeIsNull(type))
4148  throw MLIRError("Unable to parse type", errors.take());
4149  return PyType(context.get()->getRef(), type).maybeDownCast();
4150  },
4151  nb::arg("asm"), nb::arg("context") = nb::none(),
4153  .def_prop_ro(
4154  "context",
4155  [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4156  return self.getContext().getObject();
4157  },
4158  "Context that owns the Type")
4159  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
4160  .def(
4161  "__eq__", [](PyType &self, nb::object &other) { return false; },
4162  nb::arg("other").none())
4163  .def("__hash__",
4164  [](PyType &self) {
4165  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4166  })
4167  .def(
4168  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4169  .def(
4170  "__str__",
4171  [](PyType &self) {
4172  PyPrintAccumulator printAccum;
4173  mlirTypePrint(self, printAccum.getCallback(),
4174  printAccum.getUserData());
4175  return printAccum.join();
4176  },
4177  "Returns the assembly form of the type.")
4178  .def("__repr__",
4179  [](PyType &self) {
4180  // Generally, assembly formats are not printed for __repr__ because
4181  // this can cause exceptionally long debug output and exceptions.
4182  // However, types are an exception as they typically have compact
4183  // assembly forms and printing them is useful.
4184  PyPrintAccumulator printAccum;
4185  printAccum.parts.append("Type(");
4186  mlirTypePrint(self, printAccum.getCallback(),
4187  printAccum.getUserData());
4188  printAccum.parts.append(")");
4189  return printAccum.join();
4190  })
4192  [](PyType &self) -> nb::typed<nb::object, PyType> {
4193  return self.maybeDownCast();
4194  })
4195  .def_prop_ro("typeid", [](PyType &self) {
4196  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4197  if (!mlirTypeIDIsNull(mlirTypeID))
4198  return PyTypeID(mlirTypeID);
4199  auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4200  throw nb::value_error(
4201  (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
4202  });
4203 
4204  //----------------------------------------------------------------------------
4205  // Mapping of PyTypeID.
4206  //----------------------------------------------------------------------------
4207  nb::class_<PyTypeID>(m, "TypeID")
4208  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
4209  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
4210  // Note, this tests whether the underlying TypeIDs are the same,
4211  // not whether the wrapper MlirTypeIDs are the same, nor whether
4212  // the Python objects are the same (i.e., PyTypeID is a value type).
4213  .def("__eq__",
4214  [](PyTypeID &self, PyTypeID &other) { return self == other; })
4215  .def("__eq__",
4216  [](PyTypeID &self, const nb::object &other) { return false; })
4217  // Note, this gives the hash value of the underlying TypeID, not the
4218  // hash value of the Python object, nor the hash value of the
4219  // MlirTypeID wrapper.
4220  .def("__hash__", [](PyTypeID &self) {
4221  return static_cast<size_t>(mlirTypeIDHashValue(self));
4222  });
4223 
4224  //----------------------------------------------------------------------------
4225  // Mapping of Value.
4226  //----------------------------------------------------------------------------
4227  nb::class_<PyValue>(m, "Value")
4228  .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
4229  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
4230  .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
4231  .def_prop_ro(
4232  "context",
4233  [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4234  return self.getParentOperation()->getContext().getObject();
4235  },
4236  "Context in which the value lives.")
4237  .def(
4238  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4240  .def_prop_ro(
4241  "owner",
4242  [](PyValue &self) -> nb::object {
4243  MlirValue v = self.get();
4244  if (mlirValueIsAOpResult(v)) {
4245  assert(mlirOperationEqual(self.getParentOperation()->get(),
4246  mlirOpResultGetOwner(self.get())) &&
4247  "expected the owner of the value in Python to match "
4248  "that in "
4249  "the IR");
4250  return self.getParentOperation().getObject();
4251  }
4252 
4253  if (mlirValueIsABlockArgument(v)) {
4254  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4255  return nb::cast(PyBlock(self.getParentOperation(), block));
4256  }
4257 
4258  assert(false && "Value must be a block argument or an op result");
4259  return nb::none();
4260  },
4261  // clang-format off
4262  nb::sig("def owner(self) -> Operation | Block | None"))
4263  // clang-format on
4264  .def_prop_ro("uses",
4265  [](PyValue &self) {
4266  return PyOpOperandIterator(
4267  mlirValueGetFirstUse(self.get()));
4268  })
4269  .def("__eq__",
4270  [](PyValue &self, PyValue &other) {
4271  return self.get().ptr == other.get().ptr;
4272  })
4273  .def("__eq__", [](PyValue &self, nb::object other) { return false; })
4274  .def("__hash__",
4275  [](PyValue &self) {
4276  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4277  })
4278  .def(
4279  "__str__",
4280  [](PyValue &self) {
4281  PyPrintAccumulator printAccum;
4282  printAccum.parts.append("Value(");
4283  mlirValuePrint(self.get(), printAccum.getCallback(),
4284  printAccum.getUserData());
4285  printAccum.parts.append(")");
4286  return printAccum.join();
4287  },
4289  .def(
4290  "get_name",
4291  [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4292  PyPrintAccumulator printAccum;
4293  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4294  if (useLocalScope)
4296  if (useNameLocAsPrefix)
4298  MlirAsmState valueState =
4299  mlirAsmStateCreateForValue(self.get(), flags);
4300  mlirValuePrintAsOperand(self.get(), valueState,
4301  printAccum.getCallback(),
4302  printAccum.getUserData());
4304  mlirAsmStateDestroy(valueState);
4305  return printAccum.join();
4306  },
4307  nb::arg("use_local_scope") = false,
4308  nb::arg("use_name_loc_as_prefix") = false)
4309  .def(
4310  "get_name",
4311  [](PyValue &self, PyAsmState &state) {
4312  PyPrintAccumulator printAccum;
4313  MlirAsmState valueState = state.get();
4314  mlirValuePrintAsOperand(self.get(), valueState,
4315  printAccum.getCallback(),
4316  printAccum.getUserData());
4317  return printAccum.join();
4318  },
4319  nb::arg("state"), kGetNameAsOperand)
4320  .def_prop_ro("type",
4321  [](PyValue &self) -> nb::typed<nb::object, PyType> {
4322  return PyType(self.getParentOperation()->getContext(),
4323  mlirValueGetType(self.get()))
4324  .maybeDownCast();
4325  })
4326  .def(
4327  "set_type",
4328  [](PyValue &self, const PyType &type) {
4329  return mlirValueSetType(self.get(), type);
4330  },
4331  nb::arg("type"))
4332  .def(
4333  "replace_all_uses_with",
4334  [](PyValue &self, PyValue &with) {
4335  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4336  },
4338  .def(
4339  "replace_all_uses_except",
4340  [](PyValue &self, PyValue &with, PyOperation &exception) {
4341  MlirOperation exceptedUser = exception.get();
4342  mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4343  },
4344  nb::arg("with_"), nb::arg("exceptions"),
4346  .def(
4347  "replace_all_uses_except",
4348  [](PyValue &self, PyValue &with, const nb::list &exceptions) {
4349  // Convert Python list to a SmallVector of MlirOperations
4350  llvm::SmallVector<MlirOperation> exceptionOps;
4351  for (nb::handle exception : exceptions) {
4352  exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4353  }
4354 
4356  self, with, static_cast<intptr_t>(exceptionOps.size()),
4357  exceptionOps.data());
4358  },
4359  nb::arg("with_"), nb::arg("exceptions"),
4361  .def(
4362  "replace_all_uses_except",
4363  [](PyValue &self, PyValue &with, PyOperation &exception) {
4364  MlirOperation exceptedUser = exception.get();
4365  mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4366  },
4367  nb::arg("with_"), nb::arg("exceptions"),
4369  .def(
4370  "replace_all_uses_except",
4371  [](PyValue &self, PyValue &with,
4372  std::vector<PyOperation> &exceptions) {
4373  // Convert Python list to a SmallVector of MlirOperations
4374  llvm::SmallVector<MlirOperation> exceptionOps;
4375  for (PyOperation &exception : exceptions)
4376  exceptionOps.push_back(exception);
4378  self, with, static_cast<intptr_t>(exceptionOps.size()),
4379  exceptionOps.data());
4380  },
4381  nb::arg("with_"), nb::arg("exceptions"),
4384  [](PyValue &self) -> nb::typed<nb::object, PyValue> {
4385  return self.maybeDownCast();
4386  })
4387  .def_prop_ro(
4388  "location",
4389  [](MlirValue self) {
4390  return PyLocation(
4391  PyMlirContext::forContext(mlirValueGetContext(self)),
4392  mlirValueGetLocation(self));
4393  },
4394  "Returns the source location the value");
4395 
4396  PyBlockArgument::bind(m);
4397  PyOpResult::bind(m);
4398  PyOpOperand::bind(m);
4399 
4400  nb::class_<PyAsmState>(m, "AsmState")
4401  .def(nb::init<PyValue &, bool>(), nb::arg("value"),
4402  nb::arg("use_local_scope") = false)
4403  .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4404  nb::arg("use_local_scope") = false);
4405 
4406  //----------------------------------------------------------------------------
4407  // Mapping of SymbolTable.
4408  //----------------------------------------------------------------------------
4409  nb::class_<PySymbolTable>(m, "SymbolTable")
4410  .def(nb::init<PyOperationBase &>())
4411  .def("__getitem__",
4412  [](PySymbolTable &self,
4413  const std::string &name) -> nb::typed<nb::object, PyOpView> {
4414  return self.dunderGetItem(name);
4415  })
4416  .def("insert", &PySymbolTable::insert, nb::arg("operation"))
4417  .def("erase", &PySymbolTable::erase, nb::arg("operation"))
4418  .def("__delitem__", &PySymbolTable::dunderDel)
4419  .def("__contains__",
4420  [](PySymbolTable &table, const std::string &name) {
4422  table, mlirStringRefCreate(name.data(), name.length())));
4423  })
4424  // Static helpers.
4425  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
4426  nb::arg("symbol"), nb::arg("name"))
4427  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
4428  nb::arg("symbol"))
4429  .def_static("get_visibility", &PySymbolTable::getVisibility,
4430  nb::arg("symbol"))
4431  .def_static("set_visibility", &PySymbolTable::setVisibility,
4432  nb::arg("symbol"), nb::arg("visibility"))
4433  .def_static("replace_all_symbol_uses",
4434  &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4435  nb::arg("new_symbol"), nb::arg("from_op"))
4436  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4437  nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4438  nb::arg("callback"));
4439 
4440  // Container bindings.
4441  PyBlockArgumentList::bind(m);
4442  PyBlockIterator::bind(m);
4443  PyBlockList::bind(m);
4444  PyBlockSuccessors::bind(m);
4445  PyBlockPredecessors::bind(m);
4446  PyOperationIterator::bind(m);
4447  PyOperationList::bind(m);
4448  PyOpAttributeMap::bind(m);
4449  PyOpOperandIterator::bind(m);
4450  PyOpOperandList::bind(m);
4452  PyOpSuccessors::bind(m);
4453  PyRegionIterator::bind(m);
4454  PyRegionList::bind(m);
4455 
4456  // Debug bindings.
4458 
4459  // Attribute builder getter.
4461 
4462  nb::register_exception_translator([](const std::exception_ptr &p,
4463  void *payload) {
4464  // We can't define exceptions with custom fields through pybind, so instead
4465  // the exception class is defined in python and imported here.
4466  try {
4467  if (p)
4468  std::rethrow_exception(p);
4469  } catch (const MLIRError &e) {
4470  nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4471  .attr("MLIRError")(e.message, e.errorDiagnostics);
4472  PyErr_SetObject(PyExc_Exception, obj.ptr());
4473  }
4474  });
4475 }
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition: Debug.cpp:20
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools.
Definition: Debug.cpp:28
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static bool isBeforeInBlock(Block *block, Block::iterator a, Block::iterator b)
Given two iterators into the same block, return "true" if a is before `b.
Definition: Dominance.cpp:241
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:126
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:188
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:59
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:184
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:223
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:62
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:158
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:95
static nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:205
static MlirValue getUniqueResult(MlirOperation operation)
Definition: IRCore.cpp:1807
#define Py_XNewRef(obj)
Definition: IRCore.cpp:2775
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:50
static const char kDumpDocstring[]
Definition: IRCore.cpp:166
#define _Py_CAST(type, expr)
Definition: IRCore.cpp:2751
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:169
static const char kModuleCAPICreate[]
Definition: IRCore.cpp:70
static MlirValue getOpResultOrValue(nb::handle operand)
Definition: IRCore.cpp:1822
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition: IRCore.cpp:1645
#define Py_NewRef(obj)
Definition: IRCore.cpp:2784
#define _Py_NULL
Definition: IRCore.cpp:2762
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:56
static const char kContextGetFileRangeDocstring[]
Definition: IRCore.cpp:53
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition: IRCore.cpp:1407
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition: IRCore.cpp:211
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:148
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:135
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:237
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:76
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:39
static void populateResultTypes(StringRef name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1710
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:47
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:176
static const char kValueReplaceAllUsesExceptDocstring[]
Definition: IRCore.cpp:193
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:273
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:118
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:348
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:189
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:180
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:255
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:357
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:235
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:367
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:454
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:198
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:330
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:264
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:445
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:216
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
ArrayRef< float > table
A list of operation results.
Definition: IRCore.cpp:1660
PyOperationRef & getOperation()
Definition: IRCore.cpp:1683
static void bindDerived(ClassTy &c)
Definition: IRCore.cpp:1673
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition: IRCore.cpp:1665
Python wrapper for MlirOpResult.
Definition: IRCore.cpp:1621
static void bindDerived(ClassTy &c)
Definition: IRCore.cpp:1627
Accumulates into a file, either writing text (default) or binary.
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
nanobind::class_< PyOpResultList > ClassTy
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:284
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:292
static PyLocation & resolve()
Definition: IRCore.cpp:1064
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:273
static PyMlirContext & resolve()
Definition: IRCore.cpp:785
ReferrentTy * get() const
Definition: NanobindUtils.h:60
Wrapper around an MlirAsmState.
Definition: IRModule.h:778
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1008
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:2135
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:1010
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:2131
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2143
MlirAttribute get() const
Definition: IRModule.h:1014
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:2127
Wrapper around an MlirBlock.
Definition: IRModule.h:813
MlirBlock get()
Definition: IRModule.h:820
PyOperationRef & getParentOperation()
Definition: IRModule.h:821
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:380
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition: IRCore.cpp:941
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:947
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:330
PyLocation getLocation()
Definition: IRCore.cpp:970
nanobind::tuple getNotes()
Definition: IRCore.cpp:985
nanobind::str getMessage()
Definition: IRCore.cpp:977
DiagnosticInfo getInfo()
Definition: IRCore.cpp:1001
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:332
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:965
Wrapper around an MlirDialect.
Definition: IRModule.h:435
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:472
nanobind::object getCapsule()
Definition: IRCore.cpp:1026
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition: IRCore.cpp:1030
User-level dialect object.
Definition: IRModule.h:459
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:448
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:1013
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:155
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:184
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:39
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:142
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:837
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:2091
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:2078
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:2117
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition: IRCore.cpp:2100
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:2043
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:2052
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition: IRCore.cpp:2113
Wrapper around an MlirLocation.
Definition: IRModule.h:299
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1042
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:301
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1046
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:1058
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition: IRCore.cpp:1054
MlirLocation get() const
Definition: IRModule.h:305
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:204
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:208
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:706
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:2111
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:721
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:670
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:715
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:681
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:674
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition: IRCore.cpp:711
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:522
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1091
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1116
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1123
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1034
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:2161
MlirNamedAttribute namedAttr
Definition: IRModule.h:1043
nanobind::object getObject()
Definition: IRModule.h:91
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:79
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:723
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:2025
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition: IRCore.cpp:1841
PyOpView(const nanobind::object &operationObject)
Definition: IRCore.cpp:2033
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:552
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1280
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: IRCore.cpp:1358
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition: IRCore.cpp:1312
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1340
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1258
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1349
bool verify()
Verify the operation.
Definition: IRCore.cpp:1366
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:630
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition: IRCore.cpp:1131
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1557
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1398
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1393
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition: IRCore.cpp:1183
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition: IRCore.cpp:1537
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:608
PyOperationRef getRef()
Definition: IRModule.h:643
MlirOperation get() const
Definition: IRModule.h:638
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1176
void setAttached(const nanobind::object &parent=nanobind::object())
Definition: IRModule.h:648
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1374
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * >> results, llvm::ArrayRef< MlirValue > operands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1422
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1546
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1384
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1192
void checkValid() const
Definition: IRCore.cpp:1204
void setInvalid()
Invalidate the operation.
Definition: IRModule.h:690
Wrapper around an MlirRegion.
Definition: IRModule.h:759
PyOperationRef & getParentOperation()
Definition: IRModule.h:768
MlirRegion get()
Definition: IRModule.h:767
Bindings for MLIR symbol tables.
Definition: IRModule.h:1266
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2296
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2382
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2370
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2352
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2326
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2301
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2286
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2265
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2313
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2341
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2273
Tracks an entry in the thread context stack.
Definition: IRModule.h:112
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:805
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:917
static nanobind::object pushLocation(nanobind::object location)
Definition: IRCore.cpp:908
static nanobind::object pushContext(nanobind::object context)
Definition: IRCore.cpp:867
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:862
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:897
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition: IRCore.cpp:885
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:874
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:857
PyMlirContext * getContext()
Definition: IRCore.cpp:834
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:852
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:800
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:840
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition: IRModule.h:168
MlirLlvmThreadPool get()
Definition: IRModule.h:177
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:904
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:2211
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:2217
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:2207
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:906
Wrapper around the generic MlirType.
Definition: IRModule.h:878
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:880
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:2177
MlirType get() const
Definition: IRModule.h:884
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:2181
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2189
bool operator==(const PyType &other) const
Definition: IRCore.cpp:2173
Wrapper around the generic MlirValue.
Definition: IRModule.h:1167
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1173
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:2244
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2229
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:2225
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:265
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:1122
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1183
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:851
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition: IR.cpp:1192
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition: IR.cpp:117
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:841
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:669
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:273
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1374
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:137
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:728
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition: IR.cpp:311
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1353
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:844
@ MlirWalkPreOrder
Definition: IR.h:845
@ MlirWalkPostOrder
Definition: IR.h:846
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Print the name and location, if NamedLoc, as a prefix to the SSA ID.
Definition: IR.cpp:229
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:802
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:509
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:454
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1301
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition: IR.cpp:392
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1358
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:233
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:1110
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:84
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1322
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:406
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1235
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1218
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1274
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:73
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:946
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition: IR.cpp:321
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Returns pos-th successor of the block.
Definition: IR.cpp:1086
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition: IR.cpp:355
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:1174
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:740
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:1157
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:403
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1255
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:812
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1206
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Enables the elision of large resources strings by omitting them from the dialect_resources section.
Definition: IR.cpp:215
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition: IR.cpp:289
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1285
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition: IR.cpp:361
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1293
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:985
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:865
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:1032
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition: IR.cpp:1178
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:833
MlirWalkResult
Operation walk result.
Definition: IR.h:837
@ MlirWalkResultInterrupt
Definition: IR.h:839
@ MlirWalkResultSkip
Definition: IR.h:840
@ MlirWalkResultAdvance
Definition: IR.h:838
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:926
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:807
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1148
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Returns pos-th predecessor of the block.
Definition: IR.cpp:1095
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:100
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:736
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:635
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:1054
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition: IR.cpp:293
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:225
MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Checks if two modules are equal.
Definition: IR.cpp:468
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:347
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:1035
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:196
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:95
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition: IR.cpp:343
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:210
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:932
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:1127
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:651
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:969
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:933
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:1010
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:1072
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1348
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1239
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition: IR.cpp:281
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:1149
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:425
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1204
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1338
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:407
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition: IR.cpp:299
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition: IR.cpp:375
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:673
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:370
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Returns the number of predecessor blocks of the block.
Definition: IR.cpp:1090
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:643
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition: IR.cpp:1058
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:827
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:1164
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:415
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition: IR.cpp:112
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:857
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:460
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1251
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:480
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1214
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition: IR.cpp:334
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:1000
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Skip printing regions.
Definition: IR.cpp:241
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition: IR.cpp:388
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:705
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: IR.cpp:869
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:677
MLIR_CAPI_EXPORTED void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Sets the location of the operation.
Definition: IR.cpp:659
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:446
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:269
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:237
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1243
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1334
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:182
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:252
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:713
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:521
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:989
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:325
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition: IR.cpp:399
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:1132
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:1114
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:709
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:887
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:55
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:981
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:681
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:411
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: IR.cpp:1063
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:817
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1299
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1363
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1266
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:429
MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op)
Compute a hash for the given operation.
Definition: IR.cpp:647
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:922
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:993
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1260
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:732
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:248
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1270
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:1118
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:872
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:639
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Returns the number of successor blocks of the block.
Definition: IR.cpp:1082
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:685
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1247
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1310
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:793
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:108
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:513
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:517
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition: IR.cpp:121
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:450
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition: IR.cpp:305
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:206
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:379
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:104
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1076
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition: IR.cpp:329
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition: IR.cpp:1196
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:256
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1330
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:909
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Parses a module from file and transfers ownership to the caller.
Definition: IR.cpp:437
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:1049
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:504
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:220
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:655
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1281
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:717
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:855
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:1136
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:202
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:178
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:589
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1238
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:77
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:861
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:798
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1151
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:848
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:1145
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:1141
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:71
MLIR_CAPI_EXPORTED size_t mlirModuleHashValue(MlirModule mod)
Compute a hash for the given module.
Definition: IR.cpp:472
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:626
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1289
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:621
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:915
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:190
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:511
void populateIRCore(nanobind::module_ &m)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:438
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:302
static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace)
Definition: IRCore.cpp:311
static nb::callable dunderGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:305
static void bind(nb::module_ &m)
Definition: IRCore.cpp:317
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:262
static void bind(nb::module_ &m)
Definition: IRCore.cpp:273
static void set(nb::object &o, bool enable)
Definition: IRCore.cpp:263
static bool get(const nb::object &)
Definition: IRCore.cpp:268
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1318
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1323
Materialized diagnostic information.
Definition: IRModule.h:342
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:408
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:418
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:409