MLIR  21.0.0git
IRCore.cpp
Go to the documentation of this file.
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Globals.h"
10 #include "IRModule.h"
11 #include "NanobindUtils.h"
12 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14 #include "mlir-c/Debug.h"
15 #include "mlir-c/Diagnostics.h"
16 #include "mlir-c/IR.h"
17 #include "mlir-c/Support.h"
20 #include "nanobind/nanobind.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #include <optional>
26 #include <system_error>
27 #include <utility>
28 
29 namespace nb = nanobind;
30 using namespace nb::literals;
31 using namespace mlir;
32 using namespace mlir::python;
33 
34 using llvm::SmallVector;
35 using llvm::StringRef;
36 using llvm::Twine;
37 
38 //------------------------------------------------------------------------------
39 // Docstrings (trivial, non-duplicated docstrings are included inline).
40 //------------------------------------------------------------------------------
41 
42 static const char kContextParseTypeDocstring[] =
43  R"(Parses the assembly form of a type.
44 
45 Returns a Type object or raises an MLIRError if the type cannot be parsed.
46 
47 See also: https://mlir.llvm.org/docs/LangRef/#type-system
48 )";
49 
51  R"(Gets a Location representing a caller and callsite)";
52 
53 static const char kContextGetFileLocationDocstring[] =
54  R"(Gets a Location representing a file, line and column)";
55 
56 static const char kContextGetFileRangeDocstring[] =
57  R"(Gets a Location representing a file, line and column range)";
58 
59 static const char kContextGetFusedLocationDocstring[] =
60  R"(Gets a Location representing a fused location with optional metadata)";
61 
62 static const char kContextGetNameLocationDocString[] =
63  R"(Gets a Location representing a named location with optional child location)";
64 
65 static const char kModuleParseDocstring[] =
66  R"(Parses a module's assembly format from a string.
67 
68 Returns a new MlirModule or raises an MLIRError if the parsing fails.
69 
70 See also: https://mlir.llvm.org/docs/LangRef/
71 )";
72 
73 static const char kOperationCreateDocstring[] =
74  R"(Creates a new operation.
75 
76 Args:
77  name: Operation name (e.g. "dialect.operation").
78  results: Sequence of Type representing op result types.
79  attributes: Dict of str:Attribute.
80  successors: List of Block for the operation's successors.
81  regions: Number of regions to create.
82  location: A Location object (defaults to resolve from context manager).
83  ip: An InsertionPoint (defaults to resolve from context manager or set to
84  False to disable insertion, even with an insertion point set in the
85  context manager).
86  infer_type: Whether to infer result types.
87 Returns:
88  A new "detached" Operation object. Detached operations can be added
89  to blocks, which causes them to become "attached."
90 )";
91 
92 static const char kOperationPrintDocstring[] =
93  R"(Prints the assembly form of the operation to a file like object.
94 
95 Args:
96  file: The file like object to write to. Defaults to sys.stdout.
97  binary: Whether to write bytes (True) or str (False). Defaults to False.
98  large_elements_limit: Whether to elide elements attributes above this
99  number of elements. Defaults to None (no limit).
100  enable_debug_info: Whether to print debug/location information. Defaults
101  to False.
102  pretty_debug_info: Whether to format debug information for easier reading
103  by a human (warning: the result is unparseable).
104  print_generic_op_form: Whether to print the generic assembly forms of all
105  ops. Defaults to False.
106  use_local_Scope: Whether to print in a way that is more optimized for
107  multi-threaded access but may not be consistent with how the overall
108  module prints.
109  assume_verified: By default, if not printing generic form, the verifier
110  will be run and if it fails, generic form will be printed with a comment
111  about failed verification. While a reasonable default for interactive use,
112  for systematic use, it is often better for the caller to verify explicitly
113  and report failures in a more robust fashion. Set this to True if doing this
114  in order to avoid running a redundant verification. If the IR is actually
115  invalid, behavior is undefined.
116  skip_regions: Whether to skip printing regions. Defaults to False.
117 )";
118 
119 static const char kOperationPrintStateDocstring[] =
120  R"(Prints the assembly form of the operation to a file like object.
121 
122 Args:
123  file: The file like object to write to. Defaults to sys.stdout.
124  binary: Whether to write bytes (True) or str (False). Defaults to False.
125  state: AsmState capturing the operation numbering and flags.
126 )";
127 
128 static const char kOperationGetAsmDocstring[] =
129  R"(Gets the assembly form of the operation with all options available.
130 
131 Args:
132  binary: Whether to return a bytes (True) or str (False) object. Defaults to
133  False.
134  ... others ...: See the print() method for common keyword arguments for
135  configuring the printout.
136 Returns:
137  Either a bytes or str object, depending on the setting of the 'binary'
138  argument.
139 )";
140 
141 static const char kOperationPrintBytecodeDocstring[] =
142  R"(Write the bytecode form of the operation to a file like object.
143 
144 Args:
145  file: The file like object to write to.
146  desired_version: The version of bytecode to emit.
147 Returns:
148  The bytecode writer status.
149 )";
150 
151 static const char kOperationStrDunderDocstring[] =
152  R"(Gets the assembly form of the operation with default options.
153 
154 If more advanced control over the assembly formatting or I/O options is needed,
155 use the dedicated print or get_asm method, which supports keyword arguments to
156 customize behavior.
157 )";
158 
159 static const char kDumpDocstring[] =
160  R"(Dumps a debug representation of the object to stderr.)";
161 
162 static const char kAppendBlockDocstring[] =
163  R"(Appends a new block, with argument types as positional args.
164 
165 Returns:
166  The created block.
167 )";
168 
169 static const char kValueDunderStrDocstring[] =
170  R"(Returns the string form of the value.
171 
172 If the value is a block argument, this is the assembly form of its type and the
173 position in the argument list. If the value is an operation result, this is
174 equivalent to printing the operation that produced it.
175 )";
176 
177 static const char kGetNameAsOperand[] =
178  R"(Returns the string form of value as an operand (i.e., the ValueID).
179 )";
180 
182  R"(Replace all uses of value with the new value, updating anything in
183 the IR that uses 'self' to use the other value instead.
184 )";
185 
187  R"("Replace all uses of this value with the 'with' value, except for those
188 in 'exceptions'. 'exceptions' can be either a single operation or a list of
189 operations.
190 )";
191 
192 //------------------------------------------------------------------------------
193 // Utilities.
194 //------------------------------------------------------------------------------
195 
196 /// Helper for creating an @classmethod.
197 template <class Func, typename... Args>
198 nb::object classmethod(Func f, Args... args) {
199  nb::object cf = nb::cpp_function(f, args...);
200  return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
201 }
202 
203 static nb::object
204 createCustomDialectWrapper(const std::string &dialectNamespace,
205  nb::object dialectDescriptor) {
206  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
207  if (!dialectClass) {
208  // Use the base class.
209  return nb::cast(PyDialect(std::move(dialectDescriptor)));
210  }
211 
212  // Create the custom implementation.
213  return (*dialectClass)(std::move(dialectDescriptor));
214 }
215 
216 static MlirStringRef toMlirStringRef(const std::string &s) {
217  return mlirStringRefCreate(s.data(), s.size());
218 }
219 
220 static MlirStringRef toMlirStringRef(std::string_view s) {
221  return mlirStringRefCreate(s.data(), s.size());
222 }
223 
224 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
225  return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
226 }
227 
228 /// Create a block, using the current location context if no locations are
229 /// specified.
230 static MlirBlock createBlock(const nb::sequence &pyArgTypes,
231  const std::optional<nb::sequence> &pyArgLocs) {
232  SmallVector<MlirType> argTypes;
233  argTypes.reserve(nb::len(pyArgTypes));
234  for (const auto &pyType : pyArgTypes)
235  argTypes.push_back(nb::cast<PyType &>(pyType));
236 
238  if (pyArgLocs) {
239  argLocs.reserve(nb::len(*pyArgLocs));
240  for (const auto &pyLoc : *pyArgLocs)
241  argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
242  } else if (!argTypes.empty()) {
243  argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
244  }
245 
246  if (argTypes.size() != argLocs.size())
247  throw nb::value_error(("Expected " + Twine(argTypes.size()) +
248  " locations, got: " + Twine(argLocs.size()))
249  .str()
250  .c_str());
251  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
252 }
253 
254 /// Wrapper for the global LLVM debugging flag.
256  static void set(nb::object &o, bool enable) {
257  nb::ft_lock_guard lock(mutex);
258  mlirEnableGlobalDebug(enable);
259  }
260 
261  static bool get(const nb::object &) {
262  nb::ft_lock_guard lock(mutex);
263  return mlirIsGlobalDebugEnabled();
264  }
265 
266  static void bind(nb::module_ &m) {
267  // Debug flags.
268  nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
269  .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
270  &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
271  .def_static(
272  "set_types",
273  [](const std::string &type) {
274  nb::ft_lock_guard lock(mutex);
275  mlirSetGlobalDebugType(type.c_str());
276  },
277  "types"_a, "Sets specific debug types to be produced by LLVM")
278  .def_static("set_types", [](const std::vector<std::string> &types) {
279  std::vector<const char *> pointers;
280  pointers.reserve(types.size());
281  for (const std::string &str : types)
282  pointers.push_back(str.c_str());
283  nb::ft_lock_guard lock(mutex);
284  mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
285  });
286  }
287 
288 private:
289  static nb::ft_mutex mutex;
290 };
291 
292 nb::ft_mutex PyGlobalDebugFlag::mutex;
293 
295  static bool dunderContains(const std::string &attributeKind) {
296  return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
297  }
298  static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
299  auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
300  if (!builder)
301  throw nb::key_error(attributeKind.c_str());
302  return *builder;
303  }
304  static void dunderSetItemNamed(const std::string &attributeKind,
305  nb::callable func, bool replace) {
306  PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
307  replace);
308  }
309 
310  static void bind(nb::module_ &m) {
311  nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
312  .def_static("contains", &PyAttrBuilderMap::dunderContains)
313  .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
314  .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
315  "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
316  "Register an attribute builder for building MLIR "
317  "attributes from python values.");
318  }
319 };
320 
321 //------------------------------------------------------------------------------
322 // PyBlock
323 //------------------------------------------------------------------------------
324 
325 nb::object PyBlock::getCapsule() {
326  return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
327 }
328 
329 //------------------------------------------------------------------------------
330 // Collections.
331 //------------------------------------------------------------------------------
332 
333 namespace {
334 
335 class PyRegionIterator {
336 public:
337  PyRegionIterator(PyOperationRef operation)
338  : operation(std::move(operation)) {}
339 
340  PyRegionIterator &dunderIter() { return *this; }
341 
342  PyRegion dunderNext() {
343  operation->checkValid();
344  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
345  throw nb::stop_iteration();
346  }
347  MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
348  return PyRegion(operation, region);
349  }
350 
351  static void bind(nb::module_ &m) {
352  nb::class_<PyRegionIterator>(m, "RegionIterator")
353  .def("__iter__", &PyRegionIterator::dunderIter)
354  .def("__next__", &PyRegionIterator::dunderNext);
355  }
356 
357 private:
358  PyOperationRef operation;
359  int nextIndex = 0;
360 };
361 
362 /// Regions of an op are fixed length and indexed numerically so are represented
363 /// with a sequence-like container.
364 class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
365 public:
366  static constexpr const char *pyClassName = "RegionSequence";
367 
368  PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
369  intptr_t length = -1, intptr_t step = 1)
370  : Sliceable(startIndex,
371  length == -1 ? mlirOperationGetNumRegions(operation->get())
372  : length,
373  step),
374  operation(std::move(operation)) {}
375 
376  PyRegionIterator dunderIter() {
377  operation->checkValid();
378  return PyRegionIterator(operation);
379  }
380 
381  static void bindDerived(ClassTy &c) {
382  c.def("__iter__", &PyRegionList::dunderIter);
383  }
384 
385 private:
386  /// Give the parent CRTP class access to hook implementations below.
387  friend class Sliceable<PyRegionList, PyRegion>;
388 
389  intptr_t getRawNumElements() {
390  operation->checkValid();
391  return mlirOperationGetNumRegions(operation->get());
392  }
393 
394  PyRegion getRawElement(intptr_t pos) {
395  operation->checkValid();
396  return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
397  }
398 
399  PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
400  return PyRegionList(operation, startIndex, length, step);
401  }
402 
403  PyOperationRef operation;
404 };
405 
406 class PyBlockIterator {
407 public:
408  PyBlockIterator(PyOperationRef operation, MlirBlock next)
409  : operation(std::move(operation)), next(next) {}
410 
411  PyBlockIterator &dunderIter() { return *this; }
412 
413  PyBlock dunderNext() {
414  operation->checkValid();
415  if (mlirBlockIsNull(next)) {
416  throw nb::stop_iteration();
417  }
418 
419  PyBlock returnBlock(operation, next);
420  next = mlirBlockGetNextInRegion(next);
421  return returnBlock;
422  }
423 
424  static void bind(nb::module_ &m) {
425  nb::class_<PyBlockIterator>(m, "BlockIterator")
426  .def("__iter__", &PyBlockIterator::dunderIter)
427  .def("__next__", &PyBlockIterator::dunderNext);
428  }
429 
430 private:
431  PyOperationRef operation;
432  MlirBlock next;
433 };
434 
435 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
436 /// we present them as a more full-featured list-like container but optimize
437 /// it for forward iteration. Blocks are always owned by a region.
438 class PyBlockList {
439 public:
440  PyBlockList(PyOperationRef operation, MlirRegion region)
441  : operation(std::move(operation)), region(region) {}
442 
443  PyBlockIterator dunderIter() {
444  operation->checkValid();
445  return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
446  }
447 
448  intptr_t dunderLen() {
449  operation->checkValid();
450  intptr_t count = 0;
451  MlirBlock block = mlirRegionGetFirstBlock(region);
452  while (!mlirBlockIsNull(block)) {
453  count += 1;
454  block = mlirBlockGetNextInRegion(block);
455  }
456  return count;
457  }
458 
459  PyBlock dunderGetItem(intptr_t index) {
460  operation->checkValid();
461  if (index < 0) {
462  index += dunderLen();
463  }
464  if (index < 0) {
465  throw nb::index_error("attempt to access out of bounds block");
466  }
467  MlirBlock block = mlirRegionGetFirstBlock(region);
468  while (!mlirBlockIsNull(block)) {
469  if (index == 0) {
470  return PyBlock(operation, block);
471  }
472  block = mlirBlockGetNextInRegion(block);
473  index -= 1;
474  }
475  throw nb::index_error("attempt to access out of bounds block");
476  }
477 
478  PyBlock appendBlock(const nb::args &pyArgTypes,
479  const std::optional<nb::sequence> &pyArgLocs) {
480  operation->checkValid();
481  MlirBlock block =
482  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
483  mlirRegionAppendOwnedBlock(region, block);
484  return PyBlock(operation, block);
485  }
486 
487  static void bind(nb::module_ &m) {
488  nb::class_<PyBlockList>(m, "BlockList")
489  .def("__getitem__", &PyBlockList::dunderGetItem)
490  .def("__iter__", &PyBlockList::dunderIter)
491  .def("__len__", &PyBlockList::dunderLen)
492  .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
493  nb::arg("args"), nb::kw_only(),
494  nb::arg("arg_locs") = std::nullopt);
495  }
496 
497 private:
498  PyOperationRef operation;
499  MlirRegion region;
500 };
501 
502 class PyOperationIterator {
503 public:
504  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
505  : parentOperation(std::move(parentOperation)), next(next) {}
506 
507  PyOperationIterator &dunderIter() { return *this; }
508 
509  nb::object dunderNext() {
510  parentOperation->checkValid();
511  if (mlirOperationIsNull(next)) {
512  throw nb::stop_iteration();
513  }
514 
515  PyOperationRef returnOperation =
516  PyOperation::forOperation(parentOperation->getContext(), next);
517  next = mlirOperationGetNextInBlock(next);
518  return returnOperation->createOpView();
519  }
520 
521  static void bind(nb::module_ &m) {
522  nb::class_<PyOperationIterator>(m, "OperationIterator")
523  .def("__iter__", &PyOperationIterator::dunderIter)
524  .def("__next__", &PyOperationIterator::dunderNext);
525  }
526 
527 private:
528  PyOperationRef parentOperation;
529  MlirOperation next;
530 };
531 
532 /// Operations are exposed by the C-API as a forward-only linked list. In
533 /// Python, we present them as a more full-featured list-like container but
534 /// optimize it for forward iteration. Iterable operations are always owned
535 /// by a block.
536 class PyOperationList {
537 public:
538  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
539  : parentOperation(std::move(parentOperation)), block(block) {}
540 
541  PyOperationIterator dunderIter() {
542  parentOperation->checkValid();
543  return PyOperationIterator(parentOperation,
545  }
546 
547  intptr_t dunderLen() {
548  parentOperation->checkValid();
549  intptr_t count = 0;
550  MlirOperation childOp = mlirBlockGetFirstOperation(block);
551  while (!mlirOperationIsNull(childOp)) {
552  count += 1;
553  childOp = mlirOperationGetNextInBlock(childOp);
554  }
555  return count;
556  }
557 
558  nb::object dunderGetItem(intptr_t index) {
559  parentOperation->checkValid();
560  if (index < 0) {
561  index += dunderLen();
562  }
563  if (index < 0) {
564  throw nb::index_error("attempt to access out of bounds operation");
565  }
566  MlirOperation childOp = mlirBlockGetFirstOperation(block);
567  while (!mlirOperationIsNull(childOp)) {
568  if (index == 0) {
569  return PyOperation::forOperation(parentOperation->getContext(), childOp)
570  ->createOpView();
571  }
572  childOp = mlirOperationGetNextInBlock(childOp);
573  index -= 1;
574  }
575  throw nb::index_error("attempt to access out of bounds operation");
576  }
577 
578  static void bind(nb::module_ &m) {
579  nb::class_<PyOperationList>(m, "OperationList")
580  .def("__getitem__", &PyOperationList::dunderGetItem)
581  .def("__iter__", &PyOperationList::dunderIter)
582  .def("__len__", &PyOperationList::dunderLen);
583  }
584 
585 private:
586  PyOperationRef parentOperation;
587  MlirBlock block;
588 };
589 
590 class PyOpOperand {
591 public:
592  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
593 
594  nb::object getOwner() {
595  MlirOperation owner = mlirOpOperandGetOwner(opOperand);
596  PyMlirContextRef context =
597  PyMlirContext::forContext(mlirOperationGetContext(owner));
598  return PyOperation::forOperation(context, owner)->createOpView();
599  }
600 
601  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
602 
603  static void bind(nb::module_ &m) {
604  nb::class_<PyOpOperand>(m, "OpOperand")
605  .def_prop_ro("owner", &PyOpOperand::getOwner)
606  .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
607  }
608 
609 private:
610  MlirOpOperand opOperand;
611 };
612 
613 class PyOpOperandIterator {
614 public:
615  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
616 
617  PyOpOperandIterator &dunderIter() { return *this; }
618 
619  PyOpOperand dunderNext() {
620  if (mlirOpOperandIsNull(opOperand))
621  throw nb::stop_iteration();
622 
623  PyOpOperand returnOpOperand(opOperand);
624  opOperand = mlirOpOperandGetNextUse(opOperand);
625  return returnOpOperand;
626  }
627 
628  static void bind(nb::module_ &m) {
629  nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
630  .def("__iter__", &PyOpOperandIterator::dunderIter)
631  .def("__next__", &PyOpOperandIterator::dunderNext);
632  }
633 
634 private:
635  MlirOpOperand opOperand;
636 };
637 
638 } // namespace
639 
640 //------------------------------------------------------------------------------
641 // PyMlirContext
642 //------------------------------------------------------------------------------
643 
644 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
645  nb::gil_scoped_acquire acquire;
646  nb::ft_lock_guard lock(live_contexts_mutex);
647  auto &liveContexts = getLiveContexts();
648  liveContexts[context.ptr] = this;
649 }
650 
652  // Note that the only public way to construct an instance is via the
653  // forContext method, which always puts the associated handle into
654  // liveContexts.
655  nb::gil_scoped_acquire acquire;
656  {
657  nb::ft_lock_guard lock(live_contexts_mutex);
658  getLiveContexts().erase(context.ptr);
659  }
660  mlirContextDestroy(context);
661 }
662 
664  return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
665 }
666 
667 nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
668  MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
669  if (mlirContextIsNull(rawContext))
670  throw nb::python_error();
671  return forContext(rawContext).releaseObject();
672 }
673 
675  nb::gil_scoped_acquire acquire;
676  nb::ft_lock_guard lock(live_contexts_mutex);
677  auto &liveContexts = getLiveContexts();
678  auto it = liveContexts.find(context.ptr);
679  if (it == liveContexts.end()) {
680  // Create.
681  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
682  nb::object pyRef = nb::cast(unownedContextWrapper);
683  assert(pyRef && "cast to nb::object failed");
684  liveContexts[context.ptr] = unownedContextWrapper;
685  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
686  }
687  // Use existing.
688  nb::object pyRef = nb::cast(it->second);
689  return PyMlirContextRef(it->second, std::move(pyRef));
690 }
691 
692 nb::ft_mutex PyMlirContext::live_contexts_mutex;
693 
694 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
695  static LiveContextMap liveContexts;
696  return liveContexts;
697 }
698 
700  nb::ft_lock_guard lock(live_contexts_mutex);
701  return getLiveContexts().size();
702 }
703 
705  nb::ft_lock_guard lock(liveOperationsMutex);
706  return liveOperations.size();
707 }
708 
709 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
710  std::vector<PyOperation *> liveObjects;
711  nb::ft_lock_guard lock(liveOperationsMutex);
712  for (auto &entry : liveOperations)
713  liveObjects.push_back(entry.second.second);
714  return liveObjects;
715 }
716 
718 
719  LiveOperationMap operations;
720  {
721  nb::ft_lock_guard lock(liveOperationsMutex);
722  std::swap(operations, liveOperations);
723  }
724  for (auto &op : operations)
725  op.second.second->setInvalid();
726  size_t numInvalidated = operations.size();
727  return numInvalidated;
728 }
729 
730 void PyMlirContext::clearOperation(MlirOperation op) {
731  PyOperation *py_op;
732  {
733  nb::ft_lock_guard lock(liveOperationsMutex);
734  auto it = liveOperations.find(op.ptr);
735  if (it == liveOperations.end()) {
736  return;
737  }
738  py_op = it->second.second;
739  liveOperations.erase(it);
740  }
741  py_op->setInvalid();
742 }
743 
745  typedef struct {
746  PyOperation &rootOp;
747  bool rootSeen;
748  } callBackData;
749  callBackData data{op.getOperation(), false};
750  // Mark all ops below the op that the passmanager will be rooted
751  // at (but not op itself - note the preorder) as invalid.
752  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
753  void *userData) {
754  callBackData *data = static_cast<callBackData *>(userData);
755  if (LLVM_LIKELY(data->rootSeen))
756  data->rootOp.getOperation().getContext()->clearOperation(op);
757  else
758  data->rootSeen = true;
760  };
761  mlirOperationWalk(op.getOperation(), invalidatingCallback,
762  static_cast<void *>(&data), MlirWalkPreOrder);
763 }
764 void PyMlirContext::clearOperationsInside(MlirOperation op) {
767 }
768 
770  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771  void *userData) {
772  PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
773  contextRef->clearOperation(op);
775  };
776  mlirOperationWalk(op.getOperation(), invalidatingCallback,
778 }
779 
780 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
781 
782 nb::object PyMlirContext::contextEnter(nb::object context) {
783  return PyThreadContextEntry::pushContext(context);
784 }
785 
786 void PyMlirContext::contextExit(const nb::object &excType,
787  const nb::object &excVal,
788  const nb::object &excTb) {
790 }
791 
792 nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
793  // Note that ownership is transferred to the delete callback below by way of
794  // an explicit inc_ref (borrow).
795  PyDiagnosticHandler *pyHandler =
796  new PyDiagnosticHandler(get(), std::move(callback));
797  nb::object pyHandlerObject =
798  nb::cast(pyHandler, nb::rv_policy::take_ownership);
799  pyHandlerObject.inc_ref();
800 
801  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
802  // guaranteed to be known to pybind.
803  auto handlerCallback =
804  +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
805  PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
806  nb::object pyDiagnosticObject =
807  nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
808 
809  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
810  bool result = false;
811  {
812  // Since this can be called from arbitrary C++ contexts, always get the
813  // gil.
814  nb::gil_scoped_acquire gil;
815  try {
816  result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
817  } catch (std::exception &e) {
818  fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
819  e.what());
820  pyHandler->hadError = true;
821  }
822  }
823 
824  pyDiagnostic->invalidate();
826  };
827  auto deleteCallback = +[](void *userData) {
828  auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
829  assert(pyHandler->registeredID && "handler is not registered");
830  pyHandler->registeredID.reset();
831 
832  // Decrement reference, balancing the inc_ref() above.
833  nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
834  pyHandlerObject.dec_ref();
835  };
836 
837  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
838  get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
839  return pyHandlerObject;
840 }
841 
842 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
843  void *userData) {
844  auto *self = static_cast<ErrorCapture *>(userData);
845  // Check if the context requested we emit errors instead of capturing them.
846  if (self->ctx->emitErrorDiagnostics)
847  return mlirLogicalResultFailure();
848 
850  return mlirLogicalResultFailure();
851 
852  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
853  return mlirLogicalResultSuccess();
854 }
855 
858  if (!context) {
859  throw std::runtime_error(
860  "An MLIR function requires a Context but none was provided in the call "
861  "or from the surrounding environment. Either pass to the function with "
862  "a 'context=' argument or establish a default using 'with Context():'");
863  }
864  return *context;
865 }
866 
867 //------------------------------------------------------------------------------
868 // PyThreadContextEntry management
869 //------------------------------------------------------------------------------
870 
871 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
872  static thread_local std::vector<PyThreadContextEntry> stack;
873  return stack;
874 }
875 
877  auto &stack = getStack();
878  if (stack.empty())
879  return nullptr;
880  return &stack.back();
881 }
882 
883 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
884  nb::object insertionPoint,
885  nb::object location) {
886  auto &stack = getStack();
887  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
888  std::move(location));
889  // If the new stack has more than one entry and the context of the new top
890  // entry matches the previous, copy the insertionPoint and location from the
891  // previous entry if missing from the new top entry.
892  if (stack.size() > 1) {
893  auto &prev = *(stack.rbegin() + 1);
894  auto &current = stack.back();
895  if (current.context.is(prev.context)) {
896  // Default non-context objects from the previous entry.
897  if (!current.insertionPoint)
898  current.insertionPoint = prev.insertionPoint;
899  if (!current.location)
900  current.location = prev.location;
901  }
902  }
903 }
904 
906  if (!context)
907  return nullptr;
908  return nb::cast<PyMlirContext *>(context);
909 }
910 
912  if (!insertionPoint)
913  return nullptr;
914  return nb::cast<PyInsertionPoint *>(insertionPoint);
915 }
916 
918  if (!location)
919  return nullptr;
920  return nb::cast<PyLocation *>(location);
921 }
922 
924  auto *tos = getTopOfStack();
925  return tos ? tos->getContext() : nullptr;
926 }
927 
929  auto *tos = getTopOfStack();
930  return tos ? tos->getInsertionPoint() : nullptr;
931 }
932 
934  auto *tos = getTopOfStack();
935  return tos ? tos->getLocation() : nullptr;
936 }
937 
938 nb::object PyThreadContextEntry::pushContext(nb::object context) {
939  push(FrameKind::Context, /*context=*/context,
940  /*insertionPoint=*/nb::object(),
941  /*location=*/nb::object());
942  return context;
943 }
944 
946  auto &stack = getStack();
947  if (stack.empty())
948  throw std::runtime_error("Unbalanced Context enter/exit");
949  auto &tos = stack.back();
950  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
951  throw std::runtime_error("Unbalanced Context enter/exit");
952  stack.pop_back();
953 }
954 
955 nb::object
956 PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
957  PyInsertionPoint &insertionPoint =
958  nb::cast<PyInsertionPoint &>(insertionPointObj);
959  nb::object contextObj =
960  insertionPoint.getBlock().getParentOperation()->getContext().getObject();
961  push(FrameKind::InsertionPoint,
962  /*context=*/contextObj,
963  /*insertionPoint=*/insertionPointObj,
964  /*location=*/nb::object());
965  return insertionPointObj;
966 }
967 
969  auto &stack = getStack();
970  if (stack.empty())
971  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
972  auto &tos = stack.back();
973  if (tos.frameKind != FrameKind::InsertionPoint &&
974  tos.getInsertionPoint() != &insertionPoint)
975  throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
976  stack.pop_back();
977 }
978 
979 nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
980  PyLocation &location = nb::cast<PyLocation &>(locationObj);
981  nb::object contextObj = location.getContext().getObject();
982  push(FrameKind::Location, /*context=*/contextObj,
983  /*insertionPoint=*/nb::object(),
984  /*location=*/locationObj);
985  return locationObj;
986 }
987 
989  auto &stack = getStack();
990  if (stack.empty())
991  throw std::runtime_error("Unbalanced Location enter/exit");
992  auto &tos = stack.back();
993  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
994  throw std::runtime_error("Unbalanced Location enter/exit");
995  stack.pop_back();
996 }
997 
998 //------------------------------------------------------------------------------
999 // PyDiagnostic*
1000 //------------------------------------------------------------------------------
1001 
1003  valid = false;
1004  if (materializedNotes) {
1005  for (nb::handle noteObject : *materializedNotes) {
1006  PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
1007  note->invalidate();
1008  }
1009  }
1010 }
1011 
1013  nb::object callback)
1014  : context(context), callback(std::move(callback)) {}
1015 
1017 
1019  if (!registeredID)
1020  return;
1021  MlirDiagnosticHandlerID localID = *registeredID;
1022  mlirContextDetachDiagnosticHandler(context, localID);
1023  assert(!registeredID && "should have unregistered");
1024  // Not strictly necessary but keeps stale pointers from being around to cause
1025  // issues.
1026  context = {nullptr};
1027 }
1028 
1029 void PyDiagnostic::checkValid() {
1030  if (!valid) {
1031  throw std::invalid_argument(
1032  "Diagnostic is invalid (used outside of callback)");
1033  }
1034 }
1035 
1037  checkValid();
1038  return mlirDiagnosticGetSeverity(diagnostic);
1039 }
1040 
1042  checkValid();
1043  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1044  MlirContext context = mlirLocationGetContext(loc);
1045  return PyLocation(PyMlirContext::forContext(context), loc);
1046 }
1047 
1049  checkValid();
1050  nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
1051  PyFileAccumulator accum(fileObject, /*binary=*/false);
1052  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
1053  return nb::cast<nb::str>(fileObject.attr("getvalue")());
1054 }
1055 
1057  checkValid();
1058  if (materializedNotes)
1059  return *materializedNotes;
1060  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1061  nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
1062  for (intptr_t i = 0; i < numNotes; ++i) {
1063  MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1064  nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1065  PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
1066  }
1067  materializedNotes = std::move(notes);
1068 
1069  return *materializedNotes;
1070 }
1071 
1073  std::vector<DiagnosticInfo> notes;
1074  for (nb::handle n : getNotes())
1075  notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1076  return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1077  std::move(notes)};
1078 }
1079 
1080 //------------------------------------------------------------------------------
1081 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1082 //------------------------------------------------------------------------------
1083 
1084 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1085  bool attrError) {
1086  MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1087  {key.data(), key.size()});
1088  if (mlirDialectIsNull(dialect)) {
1089  std::string msg = (Twine("Dialect '") + key + "' not found").str();
1090  if (attrError)
1091  throw nb::attribute_error(msg.c_str());
1092  throw nb::index_error(msg.c_str());
1093  }
1094  return dialect;
1095 }
1096 
1098  return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1099 }
1100 
1102  MlirDialectRegistry rawRegistry =
1103  mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1104  if (mlirDialectRegistryIsNull(rawRegistry))
1105  throw nb::python_error();
1106  return PyDialectRegistry(rawRegistry);
1107 }
1108 
1109 //------------------------------------------------------------------------------
1110 // PyLocation
1111 //------------------------------------------------------------------------------
1112 
1114  return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1115 }
1116 
1118  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1119  if (mlirLocationIsNull(rawLoc))
1120  throw nb::python_error();
1122  rawLoc);
1123 }
1124 
1125 nb::object PyLocation::contextEnter(nb::object locationObj) {
1126  return PyThreadContextEntry::pushLocation(locationObj);
1127 }
1128 
1129 void PyLocation::contextExit(const nb::object &excType,
1130  const nb::object &excVal,
1131  const nb::object &excTb) {
1133 }
1134 
1136  auto *location = PyThreadContextEntry::getDefaultLocation();
1137  if (!location) {
1138  throw std::runtime_error(
1139  "An MLIR function requires a Location but none was provided in the "
1140  "call or from the surrounding environment. Either pass to the function "
1141  "with a 'loc=' argument or establish a default using 'with loc:'");
1142  }
1143  return *location;
1144 }
1145 
1146 //------------------------------------------------------------------------------
1147 // PyModule
1148 //------------------------------------------------------------------------------
1149 
1150 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1151  : BaseContextObject(std::move(contextRef)), module(module) {}
1152 
1154  nb::gil_scoped_acquire acquire;
1155  auto &liveModules = getContext()->liveModules;
1156  assert(liveModules.count(module.ptr) == 1 &&
1157  "destroying module not in live map");
1158  liveModules.erase(module.ptr);
1159  mlirModuleDestroy(module);
1160 }
1161 
1162 PyModuleRef PyModule::forModule(MlirModule module) {
1163  MlirContext context = mlirModuleGetContext(module);
1164  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1165 
1166  nb::gil_scoped_acquire acquire;
1167  auto &liveModules = contextRef->liveModules;
1168  auto it = liveModules.find(module.ptr);
1169  if (it == liveModules.end()) {
1170  // Create.
1171  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1172  // Note that the default return value policy on cast is automatic_reference,
1173  // which does not take ownership (delete will not be called).
1174  // Just be explicit.
1175  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1176  unownedModule->handle = pyRef;
1177  liveModules[module.ptr] =
1178  std::make_pair(unownedModule->handle, unownedModule);
1179  return PyModuleRef(unownedModule, std::move(pyRef));
1180  }
1181  // Use existing.
1182  PyModule *existing = it->second.second;
1183  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1184  return PyModuleRef(existing, std::move(pyRef));
1185 }
1186 
1187 nb::object PyModule::createFromCapsule(nb::object capsule) {
1188  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1189  if (mlirModuleIsNull(rawModule))
1190  throw nb::python_error();
1191  return forModule(rawModule).releaseObject();
1192 }
1193 
1194 nb::object PyModule::getCapsule() {
1195  return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1196 }
1197 
1198 //------------------------------------------------------------------------------
1199 // PyOperation
1200 //------------------------------------------------------------------------------
1201 
1202 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1203  : BaseContextObject(std::move(contextRef)), operation(operation) {}
1204 
1206  // If the operation has already been invalidated there is nothing to do.
1207  if (!valid)
1208  return;
1209 
1210  // Otherwise, invalidate the operation and remove it from live map when it is
1211  // attached.
1212  if (isAttached()) {
1213  getContext()->clearOperation(*this);
1214  } else {
1215  // And destroy it when it is detached, i.e. owned by Python, in which case
1216  // all nested operations must be invalidated at removed from the live map as
1217  // well.
1218  erase();
1219  }
1220 }
1221 
1222 namespace {
1223 
1224 // Constructs a new object of type T in-place on the Python heap, returning a
1225 // PyObjectRef to it, loosely analogous to std::make_shared<T>().
1226 template <typename T, class... Args>
1227 PyObjectRef<T> makeObjectRef(Args &&...args) {
1228  nb::handle type = nb::type<T>();
1229  nb::object instance = nb::inst_alloc(type);
1230  T *ptr = nb::inst_ptr<T>(instance);
1231  new (ptr) T(std::forward<Args>(args)...);
1232  nb::inst_mark_ready(instance);
1233  return PyObjectRef<T>(ptr, std::move(instance));
1234 }
1235 
1236 } // namespace
1237 
1238 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1239  MlirOperation operation,
1240  nb::object parentKeepAlive) {
1241  // Create.
1242  PyOperationRef unownedOperation =
1243  makeObjectRef<PyOperation>(std::move(contextRef), operation);
1244  unownedOperation->handle = unownedOperation.getObject();
1245  if (parentKeepAlive) {
1246  unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1247  }
1248  return unownedOperation;
1249 }
1250 
1252  MlirOperation operation,
1253  nb::object parentKeepAlive) {
1254  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1255  auto &liveOperations = contextRef->liveOperations;
1256  auto it = liveOperations.find(operation.ptr);
1257  if (it == liveOperations.end()) {
1258  // Create.
1259  PyOperationRef result = createInstance(std::move(contextRef), operation,
1260  std::move(parentKeepAlive));
1261  liveOperations[operation.ptr] =
1262  std::make_pair(result.getObject(), result.get());
1263  return result;
1264  }
1265  // Use existing.
1266  PyOperation *existing = it->second.second;
1267  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1268  return PyOperationRef(existing, std::move(pyRef));
1269 }
1270 
1272  MlirOperation operation,
1273  nb::object parentKeepAlive) {
1274  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1275  auto &liveOperations = contextRef->liveOperations;
1276  assert(liveOperations.count(operation.ptr) == 0 &&
1277  "cannot create detached operation that already exists");
1278  (void)liveOperations;
1279  PyOperationRef created = createInstance(std::move(contextRef), operation,
1280  std::move(parentKeepAlive));
1281  liveOperations[operation.ptr] =
1282  std::make_pair(created.getObject(), created.get());
1283  created->attached = false;
1284  return created;
1285 }
1286 
1288  const std::string &sourceStr,
1289  const std::string &sourceName) {
1290  PyMlirContext::ErrorCapture errors(contextRef);
1291  MlirOperation op =
1292  mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1293  toMlirStringRef(sourceName));
1294  if (mlirOperationIsNull(op))
1295  throw MLIRError("Unable to parse operation assembly", errors.take());
1296  return PyOperation::createDetached(std::move(contextRef), op);
1297 }
1298 
1300  if (!valid) {
1301  throw std::runtime_error("the operation has been invalidated");
1302  }
1303 }
1304 
1305 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1306  bool enableDebugInfo, bool prettyDebugInfo,
1307  bool printGenericOpForm, bool useLocalScope,
1308  bool useNameLocAsPrefix, bool assumeVerified,
1309  nb::object fileObject, bool binary,
1310  bool skipRegions) {
1311  PyOperation &operation = getOperation();
1312  operation.checkValid();
1313  if (fileObject.is_none())
1314  fileObject = nb::module_::import_("sys").attr("stdout");
1315 
1316  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1317  if (largeElementsLimit) {
1318  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1319  mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit);
1320  }
1321  if (enableDebugInfo)
1322  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1323  /*prettyForm=*/prettyDebugInfo);
1324  if (printGenericOpForm)
1326  if (useLocalScope)
1328  if (assumeVerified)
1330  if (skipRegions)
1332  if (useNameLocAsPrefix)
1334 
1335  PyFileAccumulator accum(fileObject, binary);
1336  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1337  accum.getUserData());
1339 }
1340 
1341 void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1342  bool binary) {
1343  PyOperation &operation = getOperation();
1344  operation.checkValid();
1345  if (fileObject.is_none())
1346  fileObject = nb::module_::import_("sys").attr("stdout");
1347  PyFileAccumulator accum(fileObject, binary);
1348  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1349  accum.getUserData());
1350 }
1351 
1352 void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1353  std::optional<int64_t> bytecodeVersion) {
1354  PyOperation &operation = getOperation();
1355  operation.checkValid();
1356  PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1357 
1358  if (!bytecodeVersion.has_value())
1359  return mlirOperationWriteBytecode(operation, accum.getCallback(),
1360  accum.getUserData());
1361 
1362  MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1365  operation, config, accum.getCallback(), accum.getUserData());
1367  if (mlirLogicalResultIsFailure(res))
1368  throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1369  Twine(*bytecodeVersion))
1370  .str()
1371  .c_str());
1372 }
1373 
1375  std::function<MlirWalkResult(MlirOperation)> callback,
1376  MlirWalkOrder walkOrder) {
1377  PyOperation &operation = getOperation();
1378  operation.checkValid();
1379  struct UserData {
1380  std::function<MlirWalkResult(MlirOperation)> callback;
1381  bool gotException;
1382  std::string exceptionWhat;
1383  nb::object exceptionType;
1384  };
1385  UserData userData{callback, false, {}, {}};
1386  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1387  void *userData) {
1388  UserData *calleeUserData = static_cast<UserData *>(userData);
1389  try {
1390  return (calleeUserData->callback)(op);
1391  } catch (nb::python_error &e) {
1392  calleeUserData->gotException = true;
1393  calleeUserData->exceptionWhat = std::string(e.what());
1394  calleeUserData->exceptionType = nb::borrow(e.type());
1396  }
1397  };
1398  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1399  if (userData.gotException) {
1400  std::string message("Exception raised in callback: ");
1401  message.append(userData.exceptionWhat);
1402  throw std::runtime_error(message);
1403  }
1404 }
1405 
1406 nb::object PyOperationBase::getAsm(bool binary,
1407  std::optional<int64_t> largeElementsLimit,
1408  bool enableDebugInfo, bool prettyDebugInfo,
1409  bool printGenericOpForm, bool useLocalScope,
1410  bool useNameLocAsPrefix, bool assumeVerified,
1411  bool skipRegions) {
1412  nb::object fileObject;
1413  if (binary) {
1414  fileObject = nb::module_::import_("io").attr("BytesIO")();
1415  } else {
1416  fileObject = nb::module_::import_("io").attr("StringIO")();
1417  }
1418  print(/*largeElementsLimit=*/largeElementsLimit,
1419  /*enableDebugInfo=*/enableDebugInfo,
1420  /*prettyDebugInfo=*/prettyDebugInfo,
1421  /*printGenericOpForm=*/printGenericOpForm,
1422  /*useLocalScope=*/useLocalScope,
1423  /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1424  /*assumeVerified=*/assumeVerified,
1425  /*fileObject=*/fileObject,
1426  /*binary=*/binary,
1427  /*skipRegions=*/skipRegions);
1428 
1429  return fileObject.attr("getvalue")();
1430 }
1431 
1433  PyOperation &operation = getOperation();
1434  PyOperation &otherOp = other.getOperation();
1435  operation.checkValid();
1436  otherOp.checkValid();
1437  mlirOperationMoveAfter(operation, otherOp);
1438  operation.parentKeepAlive = otherOp.parentKeepAlive;
1439 }
1440 
1442  PyOperation &operation = getOperation();
1443  PyOperation &otherOp = other.getOperation();
1444  operation.checkValid();
1445  otherOp.checkValid();
1446  mlirOperationMoveBefore(operation, otherOp);
1447  operation.parentKeepAlive = otherOp.parentKeepAlive;
1448 }
1449 
1451  PyOperation &op = getOperation();
1453  if (!mlirOperationVerify(op.get()))
1454  throw MLIRError("Verification failed", errors.take());
1455  return true;
1456 }
1457 
1458 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1459  checkValid();
1460  if (!isAttached())
1461  throw nb::value_error("Detached operations have no parent");
1462  MlirOperation operation = mlirOperationGetParentOperation(get());
1463  if (mlirOperationIsNull(operation))
1464  return {};
1465  return PyOperation::forOperation(getContext(), operation);
1466 }
1467 
1469  checkValid();
1470  std::optional<PyOperationRef> parentOperation = getParentOperation();
1471  MlirBlock block = mlirOperationGetBlock(get());
1472  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1473  assert(parentOperation && "Operation has no parent");
1474  return PyBlock{std::move(*parentOperation), block};
1475 }
1476 
1478  checkValid();
1479  return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1480 }
1481 
1482 nb::object PyOperation::createFromCapsule(nb::object capsule) {
1483  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1484  if (mlirOperationIsNull(rawOperation))
1485  throw nb::python_error();
1486  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1487  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1488  .releaseObject();
1489 }
1490 
1492  const nb::object &maybeIp) {
1493  // InsertPoint active?
1494  if (!maybeIp.is(nb::cast(false))) {
1495  PyInsertionPoint *ip;
1496  if (maybeIp.is_none()) {
1498  } else {
1499  ip = nb::cast<PyInsertionPoint *>(maybeIp);
1500  }
1501  if (ip)
1502  ip->insert(*op.get());
1503  }
1504 }
1505 
1506 nb::object PyOperation::create(std::string_view name,
1507  std::optional<std::vector<PyType *>> results,
1508  llvm::ArrayRef<MlirValue> operands,
1509  std::optional<nb::dict> attributes,
1510  std::optional<std::vector<PyBlock *>> successors,
1511  int regions, DefaultingPyLocation location,
1512  const nb::object &maybeIp, bool inferType) {
1513  llvm::SmallVector<MlirType, 4> mlirResults;
1514  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1516 
1517  // General parameter validation.
1518  if (regions < 0)
1519  throw nb::value_error("number of regions must be >= 0");
1520 
1521  // Unpack/validate results.
1522  if (results) {
1523  mlirResults.reserve(results->size());
1524  for (PyType *result : *results) {
1525  // TODO: Verify result type originate from the same context.
1526  if (!result)
1527  throw nb::value_error("result type cannot be None");
1528  mlirResults.push_back(*result);
1529  }
1530  }
1531  // Unpack/validate attributes.
1532  if (attributes) {
1533  mlirAttributes.reserve(attributes->size());
1534  for (std::pair<nb::handle, nb::handle> it : *attributes) {
1535  std::string key;
1536  try {
1537  key = nb::cast<std::string>(it.first);
1538  } catch (nb::cast_error &err) {
1539  std::string msg = "Invalid attribute key (not a string) when "
1540  "attempting to create the operation \"" +
1541  std::string(name) + "\" (" + err.what() + ")";
1542  throw nb::type_error(msg.c_str());
1543  }
1544  try {
1545  auto &attribute = nb::cast<PyAttribute &>(it.second);
1546  // TODO: Verify attribute originates from the same context.
1547  mlirAttributes.emplace_back(std::move(key), attribute);
1548  } catch (nb::cast_error &err) {
1549  std::string msg = "Invalid attribute value for the key \"" + key +
1550  "\" when attempting to create the operation \"" +
1551  std::string(name) + "\" (" + err.what() + ")";
1552  throw nb::type_error(msg.c_str());
1553  } catch (std::runtime_error &) {
1554  // This exception seems thrown when the value is "None".
1555  std::string msg =
1556  "Found an invalid (`None`?) attribute value for the key \"" + key +
1557  "\" when attempting to create the operation \"" +
1558  std::string(name) + "\"";
1559  throw std::runtime_error(msg);
1560  }
1561  }
1562  }
1563  // Unpack/validate successors.
1564  if (successors) {
1565  mlirSuccessors.reserve(successors->size());
1566  for (auto *successor : *successors) {
1567  // TODO: Verify successor originate from the same context.
1568  if (!successor)
1569  throw nb::value_error("successor block cannot be None");
1570  mlirSuccessors.push_back(successor->get());
1571  }
1572  }
1573 
1574  // Apply unpacked/validated to the operation state. Beyond this
1575  // point, exceptions cannot be thrown or else the state will leak.
1576  MlirOperationState state =
1577  mlirOperationStateGet(toMlirStringRef(name), location);
1578  if (!operands.empty())
1579  mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1580  state.enableResultTypeInference = inferType;
1581  if (!mlirResults.empty())
1582  mlirOperationStateAddResults(&state, mlirResults.size(),
1583  mlirResults.data());
1584  if (!mlirAttributes.empty()) {
1585  // Note that the attribute names directly reference bytes in
1586  // mlirAttributes, so that vector must not be changed from here
1587  // on.
1588  llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1589  mlirNamedAttributes.reserve(mlirAttributes.size());
1590  for (auto &it : mlirAttributes)
1591  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1593  toMlirStringRef(it.first)),
1594  it.second));
1595  mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1596  mlirNamedAttributes.data());
1597  }
1598  if (!mlirSuccessors.empty())
1599  mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1600  mlirSuccessors.data());
1601  if (regions) {
1603  mlirRegions.resize(regions);
1604  for (int i = 0; i < regions; ++i)
1605  mlirRegions[i] = mlirRegionCreate();
1606  mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1607  mlirRegions.data());
1608  }
1609 
1610  // Construct the operation.
1611  MlirOperation operation = mlirOperationCreate(&state);
1612  if (!operation.ptr)
1613  throw nb::value_error("Operation creation failed");
1614  PyOperationRef created =
1615  PyOperation::createDetached(location->getContext(), operation);
1616  maybeInsertOperation(created, maybeIp);
1617 
1618  return created.getObject();
1619 }
1620 
1621 nb::object PyOperation::clone(const nb::object &maybeIp) {
1622  MlirOperation clonedOperation = mlirOperationClone(operation);
1623  PyOperationRef cloned =
1624  PyOperation::createDetached(getContext(), clonedOperation);
1625  maybeInsertOperation(cloned, maybeIp);
1626 
1627  return cloned->createOpView();
1628 }
1629 
1631  checkValid();
1632  MlirIdentifier ident = mlirOperationGetName(get());
1633  MlirStringRef identStr = mlirIdentifierStr(ident);
1634  auto operationCls = PyGlobals::get().lookupOperationClass(
1635  StringRef(identStr.data, identStr.length));
1636  if (operationCls)
1637  return PyOpView::constructDerived(*operationCls, getRef().getObject());
1638  return nb::cast(PyOpView(getRef().getObject()));
1639 }
1640 
1642  checkValid();
1644  mlirOperationDestroy(operation);
1645 }
1646 
1647 namespace {
1648 /// CRTP base class for Python MLIR values that subclass Value and should be
1649 /// castable from it. The value hierarchy is one level deep and is not supposed
1650 /// to accommodate other levels unless core MLIR changes.
1651 template <typename DerivedTy>
1652 class PyConcreteValue : public PyValue {
1653 public:
1654  // Derived classes must define statics for:
1655  // IsAFunctionTy isaFunction
1656  // const char *pyClassName
1657  // and redefine bindDerived.
1658  using ClassTy = nb::class_<DerivedTy, PyValue>;
1659  using IsAFunctionTy = bool (*)(MlirValue);
1660 
1661  PyConcreteValue() = default;
1662  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1663  : PyValue(operationRef, value) {}
1664  PyConcreteValue(PyValue &orig)
1665  : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1666 
1667  /// Attempts to cast the original value to the derived type and throws on
1668  /// type mismatches.
1669  static MlirValue castFrom(PyValue &orig) {
1670  if (!DerivedTy::isaFunction(orig.get())) {
1671  auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1672  throw nb::value_error((Twine("Cannot cast value to ") +
1673  DerivedTy::pyClassName + " (from " + origRepr +
1674  ")")
1675  .str()
1676  .c_str());
1677  }
1678  return orig.get();
1679  }
1680 
1681  /// Binds the Python module objects to functions of this class.
1682  static void bind(nb::module_ &m) {
1683  auto cls = ClassTy(m, DerivedTy::pyClassName);
1684  cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1685  cls.def_static(
1686  "isinstance",
1687  [](PyValue &otherValue) -> bool {
1688  return DerivedTy::isaFunction(otherValue);
1689  },
1690  nb::arg("other_value"));
1692  [](DerivedTy &self) { return self.maybeDownCast(); });
1693  DerivedTy::bindDerived(cls);
1694  }
1695 
1696  /// Implemented by derived classes to add methods to the Python subclass.
1697  static void bindDerived(ClassTy &m) {}
1698 };
1699 
1700 } // namespace
1701 
1702 /// Python wrapper for MlirOpResult.
1703 class PyOpResult : public PyConcreteValue<PyOpResult> {
1704 public:
1705  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1706  static constexpr const char *pyClassName = "OpResult";
1707  using PyConcreteValue::PyConcreteValue;
1708 
1709  static void bindDerived(ClassTy &c) {
1710  c.def_prop_ro("owner", [](PyOpResult &self) {
1711  assert(
1712  mlirOperationEqual(self.getParentOperation()->get(),
1713  mlirOpResultGetOwner(self.get())) &&
1714  "expected the owner of the value in Python to match that in the IR");
1715  return self.getParentOperation().getObject();
1716  });
1717  c.def_prop_ro("result_number", [](PyOpResult &self) {
1718  return mlirOpResultGetResultNumber(self.get());
1719  });
1720  }
1721 };
1722 
1723 /// Returns the list of types of the values held by container.
1724 template <typename Container>
1725 static std::vector<MlirType> getValueTypes(Container &container,
1726  PyMlirContextRef &context) {
1727  std::vector<MlirType> result;
1728  result.reserve(container.size());
1729  for (int i = 0, e = container.size(); i < e; ++i) {
1730  result.push_back(mlirValueGetType(container.getElement(i).get()));
1731  }
1732  return result;
1733 }
1734 
1735 /// A list of operation results. Internally, these are stored as consecutive
1736 /// elements, random access is cheap. The (returned) result list is associated
1737 /// with the operation whose results these are, and thus extends the lifetime of
1738 /// this operation.
1739 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1740 public:
1741  static constexpr const char *pyClassName = "OpResultList";
1743 
1744  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1745  intptr_t length = -1, intptr_t step = 1)
1746  : Sliceable(startIndex,
1747  length == -1 ? mlirOperationGetNumResults(operation->get())
1748  : length,
1749  step),
1750  operation(std::move(operation)) {}
1751 
1752  static void bindDerived(ClassTy &c) {
1753  c.def_prop_ro("types", [](PyOpResultList &self) {
1754  return getValueTypes(self, self.operation->getContext());
1755  });
1756  c.def_prop_ro("owner", [](PyOpResultList &self) {
1757  return self.operation->createOpView();
1758  });
1759  }
1760 
1761  PyOperationRef &getOperation() { return operation; }
1762 
1763 private:
1764  /// Give the parent CRTP class access to hook implementations below.
1765  friend class Sliceable<PyOpResultList, PyOpResult>;
1766 
1767  intptr_t getRawNumElements() {
1768  operation->checkValid();
1769  return mlirOperationGetNumResults(operation->get());
1770  }
1771 
1772  PyOpResult getRawElement(intptr_t index) {
1773  PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1774  return PyOpResult(value);
1775  }
1776 
1777  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1778  return PyOpResultList(operation, startIndex, length, step);
1779  }
1780 
1781  PyOperationRef operation;
1782 };
1783 
1784 //------------------------------------------------------------------------------
1785 // PyOpView
1786 //------------------------------------------------------------------------------
1787 
1788 static void populateResultTypes(StringRef name, nb::list resultTypeList,
1789  const nb::object &resultSegmentSpecObj,
1790  std::vector<int32_t> &resultSegmentLengths,
1791  std::vector<PyType *> &resultTypes) {
1792  resultTypes.reserve(resultTypeList.size());
1793  if (resultSegmentSpecObj.is_none()) {
1794  // Non-variadic result unpacking.
1795  for (const auto &it : llvm::enumerate(resultTypeList)) {
1796  try {
1797  resultTypes.push_back(nb::cast<PyType *>(it.value()));
1798  if (!resultTypes.back())
1799  throw nb::cast_error();
1800  } catch (nb::cast_error &err) {
1801  throw nb::value_error((llvm::Twine("Result ") +
1802  llvm::Twine(it.index()) + " of operation \"" +
1803  name + "\" must be a Type (" + err.what() + ")")
1804  .str()
1805  .c_str());
1806  }
1807  }
1808  } else {
1809  // Sized result unpacking.
1810  auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1811  if (resultSegmentSpec.size() != resultTypeList.size()) {
1812  throw nb::value_error((llvm::Twine("Operation \"") + name +
1813  "\" requires " +
1814  llvm::Twine(resultSegmentSpec.size()) +
1815  " result segments but was provided " +
1816  llvm::Twine(resultTypeList.size()))
1817  .str()
1818  .c_str());
1819  }
1820  resultSegmentLengths.reserve(resultTypeList.size());
1821  for (const auto &it :
1822  llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1823  int segmentSpec = std::get<1>(it.value());
1824  if (segmentSpec == 1 || segmentSpec == 0) {
1825  // Unpack unary element.
1826  try {
1827  auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1828  if (resultType) {
1829  resultTypes.push_back(resultType);
1830  resultSegmentLengths.push_back(1);
1831  } else if (segmentSpec == 0) {
1832  // Allowed to be optional.
1833  resultSegmentLengths.push_back(0);
1834  } else {
1835  throw nb::value_error(
1836  (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1837  " of operation \"" + name +
1838  "\" must be a Type (was None and result is not optional)")
1839  .str()
1840  .c_str());
1841  }
1842  } catch (nb::cast_error &err) {
1843  throw nb::value_error((llvm::Twine("Result ") +
1844  llvm::Twine(it.index()) + " of operation \"" +
1845  name + "\" must be a Type (" + err.what() +
1846  ")")
1847  .str()
1848  .c_str());
1849  }
1850  } else if (segmentSpec == -1) {
1851  // Unpack sequence by appending.
1852  try {
1853  if (std::get<0>(it.value()).is_none()) {
1854  // Treat it as an empty list.
1855  resultSegmentLengths.push_back(0);
1856  } else {
1857  // Unpack the list.
1858  auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1859  for (nb::handle segmentItem : segment) {
1860  resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1861  if (!resultTypes.back()) {
1862  throw nb::type_error("contained a None item");
1863  }
1864  }
1865  resultSegmentLengths.push_back(nb::len(segment));
1866  }
1867  } catch (std::exception &err) {
1868  // NOTE: Sloppy to be using a catch-all here, but there are at least
1869  // three different unrelated exceptions that can be thrown in the
1870  // above "casts". Just keep the scope above small and catch them all.
1871  throw nb::value_error((llvm::Twine("Result ") +
1872  llvm::Twine(it.index()) + " of operation \"" +
1873  name + "\" must be a Sequence of Types (" +
1874  err.what() + ")")
1875  .str()
1876  .c_str());
1877  }
1878  } else {
1879  throw nb::value_error("Unexpected segment spec");
1880  }
1881  }
1882  }
1883 }
1884 
1885 static MlirValue getUniqueResult(MlirOperation operation) {
1886  auto numResults = mlirOperationGetNumResults(operation);
1887  if (numResults != 1) {
1888  auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1889  throw nb::value_error((Twine("Cannot call .result on operation ") +
1890  StringRef(name.data, name.length) + " which has " +
1891  Twine(numResults) +
1892  " results (it is only valid for operations with a "
1893  "single result)")
1894  .str()
1895  .c_str());
1896  }
1897  return mlirOperationGetResult(operation, 0);
1898 }
1899 
1900 static MlirValue getOpResultOrValue(nb::handle operand) {
1901  if (operand.is_none()) {
1902  throw nb::value_error("contained a None item");
1903  }
1904  PyOperationBase *op;
1905  if (nb::try_cast<PyOperationBase *>(operand, op)) {
1906  return getUniqueResult(op->getOperation());
1907  }
1908  PyOpResultList *opResultList;
1909  if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1910  return getUniqueResult(opResultList->getOperation()->get());
1911  }
1912  PyValue *value;
1913  if (nb::try_cast<PyValue *>(operand, value)) {
1914  return value->get();
1915  }
1916  throw nb::value_error("is not a Value");
1917 }
1918 
1920  std::string_view name, std::tuple<int, bool> opRegionSpec,
1921  nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1922  std::optional<nb::list> resultTypeList, nb::list operandList,
1923  std::optional<nb::dict> attributes,
1924  std::optional<std::vector<PyBlock *>> successors,
1925  std::optional<int> regions, DefaultingPyLocation location,
1926  const nb::object &maybeIp) {
1927  PyMlirContextRef context = location->getContext();
1928 
1929  // Class level operation construction metadata.
1930  // Operand and result segment specs are either none, which does no
1931  // variadic unpacking, or a list of ints with segment sizes, where each
1932  // element is either a positive number (typically 1 for a scalar) or -1 to
1933  // indicate that it is derived from the length of the same-indexed operand
1934  // or result (implying that it is a list at that position).
1935  std::vector<int32_t> operandSegmentLengths;
1936  std::vector<int32_t> resultSegmentLengths;
1937 
1938  // Validate/determine region count.
1939  int opMinRegionCount = std::get<0>(opRegionSpec);
1940  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1941  if (!regions) {
1942  regions = opMinRegionCount;
1943  }
1944  if (*regions < opMinRegionCount) {
1945  throw nb::value_error(
1946  (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1947  llvm::Twine(opMinRegionCount) +
1948  " regions but was built with regions=" + llvm::Twine(*regions))
1949  .str()
1950  .c_str());
1951  }
1952  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1953  throw nb::value_error(
1954  (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1955  llvm::Twine(opMinRegionCount) +
1956  " regions but was built with regions=" + llvm::Twine(*regions))
1957  .str()
1958  .c_str());
1959  }
1960 
1961  // Unpack results.
1962  std::vector<PyType *> resultTypes;
1963  if (resultTypeList.has_value()) {
1964  populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1965  resultSegmentLengths, resultTypes);
1966  }
1967 
1968  // Unpack operands.
1970  operands.reserve(operands.size());
1971  if (operandSegmentSpecObj.is_none()) {
1972  // Non-sized operand unpacking.
1973  for (const auto &it : llvm::enumerate(operandList)) {
1974  try {
1975  operands.push_back(getOpResultOrValue(it.value()));
1976  } catch (nb::builtin_exception &err) {
1977  throw nb::value_error((llvm::Twine("Operand ") +
1978  llvm::Twine(it.index()) + " of operation \"" +
1979  name + "\" must be a Value (" + err.what() + ")")
1980  .str()
1981  .c_str());
1982  }
1983  }
1984  } else {
1985  // Sized operand unpacking.
1986  auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1987  if (operandSegmentSpec.size() != operandList.size()) {
1988  throw nb::value_error((llvm::Twine("Operation \"") + name +
1989  "\" requires " +
1990  llvm::Twine(operandSegmentSpec.size()) +
1991  "operand segments but was provided " +
1992  llvm::Twine(operandList.size()))
1993  .str()
1994  .c_str());
1995  }
1996  operandSegmentLengths.reserve(operandList.size());
1997  for (const auto &it :
1998  llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1999  int segmentSpec = std::get<1>(it.value());
2000  if (segmentSpec == 1 || segmentSpec == 0) {
2001  // Unpack unary element.
2002  auto &operand = std::get<0>(it.value());
2003  if (!operand.is_none()) {
2004  try {
2005 
2006  operands.push_back(getOpResultOrValue(operand));
2007  } catch (nb::builtin_exception &err) {
2008  throw nb::value_error((llvm::Twine("Operand ") +
2009  llvm::Twine(it.index()) +
2010  " of operation \"" + name +
2011  "\" must be a Value (" + err.what() + ")")
2012  .str()
2013  .c_str());
2014  }
2015 
2016  operandSegmentLengths.push_back(1);
2017  } else if (segmentSpec == 0) {
2018  // Allowed to be optional.
2019  operandSegmentLengths.push_back(0);
2020  } else {
2021  throw nb::value_error(
2022  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
2023  " of operation \"" + name +
2024  "\" must be a Value (was None and operand is not optional)")
2025  .str()
2026  .c_str());
2027  }
2028  } else if (segmentSpec == -1) {
2029  // Unpack sequence by appending.
2030  try {
2031  if (std::get<0>(it.value()).is_none()) {
2032  // Treat it as an empty list.
2033  operandSegmentLengths.push_back(0);
2034  } else {
2035  // Unpack the list.
2036  auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
2037  for (nb::handle segmentItem : segment) {
2038  operands.push_back(getOpResultOrValue(segmentItem));
2039  }
2040  operandSegmentLengths.push_back(nb::len(segment));
2041  }
2042  } catch (std::exception &err) {
2043  // NOTE: Sloppy to be using a catch-all here, but there are at least
2044  // three different unrelated exceptions that can be thrown in the
2045  // above "casts". Just keep the scope above small and catch them all.
2046  throw nb::value_error((llvm::Twine("Operand ") +
2047  llvm::Twine(it.index()) + " of operation \"" +
2048  name + "\" must be a Sequence of Values (" +
2049  err.what() + ")")
2050  .str()
2051  .c_str());
2052  }
2053  } else {
2054  throw nb::value_error("Unexpected segment spec");
2055  }
2056  }
2057  }
2058 
2059  // Merge operand/result segment lengths into attributes if needed.
2060  if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
2061  // Dup.
2062  if (attributes) {
2063  attributes = nb::dict(*attributes);
2064  } else {
2065  attributes = nb::dict();
2066  }
2067  if (attributes->contains("resultSegmentSizes") ||
2068  attributes->contains("operandSegmentSizes")) {
2069  throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
2070  "'operandSegmentSizes' attribute is unsupported. "
2071  "Use Operation.create for such low-level access.");
2072  }
2073 
2074  // Add resultSegmentSizes attribute.
2075  if (!resultSegmentLengths.empty()) {
2076  MlirAttribute segmentLengthAttr =
2077  mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
2078  resultSegmentLengths.data());
2079  (*attributes)["resultSegmentSizes"] =
2080  PyAttribute(context, segmentLengthAttr);
2081  }
2082 
2083  // Add operandSegmentSizes attribute.
2084  if (!operandSegmentLengths.empty()) {
2085  MlirAttribute segmentLengthAttr =
2086  mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
2087  operandSegmentLengths.data());
2088  (*attributes)["operandSegmentSizes"] =
2089  PyAttribute(context, segmentLengthAttr);
2090  }
2091  }
2092 
2093  // Delegate to create.
2094  return PyOperation::create(name,
2095  /*results=*/std::move(resultTypes),
2096  /*operands=*/std::move(operands),
2097  /*attributes=*/std::move(attributes),
2098  /*successors=*/std::move(successors),
2099  /*regions=*/*regions, location, maybeIp,
2100  !resultTypeList);
2101 }
2102 
2103 nb::object PyOpView::constructDerived(const nb::object &cls,
2104  const nb::object &operation) {
2105  nb::handle opViewType = nb::type<PyOpView>();
2106  nb::object instance = cls.attr("__new__")(cls);
2107  opViewType.attr("__init__")(instance, operation);
2108  return instance;
2109 }
2110 
2111 PyOpView::PyOpView(const nb::object &operationObject)
2112  // Casting through the PyOperationBase base-class and then back to the
2113  // Operation lets us accept any PyOperationBase subclass.
2114  : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2115  operationObject(operation.getRef().getObject()) {}
2116 
2117 //------------------------------------------------------------------------------
2118 // PyInsertionPoint.
2119 //------------------------------------------------------------------------------
2120 
2122 
2124  : refOperation(beforeOperationBase.getOperation().getRef()),
2125  block((*refOperation)->getBlock()) {}
2126 
2128  PyOperation &operation = operationBase.getOperation();
2129  if (operation.isAttached())
2130  throw nb::value_error(
2131  "Attempt to insert operation that is already attached");
2132  block.getParentOperation()->checkValid();
2133  MlirOperation beforeOp = {nullptr};
2134  if (refOperation) {
2135  // Insert before operation.
2136  (*refOperation)->checkValid();
2137  beforeOp = (*refOperation)->get();
2138  } else {
2139  // Insert at end (before null) is only valid if the block does not
2140  // already end in a known terminator (violating this will cause assertion
2141  // failures later).
2142  if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2143  throw nb::index_error("Cannot insert operation at the end of a block "
2144  "that already has a terminator. Did you mean to "
2145  "use 'InsertionPoint.at_block_terminator(block)' "
2146  "versus 'InsertionPoint(block)'?");
2147  }
2148  }
2149  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2150  operation.setAttached();
2151 }
2152 
2154  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2155  if (mlirOperationIsNull(firstOp)) {
2156  // Just insert at end.
2157  return PyInsertionPoint(block);
2158  }
2159 
2160  // Insert before first op.
2162  block.getParentOperation()->getContext(), firstOp);
2163  return PyInsertionPoint{block, std::move(firstOpRef)};
2164 }
2165 
2167  MlirOperation terminator = mlirBlockGetTerminator(block.get());
2168  if (mlirOperationIsNull(terminator))
2169  throw nb::value_error("Block has no terminator");
2170  PyOperationRef terminatorOpRef = PyOperation::forOperation(
2171  block.getParentOperation()->getContext(), terminator);
2172  return PyInsertionPoint{block, std::move(terminatorOpRef)};
2173 }
2174 
2175 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2176  return PyThreadContextEntry::pushInsertionPoint(insertPoint);
2177 }
2178 
2179 void PyInsertionPoint::contextExit(const nb::object &excType,
2180  const nb::object &excVal,
2181  const nb::object &excTb) {
2183 }
2184 
2185 //------------------------------------------------------------------------------
2186 // PyAttribute.
2187 //------------------------------------------------------------------------------
2188 
2189 bool PyAttribute::operator==(const PyAttribute &other) const {
2190  return mlirAttributeEqual(attr, other.attr);
2191 }
2192 
2194  return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2195 }
2196 
2198  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2199  if (mlirAttributeIsNull(rawAttr))
2200  throw nb::python_error();
2201  return PyAttribute(
2203 }
2204 
2205 //------------------------------------------------------------------------------
2206 // PyNamedAttribute.
2207 //------------------------------------------------------------------------------
2208 
2209 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2210  : ownedName(new std::string(std::move(ownedName))) {
2213  toMlirStringRef(*this->ownedName)),
2214  attr);
2215 }
2216 
2217 //------------------------------------------------------------------------------
2218 // PyType.
2219 //------------------------------------------------------------------------------
2220 
2221 bool PyType::operator==(const PyType &other) const {
2222  return mlirTypeEqual(type, other.type);
2223 }
2224 
2225 nb::object PyType::getCapsule() {
2226  return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2227 }
2228 
2229 PyType PyType::createFromCapsule(nb::object capsule) {
2230  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2231  if (mlirTypeIsNull(rawType))
2232  throw nb::python_error();
2234  rawType);
2235 }
2236 
2237 //------------------------------------------------------------------------------
2238 // PyTypeID.
2239 //------------------------------------------------------------------------------
2240 
2241 nb::object PyTypeID::getCapsule() {
2242  return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2243 }
2244 
2246  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2247  if (mlirTypeIDIsNull(mlirTypeID))
2248  throw nb::python_error();
2249  return PyTypeID(mlirTypeID);
2250 }
2251 bool PyTypeID::operator==(const PyTypeID &other) const {
2252  return mlirTypeIDEqual(typeID, other.typeID);
2253 }
2254 
2255 //------------------------------------------------------------------------------
2256 // PyValue and subclasses.
2257 //------------------------------------------------------------------------------
2258 
2259 nb::object PyValue::getCapsule() {
2260  return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2261 }
2262 
2264  MlirType type = mlirValueGetType(get());
2265  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2266  assert(!mlirTypeIDIsNull(mlirTypeID) &&
2267  "mlirTypeID was expected to be non-null.");
2268  std::optional<nb::callable> valueCaster =
2270  // nb::rv_policy::move means use std::move to move the return value
2271  // contents into a new instance that will be owned by Python.
2272  nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2273  if (!valueCaster)
2274  return thisObj;
2275  return valueCaster.value()(thisObj);
2276 }
2277 
2279  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2280  if (mlirValueIsNull(value))
2281  throw nb::python_error();
2282  MlirOperation owner;
2283  if (mlirValueIsAOpResult(value))
2284  owner = mlirOpResultGetOwner(value);
2285  if (mlirValueIsABlockArgument(value))
2287  if (mlirOperationIsNull(owner))
2288  throw nb::python_error();
2289  MlirContext ctx = mlirOperationGetContext(owner);
2290  PyOperationRef ownerRef =
2292  return PyValue(ownerRef, value);
2293 }
2294 
2295 //------------------------------------------------------------------------------
2296 // PySymbolTable.
2297 //------------------------------------------------------------------------------
2298 
2300  : operation(operation.getOperation().getRef()) {
2301  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2302  if (mlirSymbolTableIsNull(symbolTable)) {
2303  throw nb::type_error("Operation is not a Symbol Table.");
2304  }
2305 }
2306 
2307 nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2308  operation->checkValid();
2309  MlirOperation symbol = mlirSymbolTableLookup(
2310  symbolTable, mlirStringRefCreate(name.data(), name.length()));
2311  if (mlirOperationIsNull(symbol))
2312  throw nb::key_error(
2313  ("Symbol '" + name + "' not in the symbol table.").c_str());
2314 
2315  return PyOperation::forOperation(operation->getContext(), symbol,
2316  operation.getObject())
2317  ->createOpView();
2318 }
2319 
2321  operation->checkValid();
2322  symbol.getOperation().checkValid();
2323  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2324  // The operation is also erased, so we must invalidate it. There may be Python
2325  // references to this operation so we don't want to delete it from the list of
2326  // live operations here.
2327  symbol.getOperation().valid = false;
2328 }
2329 
2330 void PySymbolTable::dunderDel(const std::string &name) {
2331  nb::object operation = dunderGetItem(name);
2332  erase(nb::cast<PyOperationBase &>(operation));
2333 }
2334 
2335 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2336  operation->checkValid();
2337  symbol.getOperation().checkValid();
2338  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2340  if (mlirAttributeIsNull(symbolAttr))
2341  throw nb::value_error("Expected operation to have a symbol name.");
2342  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2343 }
2344 
2346  // Op must already be a symbol.
2347  PyOperation &operation = symbol.getOperation();
2348  operation.checkValid();
2350  MlirAttribute existingNameAttr =
2351  mlirOperationGetAttributeByName(operation.get(), attrName);
2352  if (mlirAttributeIsNull(existingNameAttr))
2353  throw nb::value_error("Expected operation to have a symbol name.");
2354  return existingNameAttr;
2355 }
2356 
2358  const std::string &name) {
2359  // Op must already be a symbol.
2360  PyOperation &operation = symbol.getOperation();
2361  operation.checkValid();
2363  MlirAttribute existingNameAttr =
2364  mlirOperationGetAttributeByName(operation.get(), attrName);
2365  if (mlirAttributeIsNull(existingNameAttr))
2366  throw nb::value_error("Expected operation to have a symbol name.");
2367  MlirAttribute newNameAttr =
2368  mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2369  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2370 }
2371 
2373  PyOperation &operation = symbol.getOperation();
2374  operation.checkValid();
2376  MlirAttribute existingVisAttr =
2377  mlirOperationGetAttributeByName(operation.get(), attrName);
2378  if (mlirAttributeIsNull(existingVisAttr))
2379  throw nb::value_error("Expected operation to have a symbol visibility.");
2380  return existingVisAttr;
2381 }
2382 
2384  const std::string &visibility) {
2385  if (visibility != "public" && visibility != "private" &&
2386  visibility != "nested")
2387  throw nb::value_error(
2388  "Expected visibility to be 'public', 'private' or 'nested'");
2389  PyOperation &operation = symbol.getOperation();
2390  operation.checkValid();
2392  MlirAttribute existingVisAttr =
2393  mlirOperationGetAttributeByName(operation.get(), attrName);
2394  if (mlirAttributeIsNull(existingVisAttr))
2395  throw nb::value_error("Expected operation to have a symbol visibility.");
2396  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2397  toMlirStringRef(visibility));
2398  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2399 }
2400 
2401 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2402  const std::string &newSymbol,
2403  PyOperationBase &from) {
2404  PyOperation &fromOperation = from.getOperation();
2405  fromOperation.checkValid();
2407  toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2408  from.getOperation())))
2409 
2410  throw nb::value_error("Symbol rename failed");
2411 }
2412 
2414  bool allSymUsesVisible,
2415  nb::object callback) {
2416  PyOperation &fromOperation = from.getOperation();
2417  fromOperation.checkValid();
2418  struct UserData {
2419  PyMlirContextRef context;
2420  nb::object callback;
2421  bool gotException;
2422  std::string exceptionWhat;
2423  nb::object exceptionType;
2424  };
2425  UserData userData{
2426  fromOperation.getContext(), std::move(callback), false, {}, {}};
2428  fromOperation.get(), allSymUsesVisible,
2429  [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2430  UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2431  auto pyFoundOp =
2432  PyOperation::forOperation(calleeUserData->context, foundOp);
2433  if (calleeUserData->gotException)
2434  return;
2435  try {
2436  calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2437  } catch (nb::python_error &e) {
2438  calleeUserData->gotException = true;
2439  calleeUserData->exceptionWhat = e.what();
2440  calleeUserData->exceptionType = nb::borrow(e.type());
2441  }
2442  },
2443  static_cast<void *>(&userData));
2444  if (userData.gotException) {
2445  std::string message("Exception raised in callback: ");
2446  message.append(userData.exceptionWhat);
2447  throw std::runtime_error(message);
2448  }
2449 }
2450 
2451 namespace {
2452 
2453 /// Python wrapper for MlirBlockArgument.
2454 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2455 public:
2456  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2457  static constexpr const char *pyClassName = "BlockArgument";
2458  using PyConcreteValue::PyConcreteValue;
2459 
2460  static void bindDerived(ClassTy &c) {
2461  c.def_prop_ro("owner", [](PyBlockArgument &self) {
2462  return PyBlock(self.getParentOperation(),
2463  mlirBlockArgumentGetOwner(self.get()));
2464  });
2465  c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2466  return mlirBlockArgumentGetArgNumber(self.get());
2467  });
2468  c.def(
2469  "set_type",
2470  [](PyBlockArgument &self, PyType type) {
2471  return mlirBlockArgumentSetType(self.get(), type);
2472  },
2473  nb::arg("type"));
2474  }
2475 };
2476 
2477 /// A list of block arguments. Internally, these are stored as consecutive
2478 /// elements, random access is cheap. The argument list is associated with the
2479 /// operation that contains the block (detached blocks are not allowed in
2480 /// Python bindings) and extends its lifetime.
2481 class PyBlockArgumentList
2482  : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2483 public:
2484  static constexpr const char *pyClassName = "BlockArgumentList";
2486 
2487  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2488  intptr_t startIndex = 0, intptr_t length = -1,
2489  intptr_t step = 1)
2490  : Sliceable(startIndex,
2491  length == -1 ? mlirBlockGetNumArguments(block) : length,
2492  step),
2493  operation(std::move(operation)), block(block) {}
2494 
2495  static void bindDerived(ClassTy &c) {
2496  c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2497  return getValueTypes(self, self.operation->getContext());
2498  });
2499  }
2500 
2501 private:
2502  /// Give the parent CRTP class access to hook implementations below.
2503  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2504 
2505  /// Returns the number of arguments in the list.
2506  intptr_t getRawNumElements() {
2507  operation->checkValid();
2508  return mlirBlockGetNumArguments(block);
2509  }
2510 
2511  /// Returns `pos`-the element in the list.
2512  PyBlockArgument getRawElement(intptr_t pos) {
2513  MlirValue argument = mlirBlockGetArgument(block, pos);
2514  return PyBlockArgument(operation, argument);
2515  }
2516 
2517  /// Returns a sublist of this list.
2518  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2519  intptr_t step) {
2520  return PyBlockArgumentList(operation, block, startIndex, length, step);
2521  }
2522 
2523  PyOperationRef operation;
2524  MlirBlock block;
2525 };
2526 
2527 /// A list of operation operands. Internally, these are stored as consecutive
2528 /// elements, random access is cheap. The (returned) operand list is associated
2529 /// with the operation whose operands these are, and thus extends the lifetime
2530 /// of this operation.
2531 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2532 public:
2533  static constexpr const char *pyClassName = "OpOperandList";
2534  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2535 
2536  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2537  intptr_t length = -1, intptr_t step = 1)
2538  : Sliceable(startIndex,
2539  length == -1 ? mlirOperationGetNumOperands(operation->get())
2540  : length,
2541  step),
2542  operation(operation) {}
2543 
2544  void dunderSetItem(intptr_t index, PyValue value) {
2545  index = wrapIndex(index);
2546  mlirOperationSetOperand(operation->get(), index, value.get());
2547  }
2548 
2549  static void bindDerived(ClassTy &c) {
2550  c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2551  }
2552 
2553 private:
2554  /// Give the parent CRTP class access to hook implementations below.
2555  friend class Sliceable<PyOpOperandList, PyValue>;
2556 
2557  intptr_t getRawNumElements() {
2558  operation->checkValid();
2559  return mlirOperationGetNumOperands(operation->get());
2560  }
2561 
2562  PyValue getRawElement(intptr_t pos) {
2563  MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2564  MlirOperation owner;
2565  if (mlirValueIsAOpResult(operand))
2566  owner = mlirOpResultGetOwner(operand);
2567  else if (mlirValueIsABlockArgument(operand))
2569  else
2570  assert(false && "Value must be an block arg or op result.");
2571  PyOperationRef pyOwner =
2572  PyOperation::forOperation(operation->getContext(), owner);
2573  return PyValue(pyOwner, operand);
2574  }
2575 
2576  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2577  return PyOpOperandList(operation, startIndex, length, step);
2578  }
2579 
2580  PyOperationRef operation;
2581 };
2582 
2583 /// A list of operation successors. Internally, these are stored as consecutive
2584 /// elements, random access is cheap. The (returned) successor list is
2585 /// associated with the operation whose successors these are, and thus extends
2586 /// the lifetime of this operation.
2587 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2588 public:
2589  static constexpr const char *pyClassName = "OpSuccessors";
2590 
2591  PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2592  intptr_t length = -1, intptr_t step = 1)
2593  : Sliceable(startIndex,
2594  length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2595  : length,
2596  step),
2597  operation(operation) {}
2598 
2599  void dunderSetItem(intptr_t index, PyBlock block) {
2600  index = wrapIndex(index);
2601  mlirOperationSetSuccessor(operation->get(), index, block.get());
2602  }
2603 
2604  static void bindDerived(ClassTy &c) {
2605  c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2606  }
2607 
2608 private:
2609  /// Give the parent CRTP class access to hook implementations below.
2610  friend class Sliceable<PyOpSuccessors, PyBlock>;
2611 
2612  intptr_t getRawNumElements() {
2613  operation->checkValid();
2614  return mlirOperationGetNumSuccessors(operation->get());
2615  }
2616 
2617  PyBlock getRawElement(intptr_t pos) {
2618  MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2619  return PyBlock(operation, block);
2620  }
2621 
2622  PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2623  return PyOpSuccessors(operation, startIndex, length, step);
2624  }
2625 
2626  PyOperationRef operation;
2627 };
2628 
2629 /// A list of operation attributes. Can be indexed by name, producing
2630 /// attributes, or by index, producing named attributes.
2631 class PyOpAttributeMap {
2632 public:
2633  PyOpAttributeMap(PyOperationRef operation)
2634  : operation(std::move(operation)) {}
2635 
2636  MlirAttribute dunderGetItemNamed(const std::string &name) {
2637  MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2638  toMlirStringRef(name));
2639  if (mlirAttributeIsNull(attr)) {
2640  throw nb::key_error("attempt to access a non-existent attribute");
2641  }
2642  return attr;
2643  }
2644 
2645  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2646  if (index < 0) {
2647  index += dunderLen();
2648  }
2649  if (index < 0 || index >= dunderLen()) {
2650  throw nb::index_error("attempt to access out of bounds attribute");
2651  }
2652  MlirNamedAttribute namedAttr =
2653  mlirOperationGetAttribute(operation->get(), index);
2654  return PyNamedAttribute(
2655  namedAttr.attribute,
2656  std::string(mlirIdentifierStr(namedAttr.name).data,
2657  mlirIdentifierStr(namedAttr.name).length));
2658  }
2659 
2660  void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2662  attr);
2663  }
2664 
2665  void dunderDelItem(const std::string &name) {
2666  int removed = mlirOperationRemoveAttributeByName(operation->get(),
2667  toMlirStringRef(name));
2668  if (!removed)
2669  throw nb::key_error("attempt to delete a non-existent attribute");
2670  }
2671 
2672  intptr_t dunderLen() {
2673  return mlirOperationGetNumAttributes(operation->get());
2674  }
2675 
2676  bool dunderContains(const std::string &name) {
2678  operation->get(), toMlirStringRef(name)));
2679  }
2680 
2681  static void bind(nb::module_ &m) {
2682  nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2683  .def("__contains__", &PyOpAttributeMap::dunderContains)
2684  .def("__len__", &PyOpAttributeMap::dunderLen)
2685  .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2686  .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2687  .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2688  .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2689  }
2690 
2691 private:
2692  PyOperationRef operation;
2693 };
2694 
2695 } // namespace
2696 
2697 //------------------------------------------------------------------------------
2698 // Populates the core exports of the 'ir' submodule.
2699 //------------------------------------------------------------------------------
2700 
2701 void mlir::python::populateIRCore(nb::module_ &m) {
2702  // disable leak warnings which tend to be false positives.
2703  nb::set_leak_warnings(false);
2704  //----------------------------------------------------------------------------
2705  // Enums.
2706  //----------------------------------------------------------------------------
2707  nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2708  .value("ERROR", MlirDiagnosticError)
2709  .value("WARNING", MlirDiagnosticWarning)
2710  .value("NOTE", MlirDiagnosticNote)
2711  .value("REMARK", MlirDiagnosticRemark);
2712 
2713  nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2714  .value("PRE_ORDER", MlirWalkPreOrder)
2715  .value("POST_ORDER", MlirWalkPostOrder);
2716 
2717  nb::enum_<MlirWalkResult>(m, "WalkResult")
2718  .value("ADVANCE", MlirWalkResultAdvance)
2719  .value("INTERRUPT", MlirWalkResultInterrupt)
2720  .value("SKIP", MlirWalkResultSkip);
2721 
2722  //----------------------------------------------------------------------------
2723  // Mapping of Diagnostics.
2724  //----------------------------------------------------------------------------
2725  nb::class_<PyDiagnostic>(m, "Diagnostic")
2726  .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2727  .def_prop_ro("location", &PyDiagnostic::getLocation)
2728  .def_prop_ro("message", &PyDiagnostic::getMessage)
2729  .def_prop_ro("notes", &PyDiagnostic::getNotes)
2730  .def("__str__", [](PyDiagnostic &self) -> nb::str {
2731  if (!self.isValid())
2732  return nb::str("<Invalid Diagnostic>");
2733  return self.getMessage();
2734  });
2735 
2736  nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2737  .def("__init__",
2739  new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2740  })
2741  .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2742  .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2743  .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2744  .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2745  .def("__str__",
2746  [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2747 
2748  nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2749  .def("detach", &PyDiagnosticHandler::detach)
2750  .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2751  .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2752  .def("__enter__", &PyDiagnosticHandler::contextEnter)
2753  .def("__exit__", &PyDiagnosticHandler::contextExit,
2754  nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2755  nb::arg("traceback").none());
2756 
2757  //----------------------------------------------------------------------------
2758  // Mapping of MlirContext.
2759  // Note that this is exported as _BaseContext. The containing, Python level
2760  // __init__.py will subclass it with site-specific functionality and set a
2761  // "Context" attribute on this module.
2762  //----------------------------------------------------------------------------
2763 
2764  // Expose DefaultThreadPool to python
2765  nb::class_<PyThreadPool>(m, "ThreadPool")
2766  .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2767  .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
2768  .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
2769 
2770  nb::class_<PyMlirContext>(m, "_BaseContext")
2771  .def("__init__",
2772  [](PyMlirContext &self) {
2773  MlirContext context = mlirContextCreateWithThreading(false);
2774  new (&self) PyMlirContext(context);
2775  })
2776  .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2777  .def("_get_context_again",
2778  [](PyMlirContext &self) {
2779  PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2780  return ref.releaseObject();
2781  })
2782  .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2783  .def("_get_live_operation_objects",
2784  &PyMlirContext::getLiveOperationObjects)
2785  .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2786  .def("_clear_live_operations_inside",
2787  nb::overload_cast<MlirOperation>(
2788  &PyMlirContext::clearOperationsInside))
2789  .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2790  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2791  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2792  .def("__enter__", &PyMlirContext::contextEnter)
2793  .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2794  nb::arg("exc_value").none(), nb::arg("traceback").none())
2795  .def_prop_ro_static(
2796  "current",
2797  [](nb::object & /*class*/) {
2798  auto *context = PyThreadContextEntry::getDefaultContext();
2799  if (!context)
2800  return nb::none();
2801  return nb::cast(context);
2802  },
2803  "Gets the Context bound to the current thread or raises ValueError")
2804  .def_prop_ro(
2805  "dialects",
2806  [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2807  "Gets a container for accessing dialects by name")
2808  .def_prop_ro(
2809  "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2810  "Alias for 'dialect'")
2811  .def(
2812  "get_dialect_descriptor",
2813  [=](PyMlirContext &self, std::string &name) {
2814  MlirDialect dialect = mlirContextGetOrLoadDialect(
2815  self.get(), {name.data(), name.size()});
2816  if (mlirDialectIsNull(dialect)) {
2817  throw nb::value_error(
2818  (Twine("Dialect '") + name + "' not found").str().c_str());
2819  }
2820  return PyDialectDescriptor(self.getRef(), dialect);
2821  },
2822  nb::arg("dialect_name"),
2823  "Gets or loads a dialect by name, returning its descriptor object")
2824  .def_prop_rw(
2825  "allow_unregistered_dialects",
2826  [](PyMlirContext &self) -> bool {
2828  },
2829  [](PyMlirContext &self, bool value) {
2831  })
2832  .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2833  nb::arg("callback"),
2834  "Attaches a diagnostic handler that will receive callbacks")
2835  .def(
2836  "enable_multithreading",
2837  [](PyMlirContext &self, bool enable) {
2838  mlirContextEnableMultithreading(self.get(), enable);
2839  },
2840  nb::arg("enable"))
2841  .def("set_thread_pool",
2842  [](PyMlirContext &self, PyThreadPool &pool) {
2843  // we should disable multi-threading first before setting
2844  // new thread pool otherwise the assert in
2845  // MLIRContext::setThreadPool will be raised.
2846  mlirContextEnableMultithreading(self.get(), false);
2847  mlirContextSetThreadPool(self.get(), pool.get());
2848  })
2849  .def("get_num_threads",
2850  [](PyMlirContext &self) {
2851  return mlirContextGetNumThreads(self.get());
2852  })
2853  .def("_mlir_thread_pool_ptr",
2854  [](PyMlirContext &self) {
2855  MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
2856  std::stringstream ss;
2857  ss << pool.ptr;
2858  return ss.str();
2859  })
2860  .def(
2861  "is_registered_operation",
2862  [](PyMlirContext &self, std::string &name) {
2864  self.get(), MlirStringRef{name.data(), name.size()});
2865  },
2866  nb::arg("operation_name"))
2867  .def(
2868  "append_dialect_registry",
2869  [](PyMlirContext &self, PyDialectRegistry &registry) {
2870  mlirContextAppendDialectRegistry(self.get(), registry);
2871  },
2872  nb::arg("registry"))
2873  .def_prop_rw("emit_error_diagnostics", nullptr,
2874  &PyMlirContext::setEmitErrorDiagnostics,
2875  "Emit error diagnostics to diagnostic handlers. By default "
2876  "error diagnostics are captured and reported through "
2877  "MLIRError exceptions.")
2878  .def("load_all_available_dialects", [](PyMlirContext &self) {
2880  });
2881 
2882  //----------------------------------------------------------------------------
2883  // Mapping of PyDialectDescriptor
2884  //----------------------------------------------------------------------------
2885  nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2886  .def_prop_ro("namespace",
2887  [](PyDialectDescriptor &self) {
2889  return nb::str(ns.data, ns.length);
2890  })
2891  .def("__repr__", [](PyDialectDescriptor &self) {
2893  std::string repr("<DialectDescriptor ");
2894  repr.append(ns.data, ns.length);
2895  repr.append(">");
2896  return repr;
2897  });
2898 
2899  //----------------------------------------------------------------------------
2900  // Mapping of PyDialects
2901  //----------------------------------------------------------------------------
2902  nb::class_<PyDialects>(m, "Dialects")
2903  .def("__getitem__",
2904  [=](PyDialects &self, std::string keyName) {
2905  MlirDialect dialect =
2906  self.getDialectForKey(keyName, /*attrError=*/false);
2907  nb::object descriptor =
2908  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2909  return createCustomDialectWrapper(keyName, std::move(descriptor));
2910  })
2911  .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2912  MlirDialect dialect =
2913  self.getDialectForKey(attrName, /*attrError=*/true);
2914  nb::object descriptor =
2915  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2916  return createCustomDialectWrapper(attrName, std::move(descriptor));
2917  });
2918 
2919  //----------------------------------------------------------------------------
2920  // Mapping of PyDialect
2921  //----------------------------------------------------------------------------
2922  nb::class_<PyDialect>(m, "Dialect")
2923  .def(nb::init<nb::object>(), nb::arg("descriptor"))
2924  .def_prop_ro("descriptor",
2925  [](PyDialect &self) { return self.getDescriptor(); })
2926  .def("__repr__", [](nb::object self) {
2927  auto clazz = self.attr("__class__");
2928  return nb::str("<Dialect ") +
2929  self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
2930  clazz.attr("__module__") + nb::str(".") +
2931  clazz.attr("__name__") + nb::str(")>");
2932  });
2933 
2934  //----------------------------------------------------------------------------
2935  // Mapping of PyDialectRegistry
2936  //----------------------------------------------------------------------------
2937  nb::class_<PyDialectRegistry>(m, "DialectRegistry")
2938  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
2939  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2940  .def(nb::init<>());
2941 
2942  //----------------------------------------------------------------------------
2943  // Mapping of Location
2944  //----------------------------------------------------------------------------
2945  nb::class_<PyLocation>(m, "Location")
2946  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2947  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2948  .def("__enter__", &PyLocation::contextEnter)
2949  .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
2950  nb::arg("exc_value").none(), nb::arg("traceback").none())
2951  .def("__eq__",
2952  [](PyLocation &self, PyLocation &other) -> bool {
2953  return mlirLocationEqual(self, other);
2954  })
2955  .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
2956  .def_prop_ro_static(
2957  "current",
2958  [](nb::object & /*class*/) {
2959  auto *loc = PyThreadContextEntry::getDefaultLocation();
2960  if (!loc)
2961  throw nb::value_error("No current Location");
2962  return loc;
2963  },
2964  "Gets the Location bound to the current thread or raises ValueError")
2965  .def_static(
2966  "unknown",
2967  [](DefaultingPyMlirContext context) {
2968  return PyLocation(context->getRef(),
2969  mlirLocationUnknownGet(context->get()));
2970  },
2971  nb::arg("context").none() = nb::none(),
2972  "Gets a Location representing an unknown location")
2973  .def_static(
2974  "callsite",
2975  [](PyLocation callee, const std::vector<PyLocation> &frames,
2976  DefaultingPyMlirContext context) {
2977  if (frames.empty())
2978  throw nb::value_error("No caller frames provided");
2979  MlirLocation caller = frames.back().get();
2980  for (const PyLocation &frame :
2981  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2982  caller = mlirLocationCallSiteGet(frame.get(), caller);
2983  return PyLocation(context->getRef(),
2984  mlirLocationCallSiteGet(callee.get(), caller));
2985  },
2986  nb::arg("callee"), nb::arg("frames"),
2987  nb::arg("context").none() = nb::none(),
2989  .def("is_a_callsite", mlirLocationIsACallSite)
2990  .def_prop_ro("callee", mlirLocationCallSiteGetCallee)
2991  .def_prop_ro("caller", mlirLocationCallSiteGetCaller)
2992  .def_static(
2993  "file",
2994  [](std::string filename, int line, int col,
2995  DefaultingPyMlirContext context) {
2996  return PyLocation(
2997  context->getRef(),
2999  context->get(), toMlirStringRef(filename), line, col));
3000  },
3001  nb::arg("filename"), nb::arg("line"), nb::arg("col"),
3002  nb::arg("context").none() = nb::none(),
3004  .def_static(
3005  "file",
3006  [](std::string filename, int startLine, int startCol, int endLine,
3007  int endCol, DefaultingPyMlirContext context) {
3008  return PyLocation(context->getRef(),
3010  context->get(), toMlirStringRef(filename),
3011  startLine, startCol, endLine, endCol));
3012  },
3013  nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
3014  nb::arg("end_line"), nb::arg("end_col"),
3015  nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
3016  .def("is_a_file", mlirLocationIsAFileLineColRange)
3017  .def_prop_ro("filename",
3018  [](MlirLocation loc) {
3019  return mlirIdentifierStr(
3021  })
3022  .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
3023  .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
3024  .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
3025  .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
3026  .def_static(
3027  "fused",
3028  [](const std::vector<PyLocation> &pyLocations,
3029  std::optional<PyAttribute> metadata,
3030  DefaultingPyMlirContext context) {
3032  locations.reserve(pyLocations.size());
3033  for (auto &pyLocation : pyLocations)
3034  locations.push_back(pyLocation.get());
3035  MlirLocation location = mlirLocationFusedGet(
3036  context->get(), locations.size(), locations.data(),
3037  metadata ? metadata->get() : MlirAttribute{0});
3038  return PyLocation(context->getRef(), location);
3039  },
3040  nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
3041  nb::arg("context").none() = nb::none(),
3043  .def("is_a_fused", mlirLocationIsAFused)
3044  .def_prop_ro("locations",
3045  [](MlirLocation loc) {
3046  unsigned numLocations =
3048  std::vector<MlirLocation> locations(numLocations);
3049  if (numLocations)
3050  mlirLocationFusedGetLocations(loc, locations.data());
3051  return locations;
3052  })
3053  .def_static(
3054  "name",
3055  [](std::string name, std::optional<PyLocation> childLoc,
3056  DefaultingPyMlirContext context) {
3057  return PyLocation(
3058  context->getRef(),
3060  context->get(), toMlirStringRef(name),
3061  childLoc ? childLoc->get()
3062  : mlirLocationUnknownGet(context->get())));
3063  },
3064  nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
3065  nb::arg("context").none() = nb::none(),
3067  .def("is_a_name", mlirLocationIsAName)
3068  .def_prop_ro("name_str",
3069  [](MlirLocation loc) {
3071  })
3072  .def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
3073  .def_static(
3074  "from_attr",
3075  [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3076  return PyLocation(context->getRef(),
3077  mlirLocationFromAttribute(attribute));
3078  },
3079  nb::arg("attribute"), nb::arg("context").none() = nb::none(),
3080  "Gets a Location from a LocationAttr")
3081  .def_prop_ro(
3082  "context",
3083  [](PyLocation &self) { return self.getContext().getObject(); },
3084  "Context that owns the Location")
3085  .def_prop_ro(
3086  "attr",
3087  [](PyLocation &self) { return mlirLocationGetAttribute(self); },
3088  "Get the underlying LocationAttr")
3089  .def(
3090  "emit_error",
3091  [](PyLocation &self, std::string message) {
3092  mlirEmitError(self, message.c_str());
3093  },
3094  nb::arg("message"), "Emits an error at this location")
3095  .def("__repr__", [](PyLocation &self) {
3096  PyPrintAccumulator printAccum;
3097  mlirLocationPrint(self, printAccum.getCallback(),
3098  printAccum.getUserData());
3099  return printAccum.join();
3100  });
3101 
3102  //----------------------------------------------------------------------------
3103  // Mapping of Module
3104  //----------------------------------------------------------------------------
3105  nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3106  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3107  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3108  .def_static(
3109  "parse",
3110  [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
3111  PyMlirContext::ErrorCapture errors(context->getRef());
3112  MlirModule module = mlirModuleCreateParse(
3113  context->get(), toMlirStringRef(moduleAsm));
3114  if (mlirModuleIsNull(module))
3115  throw MLIRError("Unable to parse module assembly", errors.take());
3116  return PyModule::forModule(module).releaseObject();
3117  },
3118  nb::arg("asm"), nb::arg("context").none() = nb::none(),
3120  .def_static(
3121  "parse",
3122  [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
3123  PyMlirContext::ErrorCapture errors(context->getRef());
3124  MlirModule module = mlirModuleCreateParse(
3125  context->get(), toMlirStringRef(moduleAsm));
3126  if (mlirModuleIsNull(module))
3127  throw MLIRError("Unable to parse module assembly", errors.take());
3128  return PyModule::forModule(module).releaseObject();
3129  },
3130  nb::arg("asm"), nb::arg("context").none() = nb::none(),
3132  .def_static(
3133  "parseFile",
3134  [](const std::string &path, DefaultingPyMlirContext context) {
3135  PyMlirContext::ErrorCapture errors(context->getRef());
3136  MlirModule module = mlirModuleCreateParseFromFile(
3137  context->get(), toMlirStringRef(path));
3138  if (mlirModuleIsNull(module))
3139  throw MLIRError("Unable to parse module assembly", errors.take());
3140  return PyModule::forModule(module).releaseObject();
3141  },
3142  nb::arg("path"), nb::arg("context").none() = nb::none(),
3144  .def_static(
3145  "create",
3146  [](DefaultingPyLocation loc) {
3147  MlirModule module = mlirModuleCreateEmpty(loc);
3148  return PyModule::forModule(module).releaseObject();
3149  },
3150  nb::arg("loc").none() = nb::none(), "Creates an empty module")
3151  .def_prop_ro(
3152  "context",
3153  [](PyModule &self) { return self.getContext().getObject(); },
3154  "Context that created the Module")
3155  .def_prop_ro(
3156  "operation",
3157  [](PyModule &self) {
3158  return PyOperation::forOperation(self.getContext(),
3159  mlirModuleGetOperation(self.get()),
3160  self.getRef().releaseObject())
3161  .releaseObject();
3162  },
3163  "Accesses the module as an operation")
3164  .def_prop_ro(
3165  "body",
3166  [](PyModule &self) {
3167  PyOperationRef moduleOp = PyOperation::forOperation(
3168  self.getContext(), mlirModuleGetOperation(self.get()),
3169  self.getRef().releaseObject());
3170  PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3171  return returnBlock;
3172  },
3173  "Return the block for this module")
3174  .def(
3175  "dump",
3176  [](PyModule &self) {
3178  },
3180  .def(
3181  "__str__",
3182  [](nb::object self) {
3183  // Defer to the operation's __str__.
3184  return self.attr("operation").attr("__str__")();
3185  },
3187 
3188  //----------------------------------------------------------------------------
3189  // Mapping of Operation.
3190  //----------------------------------------------------------------------------
3191  nb::class_<PyOperationBase>(m, "_OperationBase")
3192  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3193  [](PyOperationBase &self) {
3194  return self.getOperation().getCapsule();
3195  })
3196  .def("__eq__",
3197  [](PyOperationBase &self, PyOperationBase &other) {
3198  return &self.getOperation() == &other.getOperation();
3199  })
3200  .def("__eq__",
3201  [](PyOperationBase &self, nb::object other) { return false; })
3202  .def("__hash__",
3203  [](PyOperationBase &self) {
3204  return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3205  })
3206  .def_prop_ro("attributes",
3207  [](PyOperationBase &self) {
3208  return PyOpAttributeMap(self.getOperation().getRef());
3209  })
3210  .def_prop_ro(
3211  "context",
3212  [](PyOperationBase &self) {
3213  PyOperation &concreteOperation = self.getOperation();
3214  concreteOperation.checkValid();
3215  return concreteOperation.getContext().getObject();
3216  },
3217  "Context that owns the Operation")
3218  .def_prop_ro("name",
3219  [](PyOperationBase &self) {
3220  auto &concreteOperation = self.getOperation();
3221  concreteOperation.checkValid();
3222  MlirOperation operation = concreteOperation.get();
3223  return mlirIdentifierStr(mlirOperationGetName(operation));
3224  })
3225  .def_prop_ro("operands",
3226  [](PyOperationBase &self) {
3227  return PyOpOperandList(self.getOperation().getRef());
3228  })
3229  .def_prop_ro("regions",
3230  [](PyOperationBase &self) {
3231  return PyRegionList(self.getOperation().getRef());
3232  })
3233  .def_prop_ro(
3234  "results",
3235  [](PyOperationBase &self) {
3236  return PyOpResultList(self.getOperation().getRef());
3237  },
3238  "Returns the list of Operation results.")
3239  .def_prop_ro(
3240  "result",
3241  [](PyOperationBase &self) {
3242  auto &operation = self.getOperation();
3243  return PyOpResult(operation.getRef(), getUniqueResult(operation))
3244  .maybeDownCast();
3245  },
3246  "Shortcut to get an op result if it has only one (throws an error "
3247  "otherwise).")
3248  .def_prop_ro(
3249  "location",
3250  [](PyOperationBase &self) {
3251  PyOperation &operation = self.getOperation();
3252  return PyLocation(operation.getContext(),
3253  mlirOperationGetLocation(operation.get()));
3254  },
3255  "Returns the source location the operation was defined or derived "
3256  "from.")
3257  .def_prop_ro("parent",
3258  [](PyOperationBase &self) -> nb::object {
3259  auto parent = self.getOperation().getParentOperation();
3260  if (parent)
3261  return parent->getObject();
3262  return nb::none();
3263  })
3264  .def(
3265  "__str__",
3266  [](PyOperationBase &self) {
3267  return self.getAsm(/*binary=*/false,
3268  /*largeElementsLimit=*/std::nullopt,
3269  /*enableDebugInfo=*/false,
3270  /*prettyDebugInfo=*/false,
3271  /*printGenericOpForm=*/false,
3272  /*useLocalScope=*/false,
3273  /*useNameLocAsPrefix=*/false,
3274  /*assumeVerified=*/false,
3275  /*skipRegions=*/false);
3276  },
3277  "Returns the assembly form of the operation.")
3278  .def("print",
3279  nb::overload_cast<PyAsmState &, nb::object, bool>(
3281  nb::arg("state"), nb::arg("file").none() = nb::none(),
3282  nb::arg("binary") = false, kOperationPrintStateDocstring)
3283  .def("print",
3284  nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3285  bool, bool, nb::object, bool, bool>(
3287  // Careful: Lots of arguments must match up with print method.
3288  nb::arg("large_elements_limit").none() = nb::none(),
3289  nb::arg("enable_debug_info") = false,
3290  nb::arg("pretty_debug_info") = false,
3291  nb::arg("print_generic_op_form") = false,
3292  nb::arg("use_local_scope") = false,
3293  nb::arg("use_name_loc_as_prefix") = false,
3294  nb::arg("assume_verified") = false,
3295  nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3296  nb::arg("skip_regions") = false, kOperationPrintDocstring)
3297  .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3298  nb::arg("desired_version").none() = nb::none(),
3300  .def("get_asm", &PyOperationBase::getAsm,
3301  // Careful: Lots of arguments must match up with get_asm method.
3302  nb::arg("binary") = false,
3303  nb::arg("large_elements_limit").none() = nb::none(),
3304  nb::arg("enable_debug_info") = false,
3305  nb::arg("pretty_debug_info") = false,
3306  nb::arg("print_generic_op_form") = false,
3307  nb::arg("use_local_scope") = false,
3308  nb::arg("use_name_loc_as_prefix") = false,
3309  nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3311  .def("verify", &PyOperationBase::verify,
3312  "Verify the operation. Raises MLIRError if verification fails, and "
3313  "returns true otherwise.")
3314  .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3315  "Puts self immediately after the other operation in its parent "
3316  "block.")
3317  .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3318  "Puts self immediately before the other operation in its parent "
3319  "block.")
3320  .def(
3321  "clone",
3322  [](PyOperationBase &self, nb::object ip) {
3323  return self.getOperation().clone(ip);
3324  },
3325  nb::arg("ip").none() = nb::none())
3326  .def(
3327  "detach_from_parent",
3328  [](PyOperationBase &self) {
3329  PyOperation &operation = self.getOperation();
3330  operation.checkValid();
3331  if (!operation.isAttached())
3332  throw nb::value_error("Detached operation has no parent.");
3333 
3334  operation.detachFromParent();
3335  return operation.createOpView();
3336  },
3337  "Detaches the operation from its parent block.")
3338  .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3339  .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3340  nb::arg("walk_order") = MlirWalkPostOrder);
3341 
3342  nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3343  .def_static(
3344  "create",
3345  [](std::string_view name,
3346  std::optional<std::vector<PyType *>> results,
3347  std::optional<std::vector<PyValue *>> operands,
3348  std::optional<nb::dict> attributes,
3349  std::optional<std::vector<PyBlock *>> successors, int regions,
3350  DefaultingPyLocation location, const nb::object &maybeIp,
3351  bool inferType) {
3352  // Unpack/validate operands.
3353  llvm::SmallVector<MlirValue, 4> mlirOperands;
3354  if (operands) {
3355  mlirOperands.reserve(operands->size());
3356  for (PyValue *operand : *operands) {
3357  if (!operand)
3358  throw nb::value_error("operand value cannot be None");
3359  mlirOperands.push_back(operand->get());
3360  }
3361  }
3362 
3363  return PyOperation::create(name, results, mlirOperands, attributes,
3364  successors, regions, location, maybeIp,
3365  inferType);
3366  },
3367  nb::arg("name"), nb::arg("results").none() = nb::none(),
3368  nb::arg("operands").none() = nb::none(),
3369  nb::arg("attributes").none() = nb::none(),
3370  nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
3371  nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3372  nb::arg("infer_type") = false, kOperationCreateDocstring)
3373  .def_static(
3374  "parse",
3375  [](const std::string &sourceStr, const std::string &sourceName,
3376  DefaultingPyMlirContext context) {
3377  return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3378  ->createOpView();
3379  },
3380  nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3381  nb::arg("context").none() = nb::none(),
3382  "Parses an operation. Supports both text assembly format and binary "
3383  "bytecode format.")
3384  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3385  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3386  .def_prop_ro("operation", [](nb::object self) { return self; })
3387  .def_prop_ro("opview", &PyOperation::createOpView)
3388  .def_prop_ro(
3389  "successors",
3390  [](PyOperationBase &self) {
3391  return PyOpSuccessors(self.getOperation().getRef());
3392  },
3393  "Returns the list of Operation successors.");
3394 
3395  auto opViewClass =
3396  nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3397  .def(nb::init<nb::object>(), nb::arg("operation"))
3398  .def(
3399  "__init__",
3400  [](PyOpView *self, std::string_view name,
3401  std::tuple<int, bool> opRegionSpec,
3402  nb::object operandSegmentSpecObj,
3403  nb::object resultSegmentSpecObj,
3404  std::optional<nb::list> resultTypeList, nb::list operandList,
3405  std::optional<nb::dict> attributes,
3406  std::optional<std::vector<PyBlock *>> successors,
3407  std::optional<int> regions, DefaultingPyLocation location,
3408  const nb::object &maybeIp) {
3409  new (self) PyOpView(PyOpView::buildGeneric(
3410  name, opRegionSpec, operandSegmentSpecObj,
3411  resultSegmentSpecObj, resultTypeList, operandList,
3412  attributes, successors, regions, location, maybeIp));
3413  },
3414  nb::arg("name"), nb::arg("opRegionSpec"),
3415  nb::arg("operandSegmentSpecObj").none() = nb::none(),
3416  nb::arg("resultSegmentSpecObj").none() = nb::none(),
3417  nb::arg("results").none() = nb::none(),
3418  nb::arg("operands").none() = nb::none(),
3419  nb::arg("attributes").none() = nb::none(),
3420  nb::arg("successors").none() = nb::none(),
3421  nb::arg("regions").none() = nb::none(),
3422  nb::arg("loc").none() = nb::none(),
3423  nb::arg("ip").none() = nb::none())
3424 
3425  .def_prop_ro("operation", &PyOpView::getOperationObject)
3426  .def_prop_ro("opview", [](nb::object self) { return self; })
3427  .def(
3428  "__str__",
3429  [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3430  .def_prop_ro(
3431  "successors",
3432  [](PyOperationBase &self) {
3433  return PyOpSuccessors(self.getOperation().getRef());
3434  },
3435  "Returns the list of Operation successors.");
3436  opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3437  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3438  opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3439  // It is faster to pass the operation_name, ods_regions, and
3440  // ods_operand_segments/ods_result_segments as arguments to the constructor,
3441  // rather than to access them as attributes.
3442  opViewClass.attr("build_generic") = classmethod(
3443  [](nb::handle cls, std::optional<nb::list> resultTypeList,
3444  nb::list operandList, std::optional<nb::dict> attributes,
3445  std::optional<std::vector<PyBlock *>> successors,
3446  std::optional<int> regions, DefaultingPyLocation location,
3447  const nb::object &maybeIp) {
3448  std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3449  std::tuple<int, bool> opRegionSpec =
3450  nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3451  nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3452  nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3453  return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3454  resultSegmentSpec, resultTypeList,
3455  operandList, attributes, successors,
3456  regions, location, maybeIp);
3457  },
3458  nb::arg("cls"), nb::arg("results").none() = nb::none(),
3459  nb::arg("operands").none() = nb::none(),
3460  nb::arg("attributes").none() = nb::none(),
3461  nb::arg("successors").none() = nb::none(),
3462  nb::arg("regions").none() = nb::none(),
3463  nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3464  "Builds a specific, generated OpView based on class level attributes.");
3465  opViewClass.attr("parse") = classmethod(
3466  [](const nb::object &cls, const std::string &sourceStr,
3467  const std::string &sourceName, DefaultingPyMlirContext context) {
3468  PyOperationRef parsed =
3469  PyOperation::parse(context->getRef(), sourceStr, sourceName);
3470 
3471  // Check if the expected operation was parsed, and cast to to the
3472  // appropriate `OpView` subclass if successful.
3473  // NOTE: This accesses attributes that have been automatically added to
3474  // `OpView` subclasses, and is not intended to be used on `OpView`
3475  // directly.
3476  std::string clsOpName =
3477  nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3478  MlirStringRef identifier =
3480  std::string_view parsedOpName(identifier.data, identifier.length);
3481  if (clsOpName != parsedOpName)
3482  throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3483  parsedOpName + "'");
3484  return PyOpView::constructDerived(cls, parsed.getObject());
3485  },
3486  nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3487  nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
3488  "Parses a specific, generated OpView based on class level attributes");
3489 
3490  //----------------------------------------------------------------------------
3491  // Mapping of PyRegion.
3492  //----------------------------------------------------------------------------
3493  nb::class_<PyRegion>(m, "Region")
3494  .def_prop_ro(
3495  "blocks",
3496  [](PyRegion &self) {
3497  return PyBlockList(self.getParentOperation(), self.get());
3498  },
3499  "Returns a forward-optimized sequence of blocks.")
3500  .def_prop_ro(
3501  "owner",
3502  [](PyRegion &self) {
3503  return self.getParentOperation()->createOpView();
3504  },
3505  "Returns the operation owning this region.")
3506  .def(
3507  "__iter__",
3508  [](PyRegion &self) {
3509  self.checkValid();
3510  MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3511  return PyBlockIterator(self.getParentOperation(), firstBlock);
3512  },
3513  "Iterates over blocks in the region.")
3514  .def("__eq__",
3515  [](PyRegion &self, PyRegion &other) {
3516  return self.get().ptr == other.get().ptr;
3517  })
3518  .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3519 
3520  //----------------------------------------------------------------------------
3521  // Mapping of PyBlock.
3522  //----------------------------------------------------------------------------
3523  nb::class_<PyBlock>(m, "Block")
3524  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3525  .def_prop_ro(
3526  "owner",
3527  [](PyBlock &self) {
3528  return self.getParentOperation()->createOpView();
3529  },
3530  "Returns the owning operation of this block.")
3531  .def_prop_ro(
3532  "region",
3533  [](PyBlock &self) {
3534  MlirRegion region = mlirBlockGetParentRegion(self.get());
3535  return PyRegion(self.getParentOperation(), region);
3536  },
3537  "Returns the owning region of this block.")
3538  .def_prop_ro(
3539  "arguments",
3540  [](PyBlock &self) {
3541  return PyBlockArgumentList(self.getParentOperation(), self.get());
3542  },
3543  "Returns a list of block arguments.")
3544  .def(
3545  "add_argument",
3546  [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3547  return mlirBlockAddArgument(self.get(), type, loc);
3548  },
3549  "Append an argument of the specified type to the block and returns "
3550  "the newly added argument.")
3551  .def(
3552  "erase_argument",
3553  [](PyBlock &self, unsigned index) {
3554  return mlirBlockEraseArgument(self.get(), index);
3555  },
3556  "Erase the argument at 'index' and remove it from the argument list.")
3557  .def_prop_ro(
3558  "operations",
3559  [](PyBlock &self) {
3560  return PyOperationList(self.getParentOperation(), self.get());
3561  },
3562  "Returns a forward-optimized sequence of operations.")
3563  .def_static(
3564  "create_at_start",
3565  [](PyRegion &parent, const nb::sequence &pyArgTypes,
3566  const std::optional<nb::sequence> &pyArgLocs) {
3567  parent.checkValid();
3568  MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3569  mlirRegionInsertOwnedBlock(parent, 0, block);
3570  return PyBlock(parent.getParentOperation(), block);
3571  },
3572  nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3573  nb::arg("arg_locs") = std::nullopt,
3574  "Creates and returns a new Block at the beginning of the given "
3575  "region (with given argument types and locations).")
3576  .def(
3577  "append_to",
3578  [](PyBlock &self, PyRegion &region) {
3579  MlirBlock b = self.get();
3581  mlirBlockDetach(b);
3582  mlirRegionAppendOwnedBlock(region.get(), b);
3583  },
3584  "Append this block to a region, transferring ownership if necessary")
3585  .def(
3586  "create_before",
3587  [](PyBlock &self, const nb::args &pyArgTypes,
3588  const std::optional<nb::sequence> &pyArgLocs) {
3589  self.checkValid();
3590  MlirBlock block =
3591  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3592  MlirRegion region = mlirBlockGetParentRegion(self.get());
3593  mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3594  return PyBlock(self.getParentOperation(), block);
3595  },
3596  nb::arg("arg_types"), nb::kw_only(),
3597  nb::arg("arg_locs") = std::nullopt,
3598  "Creates and returns a new Block before this block "
3599  "(with given argument types and locations).")
3600  .def(
3601  "create_after",
3602  [](PyBlock &self, const nb::args &pyArgTypes,
3603  const std::optional<nb::sequence> &pyArgLocs) {
3604  self.checkValid();
3605  MlirBlock block =
3606  createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3607  MlirRegion region = mlirBlockGetParentRegion(self.get());
3608  mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3609  return PyBlock(self.getParentOperation(), block);
3610  },
3611  nb::arg("arg_types"), nb::kw_only(),
3612  nb::arg("arg_locs") = std::nullopt,
3613  "Creates and returns a new Block after this block "
3614  "(with given argument types and locations).")
3615  .def(
3616  "__iter__",
3617  [](PyBlock &self) {
3618  self.checkValid();
3619  MlirOperation firstOperation =
3621  return PyOperationIterator(self.getParentOperation(),
3622  firstOperation);
3623  },
3624  "Iterates over operations in the block.")
3625  .def("__eq__",
3626  [](PyBlock &self, PyBlock &other) {
3627  return self.get().ptr == other.get().ptr;
3628  })
3629  .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3630  .def("__hash__",
3631  [](PyBlock &self) {
3632  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3633  })
3634  .def(
3635  "__str__",
3636  [](PyBlock &self) {
3637  self.checkValid();
3638  PyPrintAccumulator printAccum;
3639  mlirBlockPrint(self.get(), printAccum.getCallback(),
3640  printAccum.getUserData());
3641  return printAccum.join();
3642  },
3643  "Returns the assembly form of the block.")
3644  .def(
3645  "append",
3646  [](PyBlock &self, PyOperationBase &operation) {
3647  if (operation.getOperation().isAttached())
3648  operation.getOperation().detachFromParent();
3649 
3650  MlirOperation mlirOperation = operation.getOperation().get();
3651  mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3652  operation.getOperation().setAttached(
3653  self.getParentOperation().getObject());
3654  },
3655  nb::arg("operation"),
3656  "Appends an operation to this block. If the operation is currently "
3657  "in another block, it will be moved.");
3658 
3659  //----------------------------------------------------------------------------
3660  // Mapping of PyInsertionPoint.
3661  //----------------------------------------------------------------------------
3662 
3663  nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3664  .def(nb::init<PyBlock &>(), nb::arg("block"),
3665  "Inserts after the last operation but still inside the block.")
3666  .def("__enter__", &PyInsertionPoint::contextEnter)
3667  .def("__exit__", &PyInsertionPoint::contextExit,
3668  nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3669  nb::arg("traceback").none())
3670  .def_prop_ro_static(
3671  "current",
3672  [](nb::object & /*class*/) {
3673  auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3674  if (!ip)
3675  throw nb::value_error("No current InsertionPoint");
3676  return ip;
3677  },
3678  "Gets the InsertionPoint bound to the current thread or raises "
3679  "ValueError if none has been set")
3680  .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3681  "Inserts before a referenced operation.")
3682  .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3683  nb::arg("block"), "Inserts at the beginning of the block.")
3684  .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3685  nb::arg("block"), "Inserts before the block terminator.")
3686  .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3687  "Inserts an operation.")
3688  .def_prop_ro(
3689  "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3690  "Returns the block that this InsertionPoint points to.")
3691  .def_prop_ro(
3692  "ref_operation",
3693  [](PyInsertionPoint &self) -> nb::object {
3694  auto refOperation = self.getRefOperation();
3695  if (refOperation)
3696  return refOperation->getObject();
3697  return nb::none();
3698  },
3699  "The reference operation before which new operations are "
3700  "inserted, or None if the insertion point is at the end of "
3701  "the block");
3702 
3703  //----------------------------------------------------------------------------
3704  // Mapping of PyAttribute.
3705  //----------------------------------------------------------------------------
3706  nb::class_<PyAttribute>(m, "Attribute")
3707  // Delegate to the PyAttribute copy constructor, which will also lifetime
3708  // extend the backing context which owns the MlirAttribute.
3709  .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3710  "Casts the passed attribute to the generic Attribute")
3711  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3712  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3713  .def_static(
3714  "parse",
3715  [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3716  PyMlirContext::ErrorCapture errors(context->getRef());
3717  MlirAttribute attr = mlirAttributeParseGet(
3718  context->get(), toMlirStringRef(attrSpec));
3719  if (mlirAttributeIsNull(attr))
3720  throw MLIRError("Unable to parse attribute", errors.take());
3721  return attr;
3722  },
3723  nb::arg("asm"), nb::arg("context").none() = nb::none(),
3724  "Parses an attribute from an assembly form. Raises an MLIRError on "
3725  "failure.")
3726  .def_prop_ro(
3727  "context",
3728  [](PyAttribute &self) { return self.getContext().getObject(); },
3729  "Context that owns the Attribute")
3730  .def_prop_ro("type",
3731  [](PyAttribute &self) { return mlirAttributeGetType(self); })
3732  .def(
3733  "get_named",
3734  [](PyAttribute &self, std::string name) {
3735  return PyNamedAttribute(self, std::move(name));
3736  },
3737  nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3738  .def("__eq__",
3739  [](PyAttribute &self, PyAttribute &other) { return self == other; })
3740  .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3741  .def("__hash__",
3742  [](PyAttribute &self) {
3743  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3744  })
3745  .def(
3746  "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3748  .def(
3749  "__str__",
3750  [](PyAttribute &self) {
3751  PyPrintAccumulator printAccum;
3752  mlirAttributePrint(self, printAccum.getCallback(),
3753  printAccum.getUserData());
3754  return printAccum.join();
3755  },
3756  "Returns the assembly form of the Attribute.")
3757  .def("__repr__",
3758  [](PyAttribute &self) {
3759  // Generally, assembly formats are not printed for __repr__ because
3760  // this can cause exceptionally long debug output and exceptions.
3761  // However, attribute values are generally considered useful and
3762  // are printed. This may need to be re-evaluated if debug dumps end
3763  // up being excessive.
3764  PyPrintAccumulator printAccum;
3765  printAccum.parts.append("Attribute(");
3766  mlirAttributePrint(self, printAccum.getCallback(),
3767  printAccum.getUserData());
3768  printAccum.parts.append(")");
3769  return printAccum.join();
3770  })
3771  .def_prop_ro("typeid",
3772  [](PyAttribute &self) -> MlirTypeID {
3773  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3774  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3775  "mlirTypeID was expected to be non-null.");
3776  return mlirTypeID;
3777  })
3779  MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3780  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3781  "mlirTypeID was expected to be non-null.");
3782  std::optional<nb::callable> typeCaster =
3783  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3784  mlirAttributeGetDialect(self));
3785  if (!typeCaster)
3786  return nb::cast(self);
3787  return typeCaster.value()(self);
3788  });
3789 
3790  //----------------------------------------------------------------------------
3791  // Mapping of PyNamedAttribute
3792  //----------------------------------------------------------------------------
3793  nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3794  .def("__repr__",
3795  [](PyNamedAttribute &self) {
3796  PyPrintAccumulator printAccum;
3797  printAccum.parts.append("NamedAttribute(");
3798  printAccum.parts.append(
3799  nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3800  mlirIdentifierStr(self.namedAttr.name).length));
3801  printAccum.parts.append("=");
3802  mlirAttributePrint(self.namedAttr.attribute,
3803  printAccum.getCallback(),
3804  printAccum.getUserData());
3805  printAccum.parts.append(")");
3806  return printAccum.join();
3807  })
3808  .def_prop_ro(
3809  "name",
3810  [](PyNamedAttribute &self) {
3811  return mlirIdentifierStr(self.namedAttr.name);
3812  },
3813  "The name of the NamedAttribute binding")
3814  .def_prop_ro(
3815  "attr",
3816  [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3817  nb::keep_alive<0, 1>(),
3818  "The underlying generic attribute of the NamedAttribute binding");
3819 
3820  //----------------------------------------------------------------------------
3821  // Mapping of PyType.
3822  //----------------------------------------------------------------------------
3823  nb::class_<PyType>(m, "Type")
3824  // Delegate to the PyType copy constructor, which will also lifetime
3825  // extend the backing context which owns the MlirType.
3826  .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3827  "Casts the passed type to the generic Type")
3828  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3829  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3830  .def_static(
3831  "parse",
3832  [](std::string typeSpec, DefaultingPyMlirContext context) {
3833  PyMlirContext::ErrorCapture errors(context->getRef());
3834  MlirType type =
3835  mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3836  if (mlirTypeIsNull(type))
3837  throw MLIRError("Unable to parse type", errors.take());
3838  return type;
3839  },
3840  nb::arg("asm"), nb::arg("context").none() = nb::none(),
3842  .def_prop_ro(
3843  "context", [](PyType &self) { return self.getContext().getObject(); },
3844  "Context that owns the Type")
3845  .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3846  .def(
3847  "__eq__", [](PyType &self, nb::object &other) { return false; },
3848  nb::arg("other").none())
3849  .def("__hash__",
3850  [](PyType &self) {
3851  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3852  })
3853  .def(
3854  "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3855  .def(
3856  "__str__",
3857  [](PyType &self) {
3858  PyPrintAccumulator printAccum;
3859  mlirTypePrint(self, printAccum.getCallback(),
3860  printAccum.getUserData());
3861  return printAccum.join();
3862  },
3863  "Returns the assembly form of the type.")
3864  .def("__repr__",
3865  [](PyType &self) {
3866  // Generally, assembly formats are not printed for __repr__ because
3867  // this can cause exceptionally long debug output and exceptions.
3868  // However, types are an exception as they typically have compact
3869  // assembly forms and printing them is useful.
3870  PyPrintAccumulator printAccum;
3871  printAccum.parts.append("Type(");
3872  mlirTypePrint(self, printAccum.getCallback(),
3873  printAccum.getUserData());
3874  printAccum.parts.append(")");
3875  return printAccum.join();
3876  })
3878  [](PyType &self) {
3879  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3880  assert(!mlirTypeIDIsNull(mlirTypeID) &&
3881  "mlirTypeID was expected to be non-null.");
3882  std::optional<nb::callable> typeCaster =
3883  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3884  mlirTypeGetDialect(self));
3885  if (!typeCaster)
3886  return nb::cast(self);
3887  return typeCaster.value()(self);
3888  })
3889  .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3890  MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3891  if (!mlirTypeIDIsNull(mlirTypeID))
3892  return mlirTypeID;
3893  auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3894  throw nb::value_error(
3895  (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3896  });
3897 
3898  //----------------------------------------------------------------------------
3899  // Mapping of PyTypeID.
3900  //----------------------------------------------------------------------------
3901  nb::class_<PyTypeID>(m, "TypeID")
3902  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3903  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3904  // Note, this tests whether the underlying TypeIDs are the same,
3905  // not whether the wrapper MlirTypeIDs are the same, nor whether
3906  // the Python objects are the same (i.e., PyTypeID is a value type).
3907  .def("__eq__",
3908  [](PyTypeID &self, PyTypeID &other) { return self == other; })
3909  .def("__eq__",
3910  [](PyTypeID &self, const nb::object &other) { return false; })
3911  // Note, this gives the hash value of the underlying TypeID, not the
3912  // hash value of the Python object, nor the hash value of the
3913  // MlirTypeID wrapper.
3914  .def("__hash__", [](PyTypeID &self) {
3915  return static_cast<size_t>(mlirTypeIDHashValue(self));
3916  });
3917 
3918  //----------------------------------------------------------------------------
3919  // Mapping of Value.
3920  //----------------------------------------------------------------------------
3921  nb::class_<PyValue>(m, "Value")
3922  .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
3923  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3924  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3925  .def_prop_ro(
3926  "context",
3927  [](PyValue &self) { return self.getParentOperation()->getContext(); },
3928  "Context in which the value lives.")
3929  .def(
3930  "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3932  .def_prop_ro(
3933  "owner",
3934  [](PyValue &self) -> nb::object {
3935  MlirValue v = self.get();
3936  if (mlirValueIsAOpResult(v)) {
3937  assert(
3938  mlirOperationEqual(self.getParentOperation()->get(),
3939  mlirOpResultGetOwner(self.get())) &&
3940  "expected the owner of the value in Python to match that in "
3941  "the IR");
3942  return self.getParentOperation().getObject();
3943  }
3944 
3945  if (mlirValueIsABlockArgument(v)) {
3946  MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3947  return nb::cast(PyBlock(self.getParentOperation(), block));
3948  }
3949 
3950  assert(false && "Value must be a block argument or an op result");
3951  return nb::none();
3952  })
3953  .def_prop_ro("uses",
3954  [](PyValue &self) {
3955  return PyOpOperandIterator(
3956  mlirValueGetFirstUse(self.get()));
3957  })
3958  .def("__eq__",
3959  [](PyValue &self, PyValue &other) {
3960  return self.get().ptr == other.get().ptr;
3961  })
3962  .def("__eq__", [](PyValue &self, nb::object other) { return false; })
3963  .def("__hash__",
3964  [](PyValue &self) {
3965  return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3966  })
3967  .def(
3968  "__str__",
3969  [](PyValue &self) {
3970  PyPrintAccumulator printAccum;
3971  printAccum.parts.append("Value(");
3972  mlirValuePrint(self.get(), printAccum.getCallback(),
3973  printAccum.getUserData());
3974  printAccum.parts.append(")");
3975  return printAccum.join();
3976  },
3978  .def(
3979  "get_name",
3980  [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
3981  PyPrintAccumulator printAccum;
3982  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3983  if (useLocalScope)
3985  if (useNameLocAsPrefix)
3987  MlirAsmState valueState =
3988  mlirAsmStateCreateForValue(self.get(), flags);
3989  mlirValuePrintAsOperand(self.get(), valueState,
3990  printAccum.getCallback(),
3991  printAccum.getUserData());
3993  mlirAsmStateDestroy(valueState);
3994  return printAccum.join();
3995  },
3996  nb::arg("use_local_scope") = false,
3997  nb::arg("use_name_loc_as_prefix") = false)
3998  .def(
3999  "get_name",
4000  [](PyValue &self, PyAsmState &state) {
4001  PyPrintAccumulator printAccum;
4002  MlirAsmState valueState = state.get();
4003  mlirValuePrintAsOperand(self.get(), valueState,
4004  printAccum.getCallback(),
4005  printAccum.getUserData());
4006  return printAccum.join();
4007  },
4008  nb::arg("state"), kGetNameAsOperand)
4009  .def_prop_ro("type",
4010  [](PyValue &self) { return mlirValueGetType(self.get()); })
4011  .def(
4012  "set_type",
4013  [](PyValue &self, const PyType &type) {
4014  return mlirValueSetType(self.get(), type);
4015  },
4016  nb::arg("type"))
4017  .def(
4018  "replace_all_uses_with",
4019  [](PyValue &self, PyValue &with) {
4020  mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4021  },
4023  .def(
4024  "replace_all_uses_except",
4025  [](MlirValue self, MlirValue with, PyOperation &exception) {
4026  MlirOperation exceptedUser = exception.get();
4027  mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4028  },
4029  nb::arg("with"), nb::arg("exceptions"),
4031  .def(
4032  "replace_all_uses_except",
4033  [](MlirValue self, MlirValue with, nb::list exceptions) {
4034  // Convert Python list to a SmallVector of MlirOperations
4035  llvm::SmallVector<MlirOperation> exceptionOps;
4036  for (nb::handle exception : exceptions) {
4037  exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4038  }
4039 
4041  self, with, static_cast<intptr_t>(exceptionOps.size()),
4042  exceptionOps.data());
4043  },
4044  nb::arg("with"), nb::arg("exceptions"),
4047  [](PyValue &self) { return self.maybeDownCast(); })
4048  .def_prop_ro(
4049  "location",
4050  [](MlirValue self) {
4051  return PyLocation(
4052  PyMlirContext::forContext(mlirValueGetContext(self)),
4053  mlirValueGetLocation(self));
4054  },
4055  "Returns the source location the value");
4056 
4057  PyBlockArgument::bind(m);
4058  PyOpResult::bind(m);
4059  PyOpOperand::bind(m);
4060 
4061  nb::class_<PyAsmState>(m, "AsmState")
4062  .def(nb::init<PyValue &, bool>(), nb::arg("value"),
4063  nb::arg("use_local_scope") = false)
4064  .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4065  nb::arg("use_local_scope") = false);
4066 
4067  //----------------------------------------------------------------------------
4068  // Mapping of SymbolTable.
4069  //----------------------------------------------------------------------------
4070  nb::class_<PySymbolTable>(m, "SymbolTable")
4071  .def(nb::init<PyOperationBase &>())
4072  .def("__getitem__", &PySymbolTable::dunderGetItem)
4073  .def("insert", &PySymbolTable::insert, nb::arg("operation"))
4074  .def("erase", &PySymbolTable::erase, nb::arg("operation"))
4075  .def("__delitem__", &PySymbolTable::dunderDel)
4076  .def("__contains__",
4077  [](PySymbolTable &table, const std::string &name) {
4079  table, mlirStringRefCreate(name.data(), name.length())));
4080  })
4081  // Static helpers.
4082  .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
4083  nb::arg("symbol"), nb::arg("name"))
4084  .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
4085  nb::arg("symbol"))
4086  .def_static("get_visibility", &PySymbolTable::getVisibility,
4087  nb::arg("symbol"))
4088  .def_static("set_visibility", &PySymbolTable::setVisibility,
4089  nb::arg("symbol"), nb::arg("visibility"))
4090  .def_static("replace_all_symbol_uses",
4091  &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4092  nb::arg("new_symbol"), nb::arg("from_op"))
4093  .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4094  nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4095  nb::arg("callback"));
4096 
4097  // Container bindings.
4098  PyBlockArgumentList::bind(m);
4099  PyBlockIterator::bind(m);
4100  PyBlockList::bind(m);
4101  PyOperationIterator::bind(m);
4102  PyOperationList::bind(m);
4103  PyOpAttributeMap::bind(m);
4104  PyOpOperandIterator::bind(m);
4105  PyOpOperandList::bind(m);
4107  PyOpSuccessors::bind(m);
4108  PyRegionIterator::bind(m);
4109  PyRegionList::bind(m);
4110 
4111  // Debug bindings.
4113 
4114  // Attribute builder getter.
4116 
4117  nb::register_exception_translator([](const std::exception_ptr &p,
4118  void *payload) {
4119  // We can't define exceptions with custom fields through pybind, so instead
4120  // the exception class is defined in python and imported here.
4121  try {
4122  if (p)
4123  std::rethrow_exception(p);
4124  } catch (const MLIRError &e) {
4125  nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4126  .attr("MLIRError")(e.message, e.errorDiagnostics);
4127  PyErr_SetObject(PyExc_Exception, obj.ptr());
4128  }
4129  });
4130 }
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition: Debug.cpp:20
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools.
Definition: Debug.cpp:28
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition: Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition: Debug.cpp:16
static const char kOperationPrintStateDocstring[]
Definition: IRCore.cpp:119
static const char kValueReplaceAllUsesWithDocstring[]
Definition: IRCore.cpp:181
static const char kContextGetNameLocationDocString[]
Definition: IRCore.cpp:62
static const char kGetNameAsOperand[]
Definition: IRCore.cpp:177
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:216
static const char kModuleParseDocstring[]
Definition: IRCore.cpp:65
static const char kOperationStrDunderDocstring[]
Definition: IRCore.cpp:151
static const char kOperationPrintDocstring[]
Definition: IRCore.cpp:92
static MlirValue getUniqueResult(MlirOperation operation)
Definition: IRCore.cpp:1885
static const char kContextGetFileLocationDocstring[]
Definition: IRCore.cpp:53
static const char kDumpDocstring[]
Definition: IRCore.cpp:159
static const char kAppendBlockDocstring[]
Definition: IRCore.cpp:162
static MlirValue getOpResultOrValue(nb::handle operand)
Definition: IRCore.cpp:1900
static const char kContextGetFusedLocationDocstring[]
Definition: IRCore.cpp:59
static const char kContextGetFileRangeDocstring[]
Definition: IRCore.cpp:56
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition: IRCore.cpp:1491
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition: IRCore.cpp:204
nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition: IRCore.cpp:198
static const char kOperationPrintBytecodeDocstring[]
Definition: IRCore.cpp:141
static const char kOperationGetAsmDocstring[]
Definition: IRCore.cpp:128
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:230
static const char kOperationCreateDocstring[]
Definition: IRCore.cpp:73
static std::vector< MlirType > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition: IRCore.cpp:1725
static const char kContextParseTypeDocstring[]
Definition: IRCore.cpp:42
static void populateResultTypes(StringRef name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition: IRCore.cpp:1788
static const char kContextGetCallSiteLocationDocstring[]
Definition: IRCore.cpp:50
static const char kValueDunderStrDocstring[]
Definition: IRCore.cpp:169
static const char kValueReplaceAllUsesExceptDocstring[]
Definition: IRCore.cpp:186
static MLIRContext * getContext(OpFoldResult val)
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition: Interop.h:273
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition: Interop.h:118
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition: Interop.h:348
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition: Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition: Interop.h:189
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition: Interop.h:180
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition: Interop.h:255
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition: Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition: Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition: Interop.h:357
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition: Interop.h:235
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition: Interop.h:367
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition: Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition: Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition: Interop.h:454
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition: Interop.h:198
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition: Interop.h:330
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition: Interop.h:264
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition: Interop.h:445
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition: Interop.h:216
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::context getDefaultContext()
const float * table
A list of operation results.
Definition: IRCore.cpp:1739
PyOperationRef & getOperation()
Definition: IRCore.cpp:1761
static void bindDerived(ClassTy &c)
Definition: IRCore.cpp:1752
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition: IRCore.cpp:1744
Python wrapper for MlirOpResult.
Definition: IRCore.cpp:1703
static void bindDerived(ClassTy &c)
Definition: IRCore.cpp:1709
Accumulates into a file, either writing text (default) or binary.
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
nanobind::class_< PyOpResultList > ClassTy
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition: IRModule.h:331
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:339
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:546
static PyLocation & resolve()
Definition: IRCore.cpp:1135
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:320
static PyMlirContext & resolve()
Definition: IRCore.cpp:856
ReferrentTy * get() const
Definition: NanobindUtils.h:60
Wrapper around an MlirAsmState.
Definition: IRModule.h:817
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1038
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition: IRModule.h:1040
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition: IRCore.cpp:2193
static PyAttribute createFromCapsule(nanobind::object capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition: IRCore.cpp:2197
bool operator==(const PyAttribute &other) const
Definition: IRCore.cpp:2189
Wrapper around an MlirBlock.
Definition: IRModule.h:852
MlirBlock get()
Definition: IRModule.h:859
PyOperationRef & getParentOperation()
Definition: IRModule.h:860
Represents a diagnostic handler attached to the context.
Definition: IRModule.h:427
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition: IRCore.cpp:1012
void detach()
Detaches the handler. Does nothing if not attached.
Definition: IRCore.cpp:1018
Python class mirroring the C MlirDiagnostic struct.
Definition: IRModule.h:377
PyLocation getLocation()
Definition: IRCore.cpp:1041
nanobind::tuple getNotes()
Definition: IRCore.cpp:1056
nanobind::str getMessage()
Definition: IRCore.cpp:1048
DiagnosticInfo getInfo()
Definition: IRCore.cpp:1072
PyDiagnostic(MlirDiagnostic diagnostic)
Definition: IRModule.h:379
MlirDiagnosticSeverity getSeverity()
Definition: IRCore.cpp:1036
Wrapper around an MlirDialect.
Definition: IRModule.h:482
Wrapper around an MlirDialectRegistry.
Definition: IRModule.h:519
nanobind::object getCapsule()
Definition: IRCore.cpp:1097
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition: IRCore.cpp:1101
User-level dialect object.
Definition: IRModule.h:506
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition: IRModule.h:495
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition: IRCore.cpp:1084
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:155
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:184
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition: Globals.h:34
An insertion point maintains a pointer to a Block and a reference operation.
Definition: IRModule.h:876
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition: IRCore.cpp:2166
PyInsertionPoint(PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition: IRCore.cpp:2121
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition: IRCore.cpp:2153
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:2179
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition: IRCore.cpp:2127
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition: IRCore.cpp:2175
Wrapper around an MlirLocation.
Definition: IRModule.h:346
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition: IRCore.cpp:1113
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition: IRModule.h:348
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition: IRCore.cpp:1117
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:1129
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition: IRCore.cpp:1125
MlirLocation get() const
Definition: IRModule.h:352
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:211
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition: IRModule.h:215
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:744
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition: IRCore.cpp:699
void clearOperationAndInside(PyOperationBase &op)
Clears the operaiton and all operations inside using clearOperation(MlirOperation).
Definition: IRCore.cpp:769
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition: IRCore.cpp:780
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition: IRCore.cpp:792
size_t clearLiveOperations()
Clears the live operations map, returning the number of entries which were invalidated.
Definition: IRCore.cpp:717
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition: IRCore.cpp:663
std::vector< PyOperation * > getLiveOperationObjects()
Get a list of Python objects which are still in the live context map.
Definition: IRCore.cpp:709
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition: IRCore.cpp:786
void clearOperation(MlirOperation op)
Removes an operation from the live operations map and sets it invalid.
Definition: IRCore.cpp:730
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition: IRCore.cpp:674
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition: IRCore.cpp:667
size_t getLiveOperationCount()
Gets the count of live operations associated with this context.
Definition: IRCore.cpp:704
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition: IRCore.cpp:782
MlirModule get()
Gets the backing MlirModule.
Definition: IRModule.h:569
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition: IRCore.cpp:1162
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition: IRCore.cpp:1187
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition: IRCore.cpp:1194
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1062
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition: IRCore.cpp:2209
MlirNamedAttribute namedAttr
Definition: IRModule.h:1071
nanobind::object getObject()
Definition: IRModule.h:90
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:78
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition: IRModule.h:762
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition: IRCore.cpp:2103
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * >> successors, std::optional< int > regions, DefaultingPyLocation location, const nanobind::object &maybeIp)
Definition: IRCore.cpp:1919
PyOpView(const nanobind::object &operationObject)
Definition: IRCore.cpp:2111
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:598
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition: IRCore.cpp:1374
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void print(std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition: IRCore.cpp:1406
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition: IRCore.cpp:1432
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition: IRCore.cpp:1352
void moveBefore(PyOperationBase &other)
Definition: IRCore.cpp:1441
bool verify()
Verify the operation.
Definition: IRCore.cpp:1450
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition: IRModule.h:668
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition: IRCore.cpp:1202
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition: IRCore.cpp:1641
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition: IRCore.cpp:1477
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition: IRCore.cpp:1271
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition: IRCore.cpp:1621
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:646
PyOperationRef getRef()
Definition: IRModule.h:681
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition: IRCore.cpp:1482
MlirOperation get() const
Definition: IRModule.h:676
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition: IRCore.cpp:1251
void setAttached(const nanobind::object &parent=nanobind::object())
Definition: IRModule.h:686
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition: IRCore.cpp:1458
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1630
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition: IRCore.cpp:1468
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * >> results, llvm::ArrayRef< MlirValue > operands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * >> successors, int regions, DefaultingPyLocation location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition: IRCore.cpp:1506
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition: IRCore.cpp:1287
void checkValid() const
Definition: IRCore.cpp:1299
void setInvalid()
Invalidate the operation.
Definition: IRModule.h:729
Wrapper around an MlirRegion.
Definition: IRModule.h:798
PyOperationRef & getParentOperation()
Definition: IRModule.h:807
MlirRegion get()
Definition: IRModule.h:806
Bindings for MLIR symbol tables.
Definition: IRModule.h:1277
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition: IRCore.cpp:2330
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition: IRCore.cpp:2413
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition: IRCore.cpp:2401
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition: IRCore.cpp:2383
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition: IRCore.cpp:2357
MlirAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition: IRCore.cpp:2335
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition: IRCore.cpp:2320
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition: IRCore.cpp:2299
static MlirAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition: IRCore.cpp:2345
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition: IRCore.cpp:2307
static MlirAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition: IRCore.cpp:2372
Tracks an entry in the thread context stack.
Definition: IRModule.h:109
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition: IRCore.cpp:876
static void popLocation(PyLocation &location)
Definition: IRCore.cpp:988
static nanobind::object pushLocation(nanobind::object location)
Definition: IRCore.cpp:979
static nanobind::object pushContext(nanobind::object context)
Definition: IRCore.cpp:938
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition: IRCore.cpp:933
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition: IRCore.cpp:968
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition: IRCore.cpp:956
static void popContext(PyMlirContext &context)
Definition: IRCore.cpp:945
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition: IRCore.cpp:928
PyMlirContext * getContext()
Definition: IRCore.cpp:905
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition: IRCore.cpp:923
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition: IRCore.cpp:871
PyInsertionPoint * getInsertionPoint()
Definition: IRCore.cpp:911
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition: IRModule.h:165
MlirLlvmThreadPool get()
Definition: IRModule.h:174
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition: IRModule.h:936
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition: IRCore.cpp:2245
bool operator==(const PyTypeID &other) const
Definition: IRCore.cpp:2251
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition: IRCore.cpp:2241
PyTypeID(MlirTypeID typeID)
Definition: IRModule.h:938
Wrapper around the generic MlirType.
Definition: IRModule.h:912
PyType(PyMlirContextRef contextRef, MlirType type)
Definition: IRModule.h:914
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition: IRCore.cpp:2225
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition: IRCore.cpp:2229
bool operator==(const PyType &other) const
Definition: IRCore.cpp:2221
Wrapper around the generic MlirValue.
Definition: IRModule.h:1178
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition: IRModule.h:1184
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition: IRCore.cpp:2278
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2263
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition: IRCore.cpp:2259
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
Definition: Diagnostics.cpp:44
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
Definition: Diagnostics.cpp:28
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
Definition: Diagnostics.cpp:18
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition: Diagnostics.h:32
@ MlirDiagnosticNote
Definition: Diagnostics.h:35
@ MlirDiagnosticRemark
Definition: Diagnostics.h:36
@ MlirDiagnosticWarning
Definition: Diagnostics.h:34
@ MlirDiagnosticError
Definition: Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
Definition: Diagnostics.cpp:50
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
Definition: Diagnostics.cpp:78
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
Definition: Diagnostics.cpp:56
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
Definition: Diagnostics.cpp:72
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition: Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
Definition: Diagnostics.cpp:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition: IR.cpp:265
MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Returns the position of the value in the argument list of its block.
Definition: IR.cpp:1082
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1145
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition: IR.h:831
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition: IR.cpp:1152
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition: IR.cpp:117
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition: IR.cpp:825
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:653
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition: IR.cpp:273
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition: IR.cpp:1334
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:137
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op)
Returns the number of results of the operation.
Definition: IR.cpp:712
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition: IR.cpp:311
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition: IR.cpp:1313
MlirWalkOrder
Traversal order for operation walk.
Definition: IR.h:824
@ MlirWalkPreOrder
Definition: IR.h:825
@ MlirWalkPostOrder
Definition: IR.h:826
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Print the name and location, if NamedLoc, as a prefix to the SSA ID.
Definition: IR.cpp:229
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Return pos-th attribute of the operation.
Definition: IR.cpp:786
MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition: IR.cpp:501
MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module)
Takes a module owned by the caller and deletes it.
Definition: IR.cpp:454
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1261
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition: IR.cpp:392
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition: IR.cpp:1318
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Use local scope when printing the operation.
Definition: IR.cpp:233
MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value)
Returns 1 if the value is a block argument, 0 otherwise.
Definition: IR.cpp:1070
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition: IR.cpp:84
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1282
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
Definition: IR.h:406
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition: IR.cpp:1195
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition: IR.cpp:1178
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1234
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition: IR.cpp:73
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition: IR.cpp:926
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition: IR.cpp:321
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition: IR.cpp:355
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition: IR.cpp:1134
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Returns pos-th successor of the operation.
Definition: IR.cpp:724
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition: IR.cpp:1117
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition: IR.cpp:403
MLIR_CAPI_EXPORTED void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1215
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise.
Definition: IR.cpp:796
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition: IR.cpp:1166
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Enables the elision of large resources strings by omitting them from the dialect_resources section.
Definition: IR.cpp:215
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition: IR.cpp:289
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute)
Gets the dialect of the attribute.
Definition: IR.cpp:1245
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition: IR.cpp:361
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints an attribute by sending chunks of the string representation and forwarding userData tocallback...
Definition: IR.cpp:1253
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition: IR.cpp:965
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition: IR.cpp:849
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
Definition: IR.h:994
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition: IR.cpp:1138
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition: IR.cpp:817
MlirWalkResult
Operation walk result.
Definition: IR.h:817
@ MlirWalkResultInterrupt
Definition: IR.h:819
@ MlirWalkResultSkip
Definition: IR.h:820
@ MlirWalkResultAdvance
Definition: IR.h:818
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition: IR.cpp:906
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Returns an attribute attached to the operation given its name.
Definition: IR.cpp:791
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1110
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition: IR.cpp:100
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Returns the number of successor blocks of the operation.
Definition: IR.cpp:720
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op)
Creates a deep copy of an operation.
Definition: IR.cpp:627
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition: IR.cpp:1034
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition: IR.cpp:293
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:225
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition: IR.cpp:347
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition: IR.cpp:1015
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition: IR.cpp:196
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition: IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition: IR.cpp:95
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition: IR.cpp:343
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:210
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition: IR.cpp:912
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Sets the type of the block argument to the given type.
Definition: IR.cpp:1087
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:639
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition: IR.cpp:949
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition: IR.h:913
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition: IR.cpp:990
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition: IR.cpp:1052
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition: IR.cpp:1308
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1199
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition: IR.cpp:281
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value)
Prints the value to the standard error stream.
Definition: IR.cpp:1109
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location)
Creates a new, empty module and transfers ownership to the caller.
Definition: IR.cpp:425
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition: IR.cpp:1164
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition: IR.cpp:1298
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition: IR.cpp:407
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition: IR.cpp:299
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition: IR.cpp:375
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetBlock(MlirOperation op)
Gets the block that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:657
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:370
MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Checks whether two operation handles point to the same operation.
Definition: IR.cpp:635
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition: IR.cpp:1038
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition: IR.cpp:811
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition: IR.cpp:1124
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:415
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition: IR.cpp:112
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition: IR.cpp:841
MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module)
Views the module as a generic operation.
Definition: IR.cpp:460
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1211
MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Constructs an operation state from a name and a location.
Definition: IR.cpp:472
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition: IR.cpp:1174
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition: IR.cpp:334
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition: IR.cpp:980
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Skip printing regions.
Definition: IR.cpp:241
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition: IR.cpp:388
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Returns an operation immediately following the given operation it its enclosing block.
Definition: IR.cpp:689
MLIR_CAPI_EXPORTED MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Gets the operation that owns this operation, returning null if the operation is not owned.
Definition: IR.cpp:661
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module)
Gets the context that a module was created with.
Definition: IR.cpp:446
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition: IR.cpp:269
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Do not verify the operation when using custom operation printers.
Definition: IR.cpp:237
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition: IR.cpp:1203
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition: IR.cpp:1294
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition: IR.h:182
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Destroys printing flags created with mlirBytecodeWriterConfigCreate.
Definition: IR.cpp:252
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Returns pos-th operand of the operation.
Definition: IR.cpp:697
MLIR_CAPI_EXPORTED void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition: IR.cpp:513
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition: IR.cpp:969
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition: IR.cpp:325
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition: IR.cpp:399
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value)
Returns an operation that produced this value as its result.
Definition: IR.cpp:1092
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value)
Returns 1 if the value is an operation result, 0 otherwise.
Definition: IR.cpp:1074
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op)
Returns the number of operands of the operation.
Definition: IR.cpp:693
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition: IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition: IR.cpp:867
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition: IR.cpp:55
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition: IR.cpp:961
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumRegions(MlirOperation op)
Returns the number of regions attached to the given operation.
Definition: IR.cpp:665
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition: IR.cpp:411
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: IR.cpp:1043
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Removes an attribute by name.
Definition: IR.cpp:801
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition: IR.cpp:1259
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition: IR.cpp:1323
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Parses an attribute. The attribute is owned by the context.
Definition: IR.cpp:1226
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Parses a module from the string and transfers ownership to the caller.
Definition: IR.cpp:429
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition: IR.cpp:902
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition: IR.cpp:973
MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type)
Prints the type to the standard error stream.
Definition: IR.cpp:1220
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Returns pos-th result of the operation.
Definition: IR.cpp:716
MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:248
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1230
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Returns the block in which this value is defined as an argument.
Definition: IR.cpp:1078
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition: IR.h:852
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op)
Takes an operation owned by the caller and destroys it.
Definition: IR.cpp:631
MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Returns pos-th region attached to the operation.
Definition: IR.cpp:669
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition: IR.cpp:1207
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1270
MLIR_CAPI_EXPORTED void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Set pos-th successor of the operation.
Definition: IR.cpp:777
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition: IR.cpp:108
MLIR_CAPI_EXPORTED void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition: IR.cpp:505
MLIR_CAPI_EXPORTED void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition: IR.cpp:509
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition: IR.cpp:121
MLIR_CAPI_EXPORTED MlirBlock mlirModuleGetBody(MlirModule module)
Gets the body of the module, i.e. the only block it contains.
Definition: IR.cpp:450
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition: IR.cpp:305
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:206
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition: IR.cpp:379
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition: IR.cpp:104
MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1056
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition: IR.cpp:329
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition: IR.cpp:1156
MLIR_CAPI_EXPORTED void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Sets the version to emit in the writer config.
Definition: IR.cpp:256
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition: IR.cpp:1290
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition: IR.cpp:889
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Parses a module from file and transfers ownership to the caller.
Definition: IR.cpp:437
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition: IR.cpp:1029
MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Adds a list of components to the operation state.
Definition: IR.cpp:496
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:220
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op)
Gets the location of the operation.
Definition: IR.cpp:643
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute)
Gets the type id of the attribute.
Definition: IR.cpp:1241
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Sets the pos-th operand of the operation.
Definition: IR.cpp:701
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition: IR.cpp:839
MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value)
Returns the position of the value in the list of results of the operation that produced it.
Definition: IR.cpp:1096
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:202
MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Creates new AsmState from value.
Definition: IR.cpp:178
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state)
Creates an operation and transfers ownership to the caller.
Definition: IR.cpp:581
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition: IR.h:1200
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition: IR.cpp:77
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition: IR.cpp:845
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Returns the number of attributes attached to the operation.
Definition: IR.cpp:782
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a value by sending chunks of the string representation and forwarding userData tocallback`.
Definition: IR.cpp:1111
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition: IR.cpp:832
MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type)
Set the type of the value.
Definition: IR.cpp:1105
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value)
Returns the type of the value.
Definition: IR.cpp:1101
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition: IR.cpp:71
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Parses an operation, giving ownership to the caller.
Definition: IR.cpp:618
MLIR_CAPI_EXPORTED bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Checks if two attributes are equal.
Definition: IR.cpp:1249
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
Definition: IR.h:615
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition: IR.cpp:895
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition: Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Definition: Polynomial.h:262
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition: IRModule.h:187
PyObjectRef< PyModule > PyModuleRef
Definition: IRModule.h:558
void populateIRCore(nanobind::module_ &m)
PyObjectRef< PyOperation > PyOperationRef
Definition: IRModule.h:642
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition: Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
An auxiliary class for constructing operations.
Definition: IR.h:432
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition: IRCore.cpp:295
static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace)
Definition: IRCore.cpp:304
static nb::callable dunderGetItemNamed(const std::string &attributeKind)
Definition: IRCore.cpp:298
static void bind(nb::module_ &m)
Definition: IRCore.cpp:310
Wrapper for the global LLVM debugging flag.
Definition: IRCore.cpp:255
static void bind(nb::module_ &m)
Definition: IRCore.cpp:266
static void set(nb::object &o, bool enable)
Definition: IRCore.cpp:256
static bool get(const nb::object &)
Definition: IRCore.cpp:261
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1329
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition: IRModule.h:1334
Materialized diagnostic information.
Definition: IRModule.h:389
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:455
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:465
ErrorCapture(PyMlirContextRef ctx)
Definition: IRModule.h:456