MLIR 23.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
13#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14// clang-format on
16#include "mlir-c/Debug.h"
17#include "mlir-c/Diagnostics.h"
19#include "mlir-c/IR.h"
20#include "mlir-c/Support.h"
21
22#include <array>
23#include <cassert>
24#include <functional>
25#include <optional>
26#include <string>
27
28namespace nb = nanobind;
29using namespace nb::literals;
30using namespace mlir;
32
33static const char kModuleParseDocstring[] =
34 R"(Parses a module's assembly format from a string.
35
36Returns a new MlirModule or raises an MLIRError if the parsing fails.
37
38See also: https://mlir.llvm.org/docs/LangRef/
39)";
40
41static const char kDumpDocstring[] =
42 "Dumps a debug representation of the object to stderr.";
43
45 R"(Replace all uses of this value with the `with` value, except for those
46in `exceptions`. `exceptions` can be either a single operation or a list of
47operations.
48)";
49
50//------------------------------------------------------------------------------
51// Utilities.
52//------------------------------------------------------------------------------
53
54/// Local helper to compute std::hash for a value.
55template <typename T>
56static size_t hash(const T &value) {
57 return std::hash<T>{}(value);
58}
59
60static nb::object
61createCustomDialectWrapper(const std::string &dialectNamespace,
62 nb::object dialectDescriptor) {
63 auto dialectClass =
65 dialectNamespace);
66 if (!dialectClass) {
67 // Use the base class.
69 std::move(dialectDescriptor)));
70 }
71
72 // Create the custom implementation.
73 return (*dialectClass)(std::move(dialectDescriptor));
74}
75
76namespace mlir {
77namespace python {
79
80MlirBlock createBlock(
81 const nb::typed<nb::sequence, PyType> &pyArgTypes,
82 const std::optional<nb::typed<nb::sequence, PyLocation>> &pyArgLocs) {
83 std::vector<MlirType> argTypes;
84 argTypes.reserve(nb::len(pyArgTypes));
85 for (nb::handle pyType : pyArgTypes)
86 argTypes.push_back(
87 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
88
89 std::vector<MlirLocation> argLocs;
90 if (pyArgLocs) {
91 argLocs.reserve(nb::len(*pyArgLocs));
92 for (nb::handle pyLoc : *pyArgLocs)
93 argLocs.push_back(
94 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
95 } else if (!argTypes.empty()) {
96 argLocs.assign(
97 argTypes.size(),
99 }
100
101 if (argTypes.size() != argLocs.size())
102 throw nb::value_error(
103 join("Expected ", argTypes.size(), " locations, got: ", argLocs.size())
104 .c_str());
105 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
106}
107
108void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
109 nb::ft_lock_guard lock(mutex);
110 mlirEnableGlobalDebug(enable);
111}
112
113bool PyGlobalDebugFlag::get(const nb::object &) {
114 nb::ft_lock_guard lock(mutex);
116}
117
118void PyGlobalDebugFlag::bind(nb::module_ &m) {
119 // Debug flags.
120 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
121 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
122 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
123 .def_static(
124 "set_types",
125 [](const std::string &type) {
126 nb::ft_lock_guard lock(mutex);
127 mlirSetGlobalDebugType(type.c_str());
128 },
129 "types"_a, "Sets specific debug types to be produced by LLVM.")
130 .def_static(
131 "set_types",
132 [](const std::vector<std::string> &types) {
133 std::vector<const char *> pointers;
134 pointers.reserve(types.size());
135 for (const std::string &str : types)
136 pointers.push_back(str.c_str());
137 nb::ft_lock_guard lock(mutex);
138 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
139 },
140 "types"_a,
141 "Sets multiple specific debug types to be produced by LLVM.");
142}
143
144nb::ft_mutex PyGlobalDebugFlag::mutex;
145
146bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
147 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
148}
149
150nb::callable
151PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
152 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
153 if (!builder)
154 throw nb::key_error(attributeKind.c_str());
155 return *builder;
156}
157
158void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
159 nb::callable func, bool replace,
160 bool allow_existing) {
161 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
162 replace, allow_existing);
163}
164
165void PyAttrBuilderMap::bind(nb::module_ &m) {
166 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
167 .def_static("contains", &PyAttrBuilderMap::dunderContains,
168 "attribute_kind"_a,
169 "Checks whether an attribute builder is registered for the "
170 "given attribute kind.")
171 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
172 "attribute_kind"_a,
173 "Gets the registered attribute builder for the given "
174 "attribute kind.")
175 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
176 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
177 "allow_existing"_a = false,
178 "Register an attribute builder for building MLIR "
179 "attributes from Python values.");
180}
181
182//------------------------------------------------------------------------------
183// PyBlock
184//------------------------------------------------------------------------------
185
187 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
188}
189
190//------------------------------------------------------------------------------
191// Collections.
192//------------------------------------------------------------------------------
193
197 length == -1 ? mlirOperationGetNumRegions(operation->get())
198 : length,
199 step),
200 operation(std::move(operation)) {}
201
202intptr_t PyRegionList::getRawNumElements() {
203 operation->checkValid();
204 return mlirOperationGetNumRegions(operation->get());
205}
206
207PyRegion PyRegionList::getRawElement(intptr_t pos) {
208 operation->checkValid();
209 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
210}
211
212PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
213 intptr_t step) const {
214 return PyRegionList(operation, startIndex, length, step);
215}
216
217nb::typed<nb::object, PyBlock> PyBlockIterator::dunderNext() {
218 operation->checkValid();
219 if (mlirBlockIsNull(next)) {
220 PyErr_SetNone(PyExc_StopIteration);
221 // python functions should return NULL after setting any exception
222 return nb::object();
223 }
224
225 PyBlock returnBlock(operation, next);
226 next = mlirBlockGetNextInRegion(next);
227 return nb::cast(returnBlock);
228}
229
230void PyBlockIterator::bind(nb::module_ &m) {
231 nb::class_<PyBlockIterator>(m, "BlockIterator")
232 .def("__iter__", &PyBlockIterator::dunderIter,
233 "Returns an iterator over the blocks in the operation's region.")
234 .def("__next__", &PyBlockIterator::dunderNext,
235 "Returns the next block in the iteration.");
236}
237
239 operation->checkValid();
240 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
241}
242
244 operation->checkValid();
245 intptr_t count = 0;
246 MlirBlock block = mlirRegionGetFirstBlock(region);
247 while (!mlirBlockIsNull(block)) {
248 count += 1;
249 block = mlirBlockGetNextInRegion(block);
250 }
251 return count;
252}
253
255 operation->checkValid();
256 if (index < 0) {
257 index += dunderLen();
258 }
259 if (index < 0) {
260 throw nb::index_error("attempt to access out of bounds block");
261 }
262 MlirBlock block = mlirRegionGetFirstBlock(region);
263 while (!mlirBlockIsNull(block)) {
264 if (index == 0) {
265 return PyBlock(operation, block);
266 }
267 block = mlirBlockGetNextInRegion(block);
268 index -= 1;
269 }
270 throw nb::index_error("attempt to access out of bounds block");
271}
272
273PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
274 const std::optional<nb::sequence> &pyArgLocs) {
275 operation->checkValid();
276 MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
277 mlirRegionAppendOwnedBlock(region, block);
278 return PyBlock(operation, block);
279}
280
281void PyBlockList::bind(nb::module_ &m) {
282 nb::class_<PyBlockList>(m, "BlockList")
283 .def("__getitem__", &PyBlockList::dunderGetItem,
284 "Returns the block at the specified index.")
285 .def("__iter__", &PyBlockList::dunderIter,
286 "Returns an iterator over blocks in the operation's region.")
287 .def("__len__", &PyBlockList::dunderLen,
288 "Returns the number of blocks in the operation's region.")
289 .def("append", &PyBlockList::appendBlock,
290 R"(
291 Appends a new block, with argument types as positional args.
292
293 Returns:
294 The created block.
295 )",
296 "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
297}
298
299nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
300 parentOperation->checkValid();
301 if (mlirOperationIsNull(next)) {
302 PyErr_SetNone(PyExc_StopIteration);
303 // python functions should return NULL after setting any exception
304 return nb::object();
305 }
306
307 PyOperationRef returnOperation =
308 PyOperation::forOperation(parentOperation->getContext(), next);
309 next = mlirOperationGetNextInBlock(next);
310 return returnOperation->createOpView();
311}
312
313void PyOperationIterator::bind(nb::module_ &m) {
314 nb::class_<PyOperationIterator>(m, "OperationIterator")
315 .def("__iter__", &PyOperationIterator::dunderIter,
316 "Returns an iterator over the operations in an operation's block.")
317 .def("__next__", &PyOperationIterator::dunderNext,
318 "Returns the next operation in the iteration.");
319}
320
322 parentOperation->checkValid();
323 return PyOperationIterator(parentOperation,
325}
326
328 parentOperation->checkValid();
329 intptr_t count = 0;
330 MlirOperation childOp = mlirBlockGetFirstOperation(block);
331 while (!mlirOperationIsNull(childOp)) {
332 count += 1;
333 childOp = mlirOperationGetNextInBlock(childOp);
334 }
335 return count;
336}
337
338nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
339 parentOperation->checkValid();
340 if (index < 0) {
341 index += dunderLen();
342 }
343 if (index < 0) {
344 throw nb::index_error("attempt to access out of bounds operation");
345 }
346 MlirOperation childOp = mlirBlockGetFirstOperation(block);
347 while (!mlirOperationIsNull(childOp)) {
348 if (index == 0) {
349 return PyOperation::forOperation(parentOperation->getContext(), childOp)
350 ->createOpView();
351 }
352 childOp = mlirOperationGetNextInBlock(childOp);
353 index -= 1;
354 }
355 throw nb::index_error("attempt to access out of bounds operation");
356}
357
358void PyOperationList::bind(nb::module_ &m) {
359 nb::class_<PyOperationList>(m, "OperationList")
360 .def("__getitem__", &PyOperationList::dunderGetItem,
361 "Returns the operation at the specified index.")
362 .def("__iter__", &PyOperationList::dunderIter,
363 "Returns an iterator over operations in the list.")
364 .def("__len__", &PyOperationList::dunderLen,
365 "Returns the number of operations in the list.");
366}
367
368nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
369 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
373}
375size_t PyOpOperand::getOperandNumber() const {
376 return mlirOpOperandGetOperandNumber(opOperand);
377}
378
379void PyOpOperand::bind(nb::module_ &m) {
380 nb::class_<PyOpOperand>(m, "OpOperand")
381 .def_prop_ro("owner", &PyOpOperand::getOwner,
382 "Returns the operation that owns this operand.")
383 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
384 "Returns the operand number in the owning operation.");
385}
386
387nb::typed<nb::object, PyOpOperand> PyOpOperandIterator::dunderNext() {
388 if (mlirOpOperandIsNull(opOperand)) {
389 PyErr_SetNone(PyExc_StopIteration);
390 // python functions should return NULL after setting any exception
391 return nb::object();
392 }
393
394 PyOpOperand returnOpOperand(opOperand);
395 opOperand = mlirOpOperandGetNextUse(opOperand);
396 return nb::cast(returnOpOperand);
397}
398
399void PyOpOperandIterator::bind(nb::module_ &m) {
400 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
401 .def("__iter__", &PyOpOperandIterator::dunderIter,
402 "Returns an iterator over operands.")
403 .def("__next__", &PyOpOperandIterator::dunderNext,
404 "Returns the next operand in the iteration.");
405}
407//------------------------------------------------------------------------------
408// PyThreadPool
409//------------------------------------------------------------------------------
410
412
414 if (threadPool.ptr)
415 mlirLlvmThreadPoolDestroy(threadPool);
416}
419 return mlirLlvmThreadPoolGetMaxConcurrency(threadPool);
420}
421
422std::string PyThreadPool::_mlir_thread_pool_ptr() const {
423 std::stringstream ss;
424 ss << threadPool.ptr;
425 return ss.str();
426}
428//------------------------------------------------------------------------------
429// PyMlirContext
430//------------------------------------------------------------------------------
431
432PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433 nb::gil_scoped_acquire acquire;
434 nb::ft_lock_guard lock(live_contexts_mutex);
435 auto &liveContexts = getLiveContexts();
436 liveContexts[context.ptr] = this;
437}
438
440 // Note that the only public way to construct an instance is via the
441 // forContext method, which always puts the associated handle into
442 // liveContexts.
443 nb::gil_scoped_acquire acquire;
444 {
445 nb::ft_lock_guard lock(live_contexts_mutex);
446 getLiveContexts().erase(context.ptr);
447 }
448 mlirContextDestroy(context);
449}
452 return PyMlirContextRef(this, nb::cast(this));
453}
455nb::object PyMlirContext::getCapsule() {
456 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
457}
458
459nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
460 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
461 if (mlirContextIsNull(rawContext))
462 throw nb::python_error();
463 return forContext(rawContext).releaseObject();
464}
465
466PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
467 nb::gil_scoped_acquire acquire;
468 nb::ft_lock_guard lock(live_contexts_mutex);
469 auto &liveContexts = getLiveContexts();
470 auto it = liveContexts.find(context.ptr);
471 if (it == liveContexts.end()) {
472 // Create.
473 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
474 nb::object pyRef = nb::cast(unownedContextWrapper);
475 assert(pyRef && "cast to nb::object failed");
476 liveContexts[context.ptr] = unownedContextWrapper;
477 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
478 }
479 // Use existing.
480 nb::object pyRef = nb::cast(it->second);
481 return PyMlirContextRef(it->second, std::move(pyRef));
482}
483
484nb::ft_mutex PyMlirContext::live_contexts_mutex;
485
486PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
487 static LiveContextMap liveContexts;
488 return liveContexts;
489}
490
492 nb::ft_lock_guard lock(live_contexts_mutex);
493 return getLiveContexts().size();
494}
496nb::object PyMlirContext::contextEnter(nb::object context) {
497 return PyThreadContextEntry::pushContext(context);
498}
499
500void PyMlirContext::contextExit(const nb::object &excType,
501 const nb::object &excVal,
502 const nb::object &excTb) {
504}
505
506nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
507 // Note that ownership is transferred to the delete callback below by way of
508 // an explicit inc_ref (borrow).
509 PyDiagnosticHandler *pyHandler =
510 new PyDiagnosticHandler(get(), std::move(callback));
511 nb::object pyHandlerObject =
512 nb::cast(pyHandler, nb::rv_policy::take_ownership);
513 (void)pyHandlerObject.inc_ref();
514
515 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
516 // guaranteed to be known to pybind.
517 auto handlerCallback =
518 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
519 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
520 nb::object pyDiagnosticObject =
521 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
522
523 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
524 bool result = false;
525 {
526 // Since this can be called from arbitrary C++ contexts, always get the
527 // gil.
528 nb::gil_scoped_acquire gil;
529 try {
530 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
531 } catch (std::exception &e) {
532 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
533 e.what());
534 pyHandler->hadError = true;
535 }
536 }
537
538 pyDiagnostic->invalidate();
540 };
541 auto deleteCallback = +[](void *userData) {
542 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
543 assert(pyHandler->registeredID && "handler is not registered");
544 pyHandler->registeredID.reset();
545
546 // Decrement reference, balancing the inc_ref() above.
547 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
548 pyHandlerObject.dec_ref();
549 };
550
551 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
552 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
553 return pyHandlerObject;
554}
555
556MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
557 void *userData) {
558 auto *self = static_cast<ErrorCapture *>(userData);
559 // Check if the context requested we emit errors instead of capturing them.
560 if (self->ctx->emitErrorDiagnostics)
562
564 MlirDiagnosticSeverity::MlirDiagnosticError)
567 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
569}
570
573 if (!context) {
574 throw std::runtime_error(
575 "An MLIR function requires a Context but none was provided in the call "
576 "or from the surrounding environment. Either pass to the function with "
577 "a 'context=' argument or establish a default using 'with Context():'");
578 }
579 return *context;
580}
582//------------------------------------------------------------------------------
583// PyThreadContextEntry management
584//------------------------------------------------------------------------------
585
586std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
587 static thread_local std::vector<PyThreadContextEntry> stack;
588 return stack;
589}
590
592 auto &stack = getStack();
593 if (stack.empty())
594 return nullptr;
595 return &stack.back();
596}
597
598void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
599 nb::object insertionPoint,
600 nb::object location) {
601 auto &stack = getStack();
602 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
603 std::move(location));
604 // If the new stack has more than one entry and the context of the new top
605 // entry matches the previous, copy the insertionPoint and location from the
606 // previous entry if missing from the new top entry.
607 if (stack.size() > 1) {
608 auto &prev = *(stack.rbegin() + 1);
609 auto &current = stack.back();
610 if (current.context.is(prev.context)) {
611 // Default non-context objects from the previous entry.
612 if (!current.insertionPoint)
613 current.insertionPoint = prev.insertionPoint;
614 if (!current.location)
615 current.location = prev.location;
616 }
617 }
618}
619
621 if (!context)
622 return nullptr;
623 return nb::cast<PyMlirContext *>(context);
624}
625
627 if (!insertionPoint)
628 return nullptr;
629 return nb::cast<PyInsertionPoint *>(insertionPoint);
630}
631
633 if (!location)
634 return nullptr;
635 return nb::cast<PyLocation *>(location);
636}
637
639 auto *tos = getTopOfStack();
640 return tos ? tos->getContext() : nullptr;
641}
642
644 auto *tos = getTopOfStack();
645 return tos ? tos->getInsertionPoint() : nullptr;
646}
647
649 auto *tos = getTopOfStack();
650 return tos ? tos->getLocation() : nullptr;
651}
652
653nb::object PyThreadContextEntry::pushContext(nb::object context) {
654 push(FrameKind::Context, /*context=*/context,
655 /*insertionPoint=*/nb::object(),
656 /*location=*/nb::object());
657 return context;
658}
659
661 auto &stack = getStack();
662 if (stack.empty())
663 throw std::runtime_error("Unbalanced Context enter/exit");
664 auto &tos = stack.back();
665 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
666 throw std::runtime_error("Unbalanced Context enter/exit");
667 stack.pop_back();
668}
669
670nb::object
671PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
672 PyInsertionPoint &insertionPoint =
673 nb::cast<PyInsertionPoint &>(insertionPointObj);
674 nb::object contextObj =
675 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
676 push(FrameKind::InsertionPoint,
677 /*context=*/contextObj,
678 /*insertionPoint=*/insertionPointObj,
679 /*location=*/nb::object());
680 return insertionPointObj;
681}
682
684 auto &stack = getStack();
685 if (stack.empty())
686 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
687 auto &tos = stack.back();
688 if (tos.frameKind != FrameKind::InsertionPoint &&
689 tos.getInsertionPoint() != &insertionPoint)
690 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
691 stack.pop_back();
692}
693
694nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
695 PyLocation &location = nb::cast<PyLocation &>(locationObj);
696 nb::object contextObj = location.getContext().getObject();
697 push(FrameKind::Location, /*context=*/contextObj,
698 /*insertionPoint=*/nb::object(),
699 /*location=*/locationObj);
700 return locationObj;
701}
702
704 auto &stack = getStack();
705 if (stack.empty())
706 throw std::runtime_error("Unbalanced Location enter/exit");
707 auto &tos = stack.back();
708 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
709 throw std::runtime_error("Unbalanced Location enter/exit");
710 stack.pop_back();
711}
713//------------------------------------------------------------------------------
714// PyDiagnostic*
715//------------------------------------------------------------------------------
716
718 valid = false;
719 if (materializedNotes) {
720 for (nb::handle noteObject : *materializedNotes) {
721 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
722 note->invalidate();
723 }
724 }
725}
726
728 nb::object callback)
729 : context(context), callback(std::move(callback)) {}
730
732
734 if (!registeredID)
735 return;
736 MlirDiagnosticHandlerID localID = *registeredID;
737 mlirContextDetachDiagnosticHandler(context, localID);
738 assert(!registeredID && "should have unregistered");
739 // Not strictly necessary but keeps stale pointers from being around to cause
740 // issues.
741 context = {nullptr};
742}
743
744void PyDiagnostic::checkValid() {
745 if (!valid) {
746 throw std::invalid_argument(
747 "Diagnostic is invalid (used outside of callback)");
748 }
749}
750
752 checkValid();
753 return static_cast<PyDiagnosticSeverity>(
754 mlirDiagnosticGetSeverity(diagnostic));
755}
756
758 checkValid();
759 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
760 MlirContext context = mlirLocationGetContext(loc);
761 return PyLocation(PyMlirContext::forContext(context), loc);
762}
763
764nb::str PyDiagnostic::getMessage() {
765 checkValid();
766 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
767 PyFileAccumulator accum(fileObject, /*binary=*/false);
768 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
769 return nb::cast<nb::str>(fileObject.attr("getvalue")());
770}
771
772nb::typed<nb::tuple, PyDiagnostic> PyDiagnostic::getNotes() {
773 checkValid();
774 if (materializedNotes)
775 return *materializedNotes;
776 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
777 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
778 for (intptr_t i = 0; i < numNotes; ++i) {
779 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
780 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
781 PyTuple_SetItem(notes.ptr(), i, diagnostic.release().ptr());
782 }
783 materializedNotes = std::move(notes);
784
785 return *materializedNotes;
786}
787
789 std::vector<DiagnosticInfo> notes;
790 for (nb::handle n : getNotes())
791 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
792 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
793 std::move(notes)};
794}
796//------------------------------------------------------------------------------
797// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
798//------------------------------------------------------------------------------
799
800MlirDialect PyDialects::getDialectForKey(const std::string &key,
801 bool attrError) {
802 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
803 {key.data(), key.size()});
804 if (mlirDialectIsNull(dialect)) {
805 std::string msg = join("Dialect '", key, "' not found");
806 if (attrError)
807 throw nb::attribute_error(msg.c_str());
808 throw nb::index_error(msg.c_str());
809 }
810 return dialect;
811}
814 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
815}
816
818 MlirDialectRegistry rawRegistry =
820 if (mlirDialectRegistryIsNull(rawRegistry))
821 throw nb::python_error();
822 return PyDialectRegistry(rawRegistry);
823}
825//------------------------------------------------------------------------------
826// PyLocation
827//------------------------------------------------------------------------------
829nb::object PyLocation::getCapsule() {
830 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
831}
832
833PyLocation PyLocation::createFromCapsule(nb::object capsule) {
834 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
835 if (mlirLocationIsNull(rawLoc))
836 throw nb::python_error();
838 rawLoc);
839}
841nb::object PyLocation::contextEnter(nb::object locationObj) {
842 return PyThreadContextEntry::pushLocation(locationObj);
843}
844
845void PyLocation::contextExit(const nb::object &excType,
846 const nb::object &excVal,
847 const nb::object &excTb) {
849}
850
853 if (!location) {
854 throw std::runtime_error(
855 "An MLIR function requires a Location but none was provided in the "
856 "call or from the surrounding environment. Either pass to the function "
857 "with a 'loc=' argument or establish a default using 'with loc:'");
858 }
859 return *location;
860}
861
862//------------------------------------------------------------------------------
863// PyModule
864//------------------------------------------------------------------------------
865
866PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
867 : BaseContextObject(std::move(contextRef)), module(module) {}
868
870 nb::gil_scoped_acquire acquire;
871 auto &liveModules = getContext()->liveModules;
872 assert(liveModules.count(module.ptr) == 1 &&
873 "destroying module not in live map");
874 liveModules.erase(module.ptr);
875 mlirModuleDestroy(module);
876}
877
878PyModuleRef PyModule::forModule(MlirModule module) {
879 MlirContext context = mlirModuleGetContext(module);
880 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
881
882 nb::gil_scoped_acquire acquire;
883 auto &liveModules = contextRef->liveModules;
884 auto it = liveModules.find(module.ptr);
885 if (it == liveModules.end()) {
886 // Create.
887 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
888 // Note that the default return value policy on cast is automatic_reference,
889 // which does not take ownership (delete will not be called).
890 // Just be explicit.
891 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
892 unownedModule->handle = pyRef;
893 liveModules[module.ptr] =
894 std::make_pair(unownedModule->handle, unownedModule);
895 return PyModuleRef(unownedModule, std::move(pyRef));
896 }
897 // Use existing.
898 PyModule *existing = it->second.second;
899 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
900 return PyModuleRef(existing, std::move(pyRef));
901}
902
903nb::object PyModule::createFromCapsule(nb::object capsule) {
904 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
905 if (mlirModuleIsNull(rawModule))
906 throw nb::python_error();
907 return forModule(rawModule).releaseObject();
908}
909
910nb::object PyModule::getCapsule() {
911 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
912}
914//------------------------------------------------------------------------------
915// PyOperation
916//------------------------------------------------------------------------------
917
918PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
919 : BaseContextObject(std::move(contextRef)), operation(operation) {}
920
922 // If the operation has already been invalidated there is nothing to do.
923 if (!valid)
924 return;
925 // Otherwise, invalidate the operation when it is attached.
926 if (isAttached())
927 setInvalid();
928 else {
929 // And destroy it when it is detached, i.e. owned by Python.
930 erase();
931 }
932}
933
934namespace {
935
936// Constructs a new object of type T in-place on the Python heap, returning a
937// PyObjectRef to it, loosely analogous to std::make_shared<T>().
938template <typename T, class... Args>
939PyObjectRef<T> makeObjectRef(Args &&...args) {
940 nb::handle type = nb::type<T>();
941 nb::object instance = nb::inst_alloc(type);
942 T *ptr = nb::inst_ptr<T>(instance);
943 new (ptr) T(std::forward<Args>(args)...);
944 nb::inst_mark_ready(instance);
945 return PyObjectRef<T>(ptr, std::move(instance));
946}
947
948} // namespace
949
950PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
951 MlirOperation operation,
952 nb::object parentKeepAlive) {
953 // Create.
954 PyOperationRef unownedOperation =
955 makeObjectRef<PyOperation>(std::move(contextRef), operation);
956 unownedOperation->handle = unownedOperation.getObject();
957 if (parentKeepAlive) {
958 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
959 }
960 return unownedOperation;
961}
962
964 MlirOperation operation,
965 nb::object parentKeepAlive) {
966 return createInstance(std::move(contextRef), operation,
967 std::move(parentKeepAlive));
968}
969
971 MlirOperation operation,
972 nb::object parentKeepAlive) {
973 PyOperationRef created = createInstance(std::move(contextRef), operation,
974 std::move(parentKeepAlive));
975 created->attached = false;
976 return created;
977}
978
980 const std::string &sourceStr,
981 const std::string &sourceName) {
982 PyMlirContext::ErrorCapture errors(contextRef);
983 MlirOperation op =
984 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
985 toMlirStringRef(sourceName));
986 if (mlirOperationIsNull(op))
987 throw MLIRError("Unable to parse operation assembly", errors.take());
988 return PyOperation::createDetached(std::move(contextRef), op);
989}
990
993 setDetached();
994 parentKeepAlive = nb::object();
995}
996
997MlirOperation PyOperation::get() const {
998 checkValid();
999 return operation;
1000}
1003 return PyOperationRef(this, nb::borrow<nb::object>(handle));
1004}
1005
1006void PyOperation::setAttached(const nb::object &parent) {
1007 assert(!attached && "operation already attached");
1008 attached = true;
1009}
1010
1012 assert(attached && "operation already detached");
1013 attached = false;
1014}
1015
1016void PyOperation::checkValid() const {
1017 if (!valid) {
1018 throw std::runtime_error("the operation has been invalidated");
1019 }
1020}
1021
1022void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1023 std::optional<int64_t> largeResourceLimit,
1024 bool enableDebugInfo, bool prettyDebugInfo,
1025 bool printGenericOpForm, bool useLocalScope,
1026 bool useNameLocAsPrefix, bool assumeVerified,
1027 nb::object fileObject, bool binary,
1028 bool skipRegions) {
1029 PyOperation &operation = getOperation();
1030 operation.checkValid();
1031 if (fileObject.is_none())
1032 fileObject = nb::module_::import_("sys").attr("stdout");
1033
1034 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1035 if (largeElementsLimit)
1036 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1037 if (largeResourceLimit)
1038 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1039 if (enableDebugInfo)
1040 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1041 /*prettyForm=*/prettyDebugInfo);
1042 if (printGenericOpForm)
1044 if (useLocalScope)
1046 if (assumeVerified)
1048 if (skipRegions)
1050 if (useNameLocAsPrefix)
1052
1053 PyFileAccumulator accum(fileObject, binary);
1054 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1055 accum.getUserData());
1057}
1058
1059void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1060 bool binary) {
1061 PyOperation &operation = getOperation();
1062 operation.checkValid();
1063 if (fileObject.is_none())
1064 fileObject = nb::module_::import_("sys").attr("stdout");
1065 PyFileAccumulator accum(fileObject, binary);
1066 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1067 accum.getUserData());
1068}
1069
1070void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1071 std::optional<int64_t> bytecodeVersion) {
1072 PyOperation &operation = getOperation();
1073 operation.checkValid();
1074 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1075
1076 if (!bytecodeVersion.has_value())
1077 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1078 accum.getUserData());
1079
1080 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1081 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1083 operation, config, accum.getCallback(), accum.getUserData());
1086 throw nb::value_error(
1087 join("Unable to honor desired bytecode version ", *bytecodeVersion)
1088 .c_str());
1089}
1090
1091void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
1092 PyWalkOrder walkOrder) {
1093 PyOperation &operation = getOperation();
1094 operation.checkValid();
1095 struct UserData {
1096 std::function<PyWalkResult(MlirOperation)> callback;
1097 bool gotException;
1098 std::string exceptionWhat;
1099 nb::object exceptionType;
1100 };
1101 UserData userData{callback, false, {}, {}};
1102 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1103 void *userData) {
1104 UserData *calleeUserData = static_cast<UserData *>(userData);
1105 try {
1106 return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
1107 } catch (nb::python_error &e) {
1108 calleeUserData->gotException = true;
1109 calleeUserData->exceptionWhat = std::string(e.what());
1110 calleeUserData->exceptionType = nb::borrow(e.type());
1111 return MlirWalkResult::MlirWalkResultInterrupt;
1112 }
1113 };
1114 mlirOperationWalk(operation, walkCallback, &userData,
1115 static_cast<MlirWalkOrder>(walkOrder));
1116 if (userData.gotException) {
1117 std::string message("Exception raised in callback: ");
1118 message.append(userData.exceptionWhat);
1119 throw std::runtime_error(message);
1120 }
1121}
1122
1123nb::object PyOperationBase::getAsm(bool binary,
1124 std::optional<int64_t> largeElementsLimit,
1125 std::optional<int64_t> largeResourceLimit,
1126 bool enableDebugInfo, bool prettyDebugInfo,
1127 bool printGenericOpForm, bool useLocalScope,
1128 bool useNameLocAsPrefix, bool assumeVerified,
1129 bool skipRegions) {
1130 nb::object fileObject;
1131 if (binary) {
1132 fileObject = nb::module_::import_("io").attr("BytesIO")();
1133 } else {
1134 fileObject = nb::module_::import_("io").attr("StringIO")();
1135 }
1136 print(/*largeElementsLimit=*/largeElementsLimit,
1137 /*largeResourceLimit=*/largeResourceLimit,
1138 /*enableDebugInfo=*/enableDebugInfo,
1139 /*prettyDebugInfo=*/prettyDebugInfo,
1140 /*printGenericOpForm=*/printGenericOpForm,
1141 /*useLocalScope=*/useLocalScope,
1142 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1143 /*assumeVerified=*/assumeVerified,
1144 /*fileObject=*/fileObject,
1145 /*binary=*/binary,
1146 /*skipRegions=*/skipRegions);
1147
1148 return fileObject.attr("getvalue")();
1149}
1150
1152 PyOperation &operation = getOperation();
1153 PyOperation &otherOp = other.getOperation();
1154 operation.checkValid();
1155 otherOp.checkValid();
1156 mlirOperationMoveAfter(operation, otherOp);
1157 operation.parentKeepAlive = otherOp.parentKeepAlive;
1158}
1159
1161 PyOperation &operation = getOperation();
1162 PyOperation &otherOp = other.getOperation();
1163 operation.checkValid();
1164 otherOp.checkValid();
1165 mlirOperationMoveBefore(operation, otherOp);
1166 operation.parentKeepAlive = otherOp.parentKeepAlive;
1167}
1168
1170 PyOperation &operation = getOperation();
1171 PyOperation &otherOp = other.getOperation();
1172 operation.checkValid();
1173 otherOp.checkValid();
1174 return mlirOperationIsBeforeInBlock(operation, otherOp);
1175}
1176
1178 PyOperation &op = getOperation();
1181 throw MLIRError("Verification failed", errors.take());
1182 return true;
1183}
1184
1185std::optional<PyOperationRef> PyOperation::getParentOperation() {
1186 checkValid();
1187 if (!isAttached())
1188 throw nb::value_error("Detached operations have no parent");
1189 MlirOperation operation = mlirOperationGetParentOperation(get());
1190 if (mlirOperationIsNull(operation))
1191 return {};
1192 return PyOperation::forOperation(getContext(), operation);
1193}
1194
1196 checkValid();
1197 std::optional<PyOperationRef> parentOperation = getParentOperation();
1198 MlirBlock block = mlirOperationGetBlock(get());
1199 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1200 assert(parentOperation && "Operation has no parent");
1201 return PyBlock{std::move(*parentOperation), block};
1202}
1203
1205 checkValid();
1206 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1207}
1208
1209nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1210 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1211 if (mlirOperationIsNull(rawOperation))
1212 throw nb::python_error();
1213 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1214 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1215 .releaseObject();
1216}
1217
1218static void maybeInsertOperation(PyOperationRef &op,
1219 const nb::object &maybeIp) {
1220 // InsertPoint active?
1221 if (!maybeIp.is(nb::cast(false))) {
1222 PyInsertionPoint *ip;
1223 if (maybeIp.is_none()) {
1225 } else {
1226 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1227 }
1228 if (ip)
1229 ip->insert(*op.get());
1230 }
1231}
1232
1233nb::object PyOperation::create(std::string_view name,
1234 std::optional<std::vector<PyType *>> results,
1235 const MlirValue *operands, size_t numOperands,
1236 std::optional<nb::dict> attributes,
1237 std::optional<std::vector<PyBlock *>> successors,
1238 int regions, PyLocation &location,
1239 const nb::object &maybeIp, bool inferType) {
1240 std::vector<MlirType> mlirResults;
1241 std::vector<MlirBlock> mlirSuccessors;
1242 std::vector<std::pair<std::string, MlirAttribute>> mlirAttributes;
1243
1244 // General parameter validation.
1245 if (regions < 0)
1246 throw nb::value_error("number of regions must be >= 0");
1247
1248 // Unpack/validate results.
1249 if (results) {
1250 mlirResults.reserve(results->size());
1251 for (PyType *result : *results) {
1252 // TODO: Verify result type originate from the same context.
1253 if (!result)
1254 throw nb::value_error("result type cannot be None");
1255 mlirResults.push_back(*result);
1256 }
1257 }
1258 // Unpack/validate attributes.
1259 if (attributes) {
1260 mlirAttributes.reserve(attributes->size());
1261 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1262 std::string key;
1263 try {
1264 key = nb::cast<std::string>(it.first);
1265 } catch (nb::cast_error &err) {
1266 std::string msg = join("Invalid attribute key (not a string) when "
1267 "attempting to create the operation \"",
1268 name, "\" (", err.what(), ")");
1269 throw nb::type_error(msg.c_str());
1270 }
1271 try {
1272 auto &attribute = nb::cast<PyAttribute &>(it.second);
1273 // TODO: Verify attribute originates from the same context.
1274 mlirAttributes.emplace_back(std::move(key), attribute);
1275 } catch (nb::cast_error &err) {
1276 std::string msg = join("Invalid attribute value for the key \"", key,
1277 "\" when attempting to create the operation \"",
1278 name, "\" (", err.what(), ")");
1279 throw nb::type_error(msg.c_str());
1280 } catch (std::runtime_error &) {
1281 // This exception seems thrown when the value is "None".
1282 std::string msg = join(
1283 "Found an invalid (`None`?) attribute value for the key \"", key,
1284 "\" when attempting to create the operation \"", name, "\"");
1285 throw std::runtime_error(msg);
1286 }
1287 }
1288 }
1289 // Unpack/validate successors.
1290 if (successors) {
1291 mlirSuccessors.reserve(successors->size());
1292 for (PyBlock *successor : *successors) {
1293 // TODO: Verify successor originate from the same context.
1294 if (!successor)
1295 throw nb::value_error("successor block cannot be None");
1296 mlirSuccessors.push_back(successor->get());
1297 }
1298 }
1299
1300 // Apply unpacked/validated to the operation state. Beyond this
1301 // point, exceptions cannot be thrown or else the state will leak.
1302 MlirOperationState state =
1303 mlirOperationStateGet(toMlirStringRef(name), location);
1304 if (numOperands > 0)
1305 mlirOperationStateAddOperands(&state, numOperands, operands);
1306 state.enableResultTypeInference = inferType;
1307 if (!mlirResults.empty())
1308 mlirOperationStateAddResults(&state, mlirResults.size(),
1309 mlirResults.data());
1310 if (!mlirAttributes.empty()) {
1311 // Note that the attribute names directly reference bytes in
1312 // mlirAttributes, so that vector must not be changed from here
1313 // on.
1314 std::vector<MlirNamedAttribute> mlirNamedAttributes;
1315 mlirNamedAttributes.reserve(mlirAttributes.size());
1316 for (const std::pair<std::string, MlirAttribute> &it : mlirAttributes)
1317 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1319 toMlirStringRef(it.first)),
1320 it.second));
1321 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1322 mlirNamedAttributes.data());
1323 }
1324 if (!mlirSuccessors.empty())
1325 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1326 mlirSuccessors.data());
1327 if (regions) {
1328 std::vector<MlirRegion> mlirRegions;
1329 mlirRegions.resize(regions);
1330 for (int i = 0; i < regions; ++i)
1331 mlirRegions[i] = mlirRegionCreate();
1332 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1333 mlirRegions.data());
1334 }
1335
1336 // Construct the operation.
1337 PyMlirContext::ErrorCapture errors(location.getContext());
1338 MlirOperation operation = mlirOperationCreate(&state);
1339 if (!operation.ptr)
1340 throw MLIRError("Operation creation failed", errors.take());
1341 PyOperationRef created =
1342 PyOperation::createDetached(location.getContext(), operation);
1343 maybeInsertOperation(created, maybeIp);
1344
1345 return created.getObject();
1346}
1347
1348nb::object PyOperation::clone(const nb::object &maybeIp) {
1349 MlirOperation clonedOperation = mlirOperationClone(operation);
1350 PyOperationRef cloned =
1351 PyOperation::createDetached(getContext(), clonedOperation);
1352 maybeInsertOperation(cloned, maybeIp);
1353
1354 return cloned->createOpView();
1355}
1356
1357nb::object PyOperation::createOpView() {
1358 checkValid();
1359 MlirIdentifier ident = mlirOperationGetName(get());
1360 MlirStringRef identStr = mlirIdentifierStr(ident);
1361 auto operationCls = PyGlobals::get().lookupOperationClass(
1362 std::string_view(identStr.data, identStr.length));
1363 if (operationCls)
1364 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1365 return nb::cast(PyOpView(getRef().getObject()));
1366}
1367
1368void PyOperation::erase() {
1370 setInvalid();
1371 mlirOperationDestroy(operation);
1372}
1373
1374void PyOpResult::bindDerived(ClassTy &c) {
1375 c.def_prop_ro(
1376 "owner",
1377 [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
1378 assert(mlirOperationEqual(self.getParentOperation()->get(),
1379 mlirOpResultGetOwner(self.get())) &&
1380 "expected the owner of the value in Python to match that in "
1381 "the IR");
1382 return self.getParentOperation()->createOpView();
1383 },
1384 "Returns the operation that produces this result.");
1385 c.def_prop_ro(
1386 "result_number",
1387 [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
1388 "Returns the position of this result in the operation's result list.");
1390
1391/// Returns the list of types of the values held by container.
1392template <typename Container>
1393static std::vector<nb::typed<nb::object, PyType>>
1394getValueTypes(Container &container, PyMlirContextRef &context) {
1395 std::vector<nb::typed<nb::object, PyType>> result;
1396 result.reserve(container.size());
1397 for (int i = 0, e = container.size(); i < e; ++i) {
1398 result.push_back(PyType(context->getRef(),
1399 mlirValueGetType(container.getElement(i).get()))
1401 }
1402 return result;
1403}
1404
1406 intptr_t length, intptr_t step)
1407 : Sliceable(startIndex,
1409 : length,
1410 step),
1411 operation(std::move(operation)) {}
1412
1413void PyOpResultList::bindDerived(ClassTy &c) {
1414 c.def_prop_ro(
1415 "types",
1416 [](PyOpResultList &self) {
1417 return getValueTypes(self, self.operation->getContext());
1418 },
1419 "Returns a list of types for all results in this result list.");
1420 c.def_prop_ro(
1421 "owner",
1422 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1423 return self.operation->createOpView();
1424 },
1425 "Returns the operation that owns this result list.");
1426}
1427
1428intptr_t PyOpResultList::getRawNumElements() {
1429 operation->checkValid();
1430 return mlirOperationGetNumResults(operation->get());
1431}
1432
1433PyOpResult PyOpResultList::getRawElement(intptr_t index) {
1434 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1435 return PyOpResult(value);
1436}
1437
1438PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
1439 intptr_t step) const {
1440 return PyOpResultList(operation, startIndex, length, step);
1441}
1443//------------------------------------------------------------------------------
1444// PyOpView
1445//------------------------------------------------------------------------------
1446
1447static void populateResultTypes(std::string_view name,
1448 nb::sequence resultTypeList,
1449 const nb::object &resultSegmentSpecObj,
1450 std::vector<int32_t> &resultSegmentLengths,
1451 std::vector<PyType *> &resultTypes) {
1452 resultTypes.reserve(nb::len(resultTypeList));
1453 if (resultSegmentSpecObj.is_none()) {
1454 // Non-variadic result unpacking.
1455 size_t index = 0;
1456 for (nb::handle resultType : resultTypeList) {
1457 try {
1458 resultTypes.push_back(nb::cast<PyType *>(resultType));
1459 if (!resultTypes.back())
1460 throw nb::cast_error();
1461 } catch (nb::cast_error &err) {
1462 throw nb::value_error(join("Result ", index, " of operation \"", name,
1463 "\" must be a Type (", err.what(), ")")
1464 .c_str());
1465 }
1466 ++index;
1467 }
1468 } else {
1469 // Sized result unpacking.
1470 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1471 if (resultSegmentSpec.size() != nb::len(resultTypeList)) {
1472 throw nb::value_error(
1473 join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
1474 " result segments but was provided ", nb::len(resultTypeList))
1475 .c_str());
1476 }
1477 resultSegmentLengths.reserve(nb::len(resultTypeList));
1478 for (size_t i = 0, e = resultSegmentSpec.size(); i < e; ++i) {
1479 int segmentSpec = resultSegmentSpec[i];
1480 if (segmentSpec == 1 || segmentSpec == 0) {
1481 // Unpack unary element.
1482 try {
1483 auto *resultType = nb::cast<PyType *>(resultTypeList[i]);
1484 if (resultType) {
1485 resultTypes.push_back(resultType);
1486 resultSegmentLengths.push_back(1);
1487 } else if (segmentSpec == 0) {
1488 // Allowed to be optional.
1489 resultSegmentLengths.push_back(0);
1490 } else {
1491 throw nb::value_error(
1492 join("Result ", i, " of operation \"", name,
1493 "\" must be a Type (was None and result is not optional)")
1494 .c_str());
1495 }
1496 } catch (nb::cast_error &err) {
1497 throw nb::value_error(join("Result ", i, " of operation \"", name,
1498 "\" must be a Type (", err.what(), ")")
1499 .c_str());
1500 }
1501 } else if (segmentSpec == -1) {
1502 // Unpack sequence by appending.
1503 try {
1504 if (resultTypeList[i].is_none()) {
1505 // Treat it as an empty list.
1506 resultSegmentLengths.push_back(0);
1507 } else {
1508 // Unpack the list.
1509 auto segment = nb::cast<nb::sequence>(resultTypeList[i]);
1510 for (nb::handle segmentItem : segment) {
1511 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1512 if (!resultTypes.back()) {
1513 throw nb::type_error("contained a None item");
1514 }
1515 }
1516 resultSegmentLengths.push_back(nb::len(segment));
1517 }
1518 } catch (std::exception &err) {
1519 // NOTE: Sloppy to be using a catch-all here, but there are at least
1520 // three different unrelated exceptions that can be thrown in the
1521 // above "casts". Just keep the scope above small and catch them all.
1522 throw nb::value_error(join("Result ", i, " of operation \"", name,
1523 "\" must be a Sequence of Types (",
1524 err.what(), ")")
1525 .c_str());
1526 }
1527 } else {
1528 throw nb::value_error("Unexpected segment spec");
1530 }
1531 }
1532}
1533
1534MlirValue getUniqueResult(MlirOperation operation) {
1535 auto numResults = mlirOperationGetNumResults(operation);
1536 if (numResults != 1) {
1537 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1538 throw nb::value_error(
1539 join("Cannot call .result on operation ",
1540 std::string_view(name.data, name.length), " which has ",
1541 numResults,
1542 " results (it is only valid for operations with a "
1543 "single result)")
1544 .c_str());
1545 }
1546 return mlirOperationGetResult(operation, 0);
1547}
1548
1549static MlirValue getOpResultOrValue(nb::handle operand) {
1550 if (operand.is_none()) {
1551 throw nb::value_error("contained a None item");
1552 }
1553 PyOperationBase *op;
1554 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1555 return getUniqueResult(op->getOperation());
1556 }
1557 PyOpResultList *opResultList;
1558 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1559 return getUniqueResult(opResultList->getOperation()->get());
1560 }
1561 PyValue *value;
1562 if (nb::try_cast<PyValue *>(operand, value)) {
1563 return value->get();
1564 }
1565 throw nb::value_error("is not a Value");
1566}
1567
1568nb::typed<nb::object, PyOperation> PyOpView::buildGeneric(
1569 std::string_view name, std::tuple<int, bool> opRegionSpec,
1570 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1571 std::optional<nb::sequence> resultTypeList, nb::sequence operandList,
1572 std::optional<nb::dict> attributes,
1573 std::optional<std::vector<PyBlock *>> successors,
1574 std::optional<int> regions, PyLocation &location,
1575 const nb::object &maybeIp) {
1576 PyMlirContextRef context = location.getContext();
1577
1578 // Class level operation construction metadata.
1579 // Operand and result segment specs are either none, which does no
1580 // variadic unpacking, or a list of ints with segment sizes, where each
1581 // element is either a positive number (typically 1 for a scalar) or -1 to
1582 // indicate that it is derived from the length of the same-indexed operand
1583 // or result (implying that it is a list at that position).
1584 std::vector<int32_t> operandSegmentLengths;
1585 std::vector<int32_t> resultSegmentLengths;
1586
1587 // Validate/determine region count.
1588 int opMinRegionCount = std::get<0>(opRegionSpec);
1589 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1590 if (!regions) {
1591 regions = opMinRegionCount;
1592 }
1593 if (*regions < opMinRegionCount) {
1594 throw nb::value_error(join("Operation \"", name,
1595 "\" requires a minimum of ", opMinRegionCount,
1596 " regions but was built with regions=", *regions)
1597 .c_str());
1598 }
1599 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1600 throw nb::value_error(join("Operation \"", name,
1601 "\" requires a maximum of ", opMinRegionCount,
1602 " regions but was built with regions=", *regions)
1603 .c_str());
1604 }
1605
1606 // Unpack results.
1607 std::vector<PyType *> resultTypes;
1608 if (resultTypeList.has_value()) {
1609 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1610 resultSegmentLengths, resultTypes);
1611 }
1612
1613 // Unpack operands.
1614 std::vector<MlirValue> operands;
1615 operands.reserve(operands.size());
1616 size_t index = 0;
1617 if (operandSegmentSpecObj.is_none()) {
1618 // Non-sized operand unpacking.
1619 for (nb::handle operand : operandList) {
1620 try {
1621 operands.push_back(getOpResultOrValue(operand));
1622 } catch (nb::builtin_exception &err) {
1623 throw nb::value_error(join("Operand ", index, " of operation \"", name,
1624 "\" must be a Value (", err.what(), ")")
1625 .c_str());
1626 }
1627 ++index;
1628 }
1629 } else {
1630 // Sized operand unpacking.
1631 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1632 if (operandSegmentSpec.size() != nb::len(operandList)) {
1633 throw nb::value_error(
1634 join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
1635 "operand segments but was provided ", nb::len(operandList))
1636 .c_str());
1637 }
1638 operandSegmentLengths.reserve(nb::len(operandList));
1639 for (size_t i = 0, e = operandSegmentSpec.size(); i < e; ++i) {
1640 int segmentSpec = operandSegmentSpec[i];
1641 if (segmentSpec == 1 || segmentSpec == 0) {
1642 // Unpack unary element.
1643 const nanobind::handle operand = operandList[i];
1644 if (!operand.is_none()) {
1645 try {
1646 operands.push_back(getOpResultOrValue(operand));
1647 } catch (nb::builtin_exception &err) {
1648 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1649 "\" must be a Value (", err.what(), ")")
1650 .c_str());
1651 }
1652
1653 operandSegmentLengths.push_back(1);
1654 } else if (segmentSpec == 0) {
1655 // Allowed to be optional.
1656 operandSegmentLengths.push_back(0);
1657 } else {
1658 throw nb::value_error(
1659 join("Operand ", i, " of operation \"", name,
1660 "\" must be a Value (was None and operand is not optional)")
1661 .c_str());
1662 }
1663 } else if (segmentSpec == -1) {
1664 // Unpack sequence by appending.
1665 try {
1666 if (operandList[i].is_none()) {
1667 // Treat it as an empty list.
1668 operandSegmentLengths.push_back(0);
1669 } else {
1670 // Unpack the list.
1671 auto segment = nb::cast<nb::sequence>(operandList[i]);
1672 for (nb::handle segmentItem : segment) {
1673 operands.push_back(getOpResultOrValue(segmentItem));
1674 }
1675 operandSegmentLengths.push_back(nb::len(segment));
1676 }
1677 } catch (std::exception &err) {
1678 // NOTE: Sloppy to be using a catch-all here, but there are at least
1679 // three different unrelated exceptions that can be thrown in the
1680 // above "casts". Just keep the scope above small and catch them all.
1681 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1682 "\" must be a Sequence of Values (",
1683 err.what(), ")")
1684 .c_str());
1685 }
1686 } else {
1687 throw nb::value_error("Unexpected segment spec");
1688 }
1689 }
1690 }
1691
1692 // Merge operand/result segment lengths into attributes if needed.
1693 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1694 // Dup.
1695 if (attributes) {
1696 attributes = nb::dict(*attributes);
1697 } else {
1698 attributes = nb::dict();
1699 }
1700 if (attributes->contains("resultSegmentSizes") ||
1701 attributes->contains("operandSegmentSizes")) {
1702 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1703 "'operandSegmentSizes' attribute is unsupported. "
1704 "Use Operation.create for such low-level access.");
1705 }
1706
1707 // Add resultSegmentSizes attribute.
1708 if (!resultSegmentLengths.empty()) {
1709 MlirAttribute segmentLengthAttr =
1710 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1711 resultSegmentLengths.data());
1712 (*attributes)["resultSegmentSizes"] =
1713 PyAttribute(context, segmentLengthAttr);
1714 }
1715
1716 // Add operandSegmentSizes attribute.
1717 if (!operandSegmentLengths.empty()) {
1718 MlirAttribute segmentLengthAttr =
1719 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1720 operandSegmentLengths.data());
1721 (*attributes)["operandSegmentSizes"] =
1722 PyAttribute(context, segmentLengthAttr);
1723 }
1724 }
1725
1726 // Delegate to create.
1727 return PyOperation::create(name,
1728 /*results=*/std::move(resultTypes),
1729 /*operands=*/operands.data(),
1730 /*numOperands=*/operands.size(),
1731 /*attributes=*/std::move(attributes),
1732 /*successors=*/std::move(successors),
1733 /*regions=*/*regions, location, maybeIp,
1734 !resultTypeList);
1735}
1736
1737nb::object PyOpView::constructDerived(const nb::object &cls,
1738 const nb::object &operation) {
1739 nb::handle opViewType = nb::type<PyOpView>();
1740 nb::object instance = cls.attr("__new__")(cls);
1741 opViewType.attr("__init__")(instance, operation);
1742 return instance;
1743}
1744
1745PyOpView::PyOpView(const nb::object &operationObject)
1746 // Casting through the PyOperationBase base-class and then back to the
1747 // Operation lets us accept any PyOperationBase subclass.
1748 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1749 operationObject(operation.getRef().getObject()) {}
1751//------------------------------------------------------------------------------
1752// PyAsmState
1753//------------------------------------------------------------------------------
1754
1755PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
1756 flags = mlirOpPrintingFlagsCreate();
1757 // The OpPrintingFlags are not exposed Python side, create locally and
1758 // associate lifetime with the state.
1759 if (useLocalScope)
1761 state = mlirAsmStateCreateForValue(value, flags);
1762}
1763
1764PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
1765 flags = mlirOpPrintingFlagsCreate();
1766 // The OpPrintingFlags are not exposed Python side, create locally and
1767 // associate lifetime with the state.
1768 if (useLocalScope)
1770 state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
1771}
1773//------------------------------------------------------------------------------
1774// PyInsertionPoint.
1775//------------------------------------------------------------------------------
1776
1777PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1780 : refOperation(beforeOperationBase.getOperation().getRef()),
1781 block((*refOperation)->getBlock()) {}
1782
1784 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1785
1786void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1787 PyOperation &operation = operationBase.getOperation();
1788 if (operation.isAttached())
1789 throw nb::value_error(
1790 "Attempt to insert operation that is already attached");
1791 block.getParentOperation()->checkValid();
1792 MlirOperation beforeOp = {nullptr};
1793 if (refOperation) {
1794 // Insert before operation.
1795 (*refOperation)->checkValid();
1796 beforeOp = (*refOperation)->get();
1797 } else {
1798 // Insert at end (before null) is only valid if the block does not
1799 // already end in a known terminator (violating this will cause assertion
1800 // failures later).
1801 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1802 throw nb::index_error("Cannot insert operation at the end of a block "
1803 "that already has a terminator. Did you mean to "
1804 "use 'InsertionPoint.at_block_terminator(block)' "
1805 "versus 'InsertionPoint(block)'?");
1806 }
1808 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1809 operation.setAttached();
1810}
1811
1813 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1814 if (mlirOperationIsNull(firstOp)) {
1815 // Just insert at end.
1816 return PyInsertionPoint(block);
1817 }
1818
1819 // Insert before first op.
1821 block.getParentOperation()->getContext(), firstOp);
1822 return PyInsertionPoint{block, std::move(firstOpRef)};
1823}
1824
1826 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1827 if (mlirOperationIsNull(terminator))
1828 throw nb::value_error("Block has no terminator");
1830 block.getParentOperation()->getContext(), terminator);
1831 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1832}
1833
1835 PyOperation &operation = op.getOperation();
1836 PyBlock block = operation.getBlock();
1837 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
1838 if (mlirOperationIsNull(nextOperation))
1839 return PyInsertionPoint(block);
1841 block.getParentOperation()->getContext(), nextOperation);
1842 return PyInsertionPoint{block, std::move(nextOpRef)};
1843}
1844
1845size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
1847nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
1848 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
1849}
1850
1851void PyInsertionPoint::contextExit(const nb::object &excType,
1852 const nb::object &excVal,
1853 const nb::object &excTb) {
1855}
1857//------------------------------------------------------------------------------
1858// PyAttribute.
1859//------------------------------------------------------------------------------
1861bool PyAttribute::operator==(const PyAttribute &other) const {
1862 return mlirAttributeEqual(attr, other.attr);
1863}
1865nb::object PyAttribute::getCapsule() {
1866 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
1867}
1868
1869PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
1870 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1871 if (mlirAttributeIsNull(rawAttr))
1872 throw nb::python_error();
1873 return PyAttribute(
1875}
1876
1877nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
1878 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
1879 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1880 "mlirTypeID was expected to be non-null.");
1881 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1882 mlirTypeID, mlirAttributeGetDialect(this->get()));
1883 // nb::rv_policy::move means use std::move to move the return value
1884 // contents into a new instance that will be owned by Python.
1885 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1886 if (!typeCaster)
1887 return thisObj;
1888 return typeCaster.value()(thisObj);
1889}
1891//------------------------------------------------------------------------------
1892// PyLocation::maybeDownCast.
1893//------------------------------------------------------------------------------
1894
1895nb::typed<nb::object, PyLocation> PyLocation::maybeDownCast() {
1896 MlirAttribute locAttr = mlirLocationGetAttribute(this->get());
1897 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(locAttr);
1898 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1899 "mlirTypeID was expected to be non-null.");
1900 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1901 mlirTypeID, mlirAttributeGetDialect(locAttr));
1902 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1903 if (!typeCaster)
1904 return thisObj;
1905 return typeCaster.value()(thisObj);
1906}
1908//------------------------------------------------------------------------------
1909// PyNamedAttribute.
1910//------------------------------------------------------------------------------
1911
1912PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1913 : ownedName(new std::string(std::move(ownedName))) {
1916 toMlirStringRef(*this->ownedName)),
1917 attr);
1918}
1920//------------------------------------------------------------------------------
1921// PyType.
1922//------------------------------------------------------------------------------
1924bool PyType::operator==(const PyType &other) const {
1925 return mlirTypeEqual(type, other.type);
1926}
1928nb::object PyType::getCapsule() {
1929 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
1930}
1931
1932PyType PyType::createFromCapsule(nb::object capsule) {
1933 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1934 if (mlirTypeIsNull(rawType))
1935 throw nb::python_error();
1937 rawType);
1938}
1939
1940nb::typed<nb::object, PyType> PyType::maybeDownCast() {
1941 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
1942 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1943 "mlirTypeID was expected to be non-null.");
1944 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1945 mlirTypeID, mlirTypeGetDialect(this->get()));
1946 // nb::rv_policy::move means use std::move to move the return value
1947 // contents into a new instance that will be owned by Python.
1948 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1949 if (!typeCaster)
1950 return thisObj;
1951 return typeCaster.value()(thisObj);
1952}
1954//------------------------------------------------------------------------------
1955// PyTypeID.
1956//------------------------------------------------------------------------------
1958nb::object PyTypeID::getCapsule() {
1959 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
1960}
1961
1962PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
1963 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1964 if (mlirTypeIDIsNull(mlirTypeID))
1965 throw nb::python_error();
1966 return PyTypeID(mlirTypeID);
1967}
1968bool PyTypeID::operator==(const PyTypeID &other) const {
1969 return mlirTypeIDEqual(typeID, other.typeID);
1970}
1972//------------------------------------------------------------------------------
1973// PyValue and subclasses.
1974//------------------------------------------------------------------------------
1976nb::object PyValue::getCapsule() {
1977 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
1978}
1979
1980static PyOperationRef getValueOwnerRef(MlirValue value) {
1981 MlirOperation owner;
1982 if (mlirValueIsAOpResult(value))
1983 owner = mlirOpResultGetOwner(value);
1984 else if (mlirValueIsABlockArgument(value))
1986 else
1987 assert(false && "Value must be an block arg or op result.");
1988 if (mlirOperationIsNull(owner))
1989 throw nb::python_error();
1990 MlirContext ctx = mlirOperationGetContext(owner);
1992}
1993
1994nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
1996 MlirType type = mlirValueGetType(get());
1997 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
1998 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1999 "mlirTypeID was expected to be non-null.");
2000 std::optional<nb::callable> valueCaster =
2002 // nb::rv_policy::move means use std::move to move the return value
2003 // contents into a new instance that will be owned by Python.
2004 nb::object thisObj;
2005 if (mlirValueIsAOpResult(value))
2006 thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
2007 else if (mlirValueIsABlockArgument(value))
2008 thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
2009 else
2010 assert(false && "Value must be an block arg or op result.");
2011 if (valueCaster)
2012 return valueCaster.value()(thisObj);
2013 return thisObj;
2014}
2015
2016PyValue PyValue::createFromCapsule(nb::object capsule) {
2017 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2018 if (mlirValueIsNull(value))
2019 throw nb::python_error();
2020 PyOperationRef ownerRef = getValueOwnerRef(value);
2021 return PyValue(ownerRef, value);
2022}
2024//------------------------------------------------------------------------------
2025// PySymbolTable.
2026//------------------------------------------------------------------------------
2027
2029 : operation(operation.getOperation().getRef()) {
2030 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2031 if (mlirSymbolTableIsNull(symbolTable)) {
2032 throw nb::type_error("Operation is not a Symbol Table.");
2033 }
2034}
2035
2036nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2037 operation->checkValid();
2038 MlirOperation symbol = mlirSymbolTableLookup(
2039 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2040 if (mlirOperationIsNull(symbol))
2041 throw nb::key_error(
2042 join("Symbol '", name, "' not in the symbol table.").c_str());
2043
2044 return PyOperation::forOperation(operation->getContext(), symbol,
2045 operation.getObject())
2046 ->createOpView();
2047}
2048
2050 operation->checkValid();
2051 symbol.getOperation().checkValid();
2052 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2053 // The operation is also erased, so we must invalidate it. There may be Python
2054 // references to this operation so we don't want to delete it from the list of
2055 // live operations here.
2056 symbol.getOperation().valid = false;
2057}
2058
2059void PySymbolTable::dunderDel(const std::string &name) {
2060 nb::object operation = dunderGetItem(name);
2061 erase(nb::cast<PyOperationBase &>(operation));
2062}
2063
2065 operation->checkValid();
2066 symbol.getOperation().checkValid();
2067 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2069 if (mlirAttributeIsNull(symbolAttr))
2070 throw nb::value_error("Expected operation to have a symbol name.");
2072 symbol.getOperation().getContext(),
2073 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2074}
2075
2077 // Op must already be a symbol.
2078 PyOperation &operation = symbol.getOperation();
2079 operation.checkValid();
2081 MlirAttribute existingNameAttr =
2082 mlirOperationGetAttributeByName(operation.get(), attrName);
2083 if (mlirAttributeIsNull(existingNameAttr))
2084 throw nb::value_error("Expected operation to have a symbol name.");
2085 return PyStringAttribute(symbol.getOperation().getContext(),
2086 existingNameAttr);
2087}
2088
2090 const std::string &name) {
2091 // Op must already be a symbol.
2092 PyOperation &operation = symbol.getOperation();
2093 operation.checkValid();
2095 MlirAttribute existingNameAttr =
2096 mlirOperationGetAttributeByName(operation.get(), attrName);
2097 if (mlirAttributeIsNull(existingNameAttr))
2098 throw nb::value_error("Expected operation to have a symbol name.");
2099 MlirAttribute newNameAttr =
2100 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2101 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2102}
2103
2105 PyOperation &operation = symbol.getOperation();
2106 operation.checkValid();
2108 MlirAttribute existingVisAttr =
2109 mlirOperationGetAttributeByName(operation.get(), attrName);
2110 if (mlirAttributeIsNull(existingVisAttr))
2111 throw nb::value_error("Expected operation to have a symbol visibility.");
2112 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2113}
2114
2116 const std::string &visibility) {
2117 if (visibility != "public" && visibility != "private" &&
2118 visibility != "nested")
2119 throw nb::value_error(
2120 "Expected visibility to be 'public', 'private' or 'nested'");
2121 PyOperation &operation = symbol.getOperation();
2122 operation.checkValid();
2124 MlirAttribute existingVisAttr =
2125 mlirOperationGetAttributeByName(operation.get(), attrName);
2126 if (mlirAttributeIsNull(existingVisAttr))
2127 throw nb::value_error("Expected operation to have a symbol visibility.");
2128 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2129 toMlirStringRef(visibility));
2130 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2131}
2132
2133void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2134 const std::string &newSymbol,
2135 PyOperationBase &from) {
2136 PyOperation &fromOperation = from.getOperation();
2137 fromOperation.checkValid();
2139 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2141
2142 throw nb::value_error("Symbol rename failed");
2143}
2144
2146 bool allSymUsesVisible,
2147 nb::object callback) {
2148 PyOperation &fromOperation = from.getOperation();
2149 fromOperation.checkValid();
2150 struct UserData {
2151 PyMlirContextRef context;
2152 nb::object callback;
2153 bool gotException;
2154 std::string exceptionWhat;
2155 nb::object exceptionType;
2156 };
2157 UserData userData{
2158 fromOperation.getContext(), std::move(callback), false, {}, {}};
2160 fromOperation.get(), allSymUsesVisible,
2161 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2162 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2163 auto pyFoundOp =
2164 PyOperation::forOperation(calleeUserData->context, foundOp);
2165 if (calleeUserData->gotException)
2166 return;
2167 try {
2168 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2169 } catch (nb::python_error &e) {
2170 calleeUserData->gotException = true;
2171 calleeUserData->exceptionWhat = e.what();
2172 calleeUserData->exceptionType = nb::borrow(e.type());
2173 }
2174 },
2175 static_cast<void *>(&userData));
2176 if (userData.gotException) {
2177 std::string message("Exception raised in callback: ");
2178 message.append(userData.exceptionWhat);
2179 throw std::runtime_error(message);
2180 }
2181}
2182
2183void PyBlockArgument::bindDerived(ClassTy &c) {
2184 c.def_prop_ro(
2185 "owner",
2186 [](PyBlockArgument &self) {
2187 return PyBlock(self.getParentOperation(),
2189 },
2190 "Returns the block that owns this argument.");
2191 c.def_prop_ro(
2192 "arg_number",
2193 [](PyBlockArgument &self) {
2194 return mlirBlockArgumentGetArgNumber(self.get());
2195 },
2196 "Returns the position of this argument in the block's argument list.");
2197 c.def(
2198 "set_type",
2199 [](PyBlockArgument &self, PyType type) {
2200 return mlirBlockArgumentSetType(self.get(), type);
2201 },
2202 "type"_a, "Sets the type of this block argument.");
2203 c.def(
2204 "set_location",
2205 [](PyBlockArgument &self, PyLocation loc) {
2207 },
2208 "loc"_a, "Sets the location of this block argument.");
2209}
2210
2212 MlirBlock block, intptr_t startIndex,
2215 length == -1 ? mlirBlockGetNumArguments(block) : length, step),
2216 operation(std::move(operation)), block(block) {}
2217
2218void PyBlockArgumentList::bindDerived(ClassTy &c) {
2219 c.def_prop_ro(
2220 "types",
2221 [](PyBlockArgumentList &self) {
2222 return getValueTypes(self, self.operation->getContext());
2223 },
2224 "Returns a list of types for all arguments in this argument list.");
2225}
2226
2227intptr_t PyBlockArgumentList::getRawNumElements() {
2228 operation->checkValid();
2229 return mlirBlockGetNumArguments(block);
2230}
2231
2232PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
2233 MlirValue argument = mlirBlockGetArgument(block, pos);
2234 return PyBlockArgument(operation, argument);
2235}
2236
2237PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
2239 intptr_t step) const {
2240 return PyBlockArgumentList(operation, block, startIndex, length, step);
2241}
2242
2244 intptr_t length, intptr_t step)
2245 : Sliceable(startIndex,
2247 : length,
2248 step),
2249 operation(operation) {}
2250
2253 mlirOperationSetOperand(operation->get(), index, value.get());
2254}
2255
2256void PyOpOperandList::bindDerived(ClassTy &c) {
2257 c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
2258 "Sets the operand at the specified index to a new value.");
2259}
2260
2261intptr_t PyOpOperandList::getRawNumElements() {
2262 operation->checkValid();
2263 return mlirOperationGetNumOperands(operation->get());
2264}
2265
2266PyValue PyOpOperandList::getRawElement(intptr_t pos) {
2267 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2268 PyOperationRef pyOwner = getValueOwnerRef(operand);
2269 return PyValue(pyOwner, operand);
2270}
2271
2272PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
2273 intptr_t step) const {
2274 return PyOpOperandList(operation, startIndex, length, step);
2275}
2277/// A list of OpOperands. Internally, these are stored as consecutive elements,
2278/// random access is cheap. The (returned) OpOperand list is associated with the
2279/// operation whose operands these are, and thus extends the lifetime of this
2280/// operation.
2281class PyOpOperands : public Sliceable<PyOpOperands, PyOpOperand> {
2282public:
2283 static constexpr const char *pyClassName = "OpOperands";
2285
2287 intptr_t length = -1, intptr_t step = 1)
2289 length == -1 ? mlirOperationGetNumOperands(operation->get())
2290 : length,
2291 step),
2292 operation(operation) {}
2293
2294private:
2295 /// Give the parent CRTP class access to hook implementations below.
2296 friend class Sliceable<PyOpOperands, PyOpOperand>;
2297
2298 intptr_t getRawNumElements() {
2299 operation->checkValid();
2300 return mlirOperationGetNumOperands(operation->get());
2301 }
2302
2303 PyOpOperand getRawElement(intptr_t pos) {
2304 MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
2305 return PyOpOperand(opOperand);
2306 }
2307
2309 return PyOpOperands(operation, startIndex, length, step);
2311
2312 PyOperationRef operation;
2313};
2314
2316 intptr_t length, intptr_t step)
2317 : Sliceable(startIndex,
2319 : length,
2320 step),
2321 operation(operation) {}
2322
2325 mlirOperationSetSuccessor(operation->get(), index, block.get());
2326}
2327
2328void PyOpSuccessors::bindDerived(ClassTy &c) {
2329 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
2330 "Sets the successor block at the specified index.");
2331}
2332
2333intptr_t PyOpSuccessors::getRawNumElements() {
2334 operation->checkValid();
2335 return mlirOperationGetNumSuccessors(operation->get());
2336}
2337
2338PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
2339 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2340 return PyBlock(operation, block);
2341}
2342
2344 intptr_t step) const {
2345 return PyOpSuccessors(operation, startIndex, length, step);
2346}
2347
2349 intptr_t startIndex, intptr_t length,
2350 intptr_t step)
2351 : Sliceable(startIndex,
2352 length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
2353 step),
2354 operation(operation), block(block) {}
2355
2356intptr_t PyBlockSuccessors::getRawNumElements() {
2357 block.checkValid();
2358 return mlirBlockGetNumSuccessors(block.get());
2359}
2360
2361PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
2362 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2363 return PyBlock(operation, block);
2364}
2365
2367 intptr_t step) const {
2368 return PyBlockSuccessors(block, operation, startIndex, length, step);
2369}
2370
2372 PyOperationRef operation,
2373 intptr_t startIndex, intptr_t length,
2374 intptr_t step)
2375 : Sliceable(startIndex,
2376 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2377 : length,
2378 step),
2379 operation(operation), block(block) {}
2380
2381intptr_t PyBlockPredecessors::getRawNumElements() {
2382 block.checkValid();
2383 return mlirBlockGetNumPredecessors(block.get());
2384}
2385
2386PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
2387 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2388 return PyBlock(operation, block);
2389}
2390
2391PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
2392 intptr_t length,
2393 intptr_t step) const {
2394 return PyBlockPredecessors(block, operation, startIndex, length, step);
2395}
2396
2397nb::typed<nb::object, PyAttribute>
2398PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
2399 MlirAttribute attr =
2401 if (mlirAttributeIsNull(attr)) {
2402 throw nb::key_error("attempt to access a non-existent attribute");
2404 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2405}
2406
2407nb::typed<nb::object, std::optional<PyAttribute>>
2408PyOpAttributeMap::get(const std::string &key, nb::object defaultValue) {
2409 MlirAttribute attr =
2411 if (mlirAttributeIsNull(attr))
2412 return defaultValue;
2413 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2414}
2415
2417 if (index < 0) {
2418 index += dunderLen();
2419 }
2420 if (index < 0 || index >= dunderLen()) {
2421 throw nb::index_error("attempt to access out of bounds attribute");
2422 }
2423 MlirNamedAttribute namedAttr =
2424 mlirOperationGetAttribute(operation->get(), index);
2425 return PyNamedAttribute(
2426 namedAttr.attribute,
2427 std::string(mlirIdentifierStr(namedAttr.name).data,
2428 mlirIdentifierStr(namedAttr.name).length));
2429}
2430
2431void PyOpAttributeMap::dunderSetItem(const std::string &name,
2432 const PyAttribute &attr) {
2433 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2434 attr);
2435}
2436
2437void PyOpAttributeMap::dunderDelItem(const std::string &name) {
2438 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2440 if (!removed)
2441 throw nb::key_error("attempt to delete a non-existent attribute");
2442}
2445 return mlirOperationGetNumAttributes(operation->get());
2446}
2447
2448bool PyOpAttributeMap::dunderContains(const std::string &name) {
2449 return !mlirAttributeIsNull(
2450 mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
2451}
2452
2454 MlirOperation op, std::function<void(MlirStringRef, MlirAttribute)> fn) {
2456 for (intptr_t i = 0; i < n; ++i) {
2459 fn(name, na.attribute);
2460 }
2461}
2462
2463void PyOpAttributeMap::bind(nb::module_ &m) {
2464 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2465 .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
2466 "Checks if an attribute with the given name exists in the map.")
2467 .def("__len__", &PyOpAttributeMap::dunderLen,
2468 "Returns the number of attributes in the map.")
2469 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
2470 "Gets an attribute by name.")
2471 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
2472 "Gets a named attribute by index.")
2473 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
2474 "Sets an attribute with the given name.")
2475 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
2476 "Deletes an attribute with the given name.")
2477 .def("get", &PyOpAttributeMap::get, nb::arg("key"),
2478 nb::arg("default") = nb::none(),
2479 "Gets an attribute by name or the default value, if it does not "
2480 "exist.")
2481 .def(
2482 "__iter__",
2483 [](PyOpAttributeMap &self) -> nb::typed<nb::iterator, nb::str> {
2484 nb::list keys;
2486 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2487 keys.append(nb::str(name.data, name.length));
2488 });
2489 return nb::iter(keys);
2490 },
2491 "Iterates over attribute names.")
2492 .def(
2493 "keys",
2494 [](PyOpAttributeMap &self) -> nb::typed<nb::list, nb::str> {
2495 nb::list out;
2497 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2498 out.append(nb::str(name.data, name.length));
2499 });
2500 return out;
2501 },
2502 "Returns a list of attribute names.")
2503 .def(
2504 "values",
2505 [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
2506 nb::list out;
2508 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
2509 out.append(PyAttribute(self.operation->getContext(), attr)
2510 .maybeDownCast());
2511 });
2512 return out;
2513 },
2514 "Returns a list of attribute values.")
2515 .def(
2516 "items",
2517 [](PyOpAttributeMap &self)
2518 -> nb::typed<nb::list,
2519 nb::typed<nb::tuple, nb::str, PyAttribute>> {
2520 nb::list out;
2522 self.operation->get(),
2523 [&](MlirStringRef name, MlirAttribute attr) {
2524 out.append(nb::make_tuple(
2525 nb::str(name.data, name.length),
2526 PyAttribute(self.operation->getContext(), attr)
2527 .maybeDownCast()));
2528 });
2529 return out;
2530 },
2531 "Returns a list of `(name, attribute)` tuples.");
2532}
2533
2534void PyOpAdaptor::bind(nb::module_ &m) {
2535 nb::class_<PyOpAdaptor>(m, "OpAdaptor")
2536 .def(nb::init<nb::typed<nb::list, PyValue>, PyOpAttributeMap>(),
2537 "Creates an OpAdaptor with the given operands and attributes.",
2538 "operands"_a, "attributes"_a)
2539 .def(nb::init<nb::typed<nb::list, PyValue>, PyOpView &>(),
2540 "Creates an OpAdaptor with the given operands and operation view.",
2541 "operands"_a, "opview"_a)
2542 .def_prop_ro(
2543 "operands", [](PyOpAdaptor &self) { return self.operands; },
2544 "Returns the operands of the adaptor.")
2545 .def_prop_ro(
2546 "attributes", [](PyOpAdaptor &self) { return self.attributes; },
2547 "Returns the attributes of the adaptor.");
2548}
2549
2550static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData,
2551 const char *methodName) {
2552 nb::handle targetObj(static_cast<PyObject *>(userData));
2553 if (!nb::hasattr(targetObj, methodName))
2554 return mlirLogicalResultSuccess();
2556 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
2557 bool success = nb::cast<bool>(targetObj.attr(methodName)(opView));
2559};
2560
2561static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
2562 PyMlirContext &context) {
2563 std::string opNameStr;
2564 if (opName.is_type()) {
2565 opNameStr = nb::cast<std::string>(opName.attr("OPERATION_NAME"));
2566 } else if (nb::isinstance<nb::str>(opName)) {
2567 opNameStr = nb::cast<std::string>(opName);
2568 } else {
2569 throw nb::type_error("the root argument must be a type or a string");
2570 }
2573 trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get());
2574}
2575
2576bool PyDynamicOpTrait::attach(const nb::object &opName,
2577 const nb::object &target,
2578 PyMlirContext &context) {
2579 if (!nb::hasattr(target, "verify_invariants") &&
2580 !nb::hasattr(target, "verify_region_invariants"))
2581 throw nb::type_error(
2582 "the target object must have at least one of 'verify_invariants' or "
2583 "'verify_region_invariants' methods");
2584
2586 callbacks.construct = [](void *userData) {
2587 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
2588 };
2589 callbacks.destruct = [](void *userData) {
2590 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
2591 };
2592
2593 callbacks.verifyTrait = [](MlirOperation op,
2594 void *userData) -> MlirLogicalResult {
2595 return verifyTraitByMethod(op, userData, "verify_invariants");
2596 };
2597 callbacks.verifyRegionTrait = [](MlirOperation op,
2598 void *userData) -> MlirLogicalResult {
2599 return verifyTraitByMethod(op, userData, "verify_region_invariants");
2600 };
2601
2602 // To ensure that the same dynamic trait gets the same TypeID despite how many
2603 // times `attach` is called, we store it as an attribute on the target class.
2604 if (!nb::hasattr(target, typeIDAttr)) {
2605 nb::setattr(target, typeIDAttr,
2606 nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
2607 }
2608 MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate(
2609 nb::cast<PyTypeID>(target.attr(typeIDAttr)).get(), callbacks,
2610 static_cast<void *>(target.ptr()));
2611 return attachOpTrait(opName, trait, context);
2612}
2613
2614void PyDynamicOpTrait::bind(nb::module_ &m) {
2615 nb::class_<PyDynamicOpTrait> cls(m, "DynamicOpTrait");
2616 cls.attr("attach") = classmethod(
2617 [](const nb::object &cls, const nb::object &opName, nb::object target,
2618 DefaultingPyMlirContext context) {
2619 if (target.is_none())
2620 target = cls;
2621 return PyDynamicOpTrait::attach(opName, target, *context.get());
2622 },
2623 nb::arg("cls"), nb::arg("op_name"), nb::arg("target").none() = nb::none(),
2624 nb::arg("context").none() = nb::none(),
2625 "Attach the dynamic op trait subclass to the given operation name.");
2626}
2627
2628bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
2629 PyMlirContext &context) {
2630 MlirDynamicOpTrait trait = mlirDynamicOpTraitIsTerminatorCreate();
2631 return attachOpTrait(opName, trait, context);
2632}
2633
2634void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
2635 nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
2636 m, "IsTerminatorTrait");
2637 cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
2638 cls.attr("attach") = classmethod(
2639 [](const nb::object &cls, const nb::object &opName,
2640 DefaultingPyMlirContext context) {
2641 return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get());
2643 "Attach IsTerminator trait to the given operation name.", nb::arg("cls"),
2644 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2645}
2646
2647bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
2648 PyMlirContext &context) {
2649 MlirDynamicOpTrait trait = mlirDynamicOpTraitNoTerminatorCreate();
2650 return attachOpTrait(opName, trait, context);
2651}
2652
2653void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
2654 nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
2655 m, "NoTerminatorTrait");
2656 cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
2657 cls.attr("attach") = classmethod(
2658 [](const nb::object &cls, const nb::object &opName,
2659 DefaultingPyMlirContext context) {
2660 return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get());
2661 },
2662 "Attach NoTerminator trait to the given operation name.", nb::arg("cls"),
2663 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2664}
2665
2666} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
2667} // namespace python
2668} // namespace mlir
2669
2670namespace {
2671
2672using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
2673
2674MlirLocation tracebackToLocation(MlirContext ctx) {
2675#if defined(Py_LIMITED_API)
2676 // Frame introspection C APIs are not available under the limited API.
2677 // Traceback-based auto-location is not supported; return unknown.
2678 return mlirLocationUnknownGet(ctx);
2679#else
2680 size_t framesLimit =
2682 // Use a thread_local here to avoid requiring a large amount of space.
2683 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2684 frames;
2685 size_t count = 0;
2686
2687 nb::gil_scoped_acquire acquire;
2688
2689 PyThreadState *tstate = PyThreadState_GET();
2690 PyFrameObject *next;
2691 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2692 // In the increment expression:
2693 // 1. get the next prev frame;
2694 // 2. decrement the ref count on the current frame (in order that it can get
2695 // gc'd, along with any objects in its closure and etc);
2696 // 3. set current = next.
2697 for (; pyFrame != nullptr && count < framesLimit;
2698 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2699 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2700 auto fileNameStr =
2701 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2702 std::string_view fileName(fileNameStr);
2703 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2704 continue;
2705
2706 // co_qualname and PyCode_Addr2Location added in py3.11
2707#if PY_VERSION_HEX < 0x030B00F0
2708 std::string name =
2709 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2710 std::string_view funcName(name);
2711 int startLine = PyFrame_GetLineNumber(pyFrame);
2712 MlirLocation loc = mlirLocationFileLineColGet(
2713 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2714 0);
2715#else
2716 std::string name =
2717 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2718 std::string_view funcName(name);
2719 int startLine, startCol, endLine, endCol;
2720 int lasti = PyFrame_GetLasti(pyFrame);
2721 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2722 &endCol)) {
2723 throw nb::python_error();
2724 }
2725 MlirLocation loc = mlirLocationFileLineColRangeGet(
2726 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2727 startCol, endLine, endCol);
2728#endif
2729
2730 frames[count] = mlirLocationNameGet(
2731 ctx, mlirStringRefCreate(funcName.data(), funcName.size()), loc);
2732 ++count;
2733 }
2734 // When the loop breaks (after the last iter), current frame (if non-null)
2735 // is leaked without this.
2736 Py_XDECREF(pyFrame);
2737
2738 if (count == 0)
2739 return mlirLocationUnknownGet(ctx);
2740
2741 MlirLocation callee = frames[0];
2742 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2743 if (count == 1)
2744 return callee;
2745
2746 MlirLocation caller = frames[count - 1];
2747 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2748 for (int i = count - 2; i >= 1; i--)
2749 caller = mlirLocationCallSiteGet(frames[i], caller);
2750
2751 return mlirLocationCallSiteGet(callee, caller);
2752#endif
2753}
2754
2755/// Apply currentLocAction: wrap or fuse Location.current onto baseLoc.
2756static MlirLocation
2757applyCurrentLocAction(MlirContext ctx, MlirLocation baseLoc,
2760 if (action == Action::Fallback)
2761 return baseLoc;
2762
2763 auto *currentLoc = PyThreadContextEntry::getDefaultLocation();
2764 if (!currentLoc)
2765 return baseLoc;
2766 assert(mlirLocationGetContext(currentLoc->get()).ptr == ctx.ptr &&
2767 "Location.current must belong to the current MLIR context");
2768
2769 // NamelocWrap: walk the NameLoc chain on Location.current, collect scope
2770 // names, wrap baseLoc innermost-first so result is Outer(Inner(baseLoc)).
2771 // If Location.current is not a NameLoc, scopeNames is empty and baseLoc
2772 // is returned unchanged (nameloc_wrap is a no-op for non-NameLoc contexts).
2773 thread_local std::vector<MlirStringRef> scopeNames;
2774 scopeNames.clear();
2775 MlirLocation walk = currentLoc->get();
2776 while (mlirLocationIsAName(walk)) {
2777 scopeNames.push_back(mlirIdentifierStr(mlirLocationNameGetName(walk)));
2779 }
2780 for (auto it = scopeNames.rbegin(); it != scopeNames.rend(); ++it)
2781 baseLoc = mlirLocationNameGet(ctx, *it, baseLoc);
2782 return baseLoc;
2783}
2784
2785PyLocation
2786maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2787 auto &tbl = PyGlobals::get().getTracebackLoc();
2788
2789 // Tracebacks not enabled — return explicit loc or fall back to
2790 // Location.current.
2791 if (!tbl.locTracebacksEnabled())
2792 return location.has_value() ? location.value()
2794
2795 // From here: tracebacks are enabled.
2797 PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2798 MlirLocation baseLoc;
2799
2800 // Step 1: on_explicit — resolve explicit loc= vs traceback.
2801 if (location.has_value()) {
2802 switch (tbl.tracebackActionOnExplicitLoc()) {
2803 case OnExplicit::UseExplicit:
2804 baseLoc = location->get();
2805 break;
2806 case OnExplicit::UseTraceback:
2807 baseLoc = tracebackToLocation(ctx.get());
2808 break;
2809 }
2810 } else {
2811 baseLoc = tracebackToLocation(ctx.get());
2812 }
2813
2814 // Step 2: current_loc — compose with Location.current.
2815 baseLoc = applyCurrentLocAction(ctx.get(), baseLoc,
2816 tbl.tracebackActionOnCurrentLoc());
2817
2819 return {ref, baseLoc};
2820}
2821} // namespace
2822
2823namespace mlir {
2824namespace python {
2826
2827static std::string formatMLIRError(const MLIRError &e) {
2828 auto locStr = [](const PyLocation &loc) {
2829 PyPrintAccumulator accum;
2830 mlirLocationPrint(loc, accum.getCallback(), accum.getUserData());
2831 std::string s = nb::cast<std::string>(nb::str(accum.join()));
2832 std::string_view sv(s);
2833 if (sv.size() > 5) {
2834 sv.remove_prefix(4); // "loc("
2835 sv.remove_suffix(1); // ")"
2836 }
2837 return std::string(sv);
2838 };
2839 auto indent = [](std::string s) {
2840 size_t pos = 0;
2841 while ((pos = s.find('\n', pos)) != std::string::npos) {
2842 s.replace(pos, 1, "\n ");
2843 pos += 3;
2844 }
2845 return s;
2846 };
2847
2848 std::ostringstream os;
2849 os << e.message;
2850 if (!e.errorDiagnostics.empty())
2851 os << ":";
2852 for (const auto &diag : e.errorDiagnostics) {
2853 os << "\nerror: " << locStr(diag.location) << ": " << indent(diag.message);
2854 for (const auto &note : diag.notes) {
2855 os << "\n note: " << locStr(note.location) << ": "
2856 << indent(note.message);
2857 }
2858 }
2859 return os.str();
2860}
2861
2862void MLIRError::bind(nb::module_ &m) {
2863 auto cls = nb::exception<MLIRError>(m, "MLIRError", PyExc_Exception);
2864 nb::register_exception_translator(
2865 [](const std::exception_ptr &p, void *payload) {
2866 try {
2867 if (p)
2868 std::rethrow_exception(p);
2869 } catch (MLIRError &e) {
2870 std::string formatted = formatMLIRError(e);
2871 nb::object ty = nb::borrow(static_cast<PyObject *>(payload));
2872 nb::object obj = ty(formatted);
2873 obj.attr("_message") = nb::cast(std::move(e.message));
2874 obj.attr("_error_diagnostics") =
2875 nb::cast(std::move(e.errorDiagnostics));
2876 PyErr_SetObject(static_cast<PyObject *>(payload), obj.ptr());
2877 }
2878 },
2879 cls.ptr());
2880 auto propertyType = nb::borrow<nb::type_object>(
2881 reinterpret_cast<PyObject *>(&PyProperty_Type));
2882 nb::setattr(
2883 cls, "message",
2884 propertyType(nb::cpp_function(
2885 [](nb::object self) -> nb::str { return self.attr("_message"); },
2886 nb::is_method())));
2887 nb::setattr(cls, "error_diagnostics",
2888 propertyType(nb::cpp_function(
2889 [](nb::object self)
2890 -> nb::typed<nb::list, PyDiagnostic::DiagnosticInfo> {
2891 return self.attr("_error_diagnostics");
2892 },
2893 nb::is_method())));
2894}
2895
2896void populateRoot(nb::module_ &m) {
2897 m.attr("T") = nb::type_var("T");
2898 m.attr("U") = nb::type_var("U");
2899
2900 // Policies for how loc_tracebacks() composes the three location sources
2901 // (explicit loc=, generated traceback, Location.current).
2902 nb::enum_<PyGlobals::TracebackLoc::OnExplicitAction>(m, "OnExplicitAction")
2903 .value("USE_EXPLICIT",
2905 .value("USE_TRACEBACK",
2907
2908 nb::enum_<PyGlobals::TracebackLoc::CurrentLocAction>(m, "CurrentLocAction")
2910 .value("NAMELOC_WRAP",
2912
2913 nb::class_<PyGlobals>(m, "_Globals")
2914 .def_prop_rw("dialect_search_modules",
2917 .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
2918 "module_name"_a)
2919 .def(
2920 "_check_dialect_module_loaded",
2921 [](PyGlobals &self, const std::string &dialectNamespace) {
2922 return self.loadDialectModule(dialectNamespace);
2923 },
2924 "dialect_namespace"_a)
2925 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
2926 "dialect_namespace"_a, "dialect_class"_a, nb::kw_only(),
2927 "replace"_a = false,
2928 "Testing hook for directly registering a dialect")
2929 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
2930 "operation_name"_a, "operation_class"_a, nb::kw_only(),
2931 "replace"_a = false,
2932 "Testing hook for directly registering an operation")
2933 .def("loc_tracebacks_enabled",
2934 [](PyGlobals &self) {
2935 return self.getTracebackLoc().locTracebacksEnabled();
2936 })
2937 .def("set_loc_tracebacks_enabled",
2938 [](PyGlobals &self, bool enabled) {
2940 })
2941 .def("loc_tracebacks_frame_limit",
2942 [](PyGlobals &self) {
2944 })
2945 .def("set_loc_tracebacks_frame_limit",
2946 [](PyGlobals &self, std::optional<int> n) {
2949 })
2950 .def("register_traceback_file_inclusion",
2951 [](PyGlobals &self, const std::string &filename) {
2953 })
2954 .def("register_traceback_file_exclusion",
2955 [](PyGlobals &self, const std::string &filename) {
2957 })
2958 .def("traceback_action_on_explicit_loc",
2959 [](PyGlobals &self) {
2961 })
2962 .def("set_traceback_action_on_explicit_loc",
2963 [](PyGlobals &self,
2966 })
2967 .def("traceback_action_on_current_loc",
2968 [](PyGlobals &self) {
2970 })
2971 .def("set_traceback_action_on_current_loc",
2972 [](PyGlobals &self,
2975 });
2976
2977 // Aside from making the globals accessible to python, having python manage
2978 // it is necessary to make sure it is destroyed (and releases its python
2979 // resources) properly.
2980 m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
2981
2982 // Registration decorators.
2983 m.def(
2984 "register_dialect",
2985 [](nb::type_object pyClass) {
2986 std::string dialectNamespace =
2987 nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
2988 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
2989 return pyClass;
2990 },
2991 "dialect_class"_a,
2992 "Class decorator for registering a custom Dialect wrapper");
2993 m.def(
2994 "register_operation",
2995 [](const nb::type_object &dialectClass, bool replace) -> nb::object {
2996 return nb::cpp_function(
2997 [dialectClass,
2998 replace](nb::type_object opClass) -> nb::type_object {
2999 std::string operationName =
3000 nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
3001 PyGlobals::get().registerOperationImpl(operationName, opClass,
3002 replace);
3003 // Dict-stuff the new opClass by name onto the dialect class.
3004 nb::object opClassName = opClass.attr("__name__");
3005 dialectClass.attr(opClassName) = opClass;
3006 return opClass;
3007 });
3008 },
3009 // clang-format off
3010 nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
3011 "-> typing.Callable[[type[T]], type[T]]"),
3012 // clang-format on
3013 "dialect_class"_a, nb::kw_only(), "replace"_a = false,
3014 "Produce a class decorator for registering an Operation class as part of "
3015 "a dialect");
3016 m.def(
3017 "register_op_adaptor",
3018 [](const nb::type_object &opClass, bool replace) -> nb::object {
3019 return nb::cpp_function(
3020 [opClass,
3021 replace](nb::type_object adaptorClass) -> nb::type_object {
3022 std::string operationName =
3023 nb::cast<std::string>(adaptorClass.attr("OPERATION_NAME"));
3024 PyGlobals::get().registerOpAdaptorImpl(operationName,
3025 adaptorClass, replace);
3026 // Dict-stuff the new adaptorClass by name onto the opClass.
3027 opClass.attr("Adaptor") = adaptorClass;
3028 return adaptorClass;
3029 });
3030 },
3031 // clang-format off
3032 nb::sig("def register_op_adaptor(op_class: type, *, replace: bool = False) "
3033 "-> typing.Callable[[type[T]], type[T]]"),
3034 // clang-format on
3035 "op_class"_a, nb::kw_only(), "replace"_a = false,
3036 "Produce a class decorator for registering an OpAdaptor class for an "
3037 "operation.");
3038 m.def(
3040 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
3041 return nb::cpp_function([mlirTypeID, replace](
3042 nb::callable typeCaster) -> nb::object {
3043 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
3044 return typeCaster;
3045 });
3046 },
3047 // clang-format off
3048 nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
3049 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
3050 // clang-format on
3051 "typeid"_a, nb::kw_only(), "replace"_a = false,
3052 "Register a type caster for casting MLIR types to custom user types.");
3053 m.def(
3055 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
3056 return nb::cpp_function(
3057 [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
3058 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
3059 replace);
3060 return valueCaster;
3061 });
3062 },
3063 // clang-format off
3064 nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
3065 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
3066 // clang-format on
3067 "typeid"_a, nb::kw_only(), "replace"_a = false,
3068 "Register a value caster for casting MLIR values to custom user values.");
3069}
3070
3071//------------------------------------------------------------------------------
3072// Location subclass bindDerived implementations.
3073//------------------------------------------------------------------------------
3074
3076 c.def_static(
3077 "get",
3078 [](DefaultingPyMlirContext context) {
3079 return PyUnknownLocation(context->getRef(),
3080 mlirLocationUnknownGet(context->get()));
3081 },
3082 "context"_a = nb::none(),
3083 "Gets a Location representing an unknown location.");
3084}
3085
3087 c.def_static(
3088 "get",
3089 [](std::string filename, int line, int col,
3090 DefaultingPyMlirContext context) {
3091 return PyFileLineColLocation(
3092 context->getRef(),
3094 toMlirStringRef(filename), line, col));
3095 },
3096 "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
3097 "Gets a FileLineColLoc for a file, line, and column.");
3098 c.def_static(
3099 "get",
3100 [](std::string filename, int startLine, int startCol, int endLine,
3101 int endCol, DefaultingPyMlirContext context) {
3102 return PyFileLineColLocation(
3103 context->getRef(), mlirLocationFileLineColRangeGet(
3104 context->get(), toMlirStringRef(filename),
3105 startLine, startCol, endLine, endCol));
3106 },
3107 "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a, "end_col"_a,
3108 "context"_a = nb::none(),
3109 "Gets a FileLineColLoc spanning a file and line/column range.");
3110 c.def_prop_ro(
3111 "filename",
3112 [](PyFileLineColLocation &self) {
3113 return mlirIdentifierStr(
3115 },
3116 "Gets the filename from a `FileLineColLoc`.");
3117 c.def_prop_ro(
3118 "start_line",
3119 [](PyFileLineColLocation &self) {
3121 },
3122 "Gets the start line number from a `FileLineColLoc`.");
3123 c.def_prop_ro(
3124 "start_col",
3125 [](PyFileLineColLocation &self) {
3127 },
3128 "Gets the start column number from a `FileLineColLoc`.");
3129 c.def_prop_ro(
3130 "end_line",
3131 [](PyFileLineColLocation &self) {
3133 },
3134 "Gets the end line number from a `FileLineColLoc`.");
3135 c.def_prop_ro(
3136 "end_col",
3137 [](PyFileLineColLocation &self) {
3139 },
3140 "Gets the end column number from a `FileLineColLoc`.");
3141}
3142
3144 c.def_static(
3145 "get",
3146 [](std::string name, std::optional<PyLocation> childLoc,
3147 DefaultingPyMlirContext context) {
3148 return PyNameLocation(
3149 context->getRef(),
3150 mlirLocationNameGet(context->get(), toMlirStringRef(name),
3151 childLoc
3152 ? childLoc->get()
3153 : mlirLocationUnknownGet(context->get())));
3154 },
3155 "name"_a, "child_loc"_a = nb::none(), "context"_a = nb::none(),
3156 "Gets a NameLoc with an optional child location.");
3157 c.def_prop_ro(
3158 "name_str",
3159 [](PyNameLocation &self) {
3161 },
3162 "Gets the name string from a `NameLoc`.");
3163 c.def_prop_ro(
3164 "child_loc",
3165 [](PyNameLocation &self) {
3166 return PyLocation(self.getContext(),
3168 .maybeDownCast();
3169 },
3170 "Gets the child location from a `NameLoc`.");
3171}
3172
3174 c.def_static(
3175 "get",
3176 [](PyLocation callee, const std::vector<PyLocation> &frames,
3177 DefaultingPyMlirContext context) {
3178 if (frames.empty())
3179 throw nb::value_error("No caller frames provided.");
3180 MlirLocation caller = frames.back().get();
3181 for (size_t index = frames.size() - 1; index-- > 0;) {
3182 caller = mlirLocationCallSiteGet(frames[index].get(), caller);
3183 }
3184 return PyCallSiteLocation(
3185 context->getRef(), mlirLocationCallSiteGet(callee.get(), caller));
3186 },
3187 "callee"_a, "frames"_a, "context"_a = nb::none(),
3188 "Gets a CallSiteLoc chaining a callee and one or more caller frames.");
3189 c.def_prop_ro(
3190 "callee",
3191 [](PyCallSiteLocation &self) {
3192 return PyLocation(self.getContext(),
3194 .maybeDownCast();
3195 },
3196 "Gets the callee location from a `CallSiteLoc`.");
3197 c.def_prop_ro(
3198 "caller",
3199 [](PyCallSiteLocation &self) {
3200 return PyLocation(self.getContext(),
3202 .maybeDownCast();
3203 },
3204 "Gets the caller location from a `CallSiteLoc`.");
3205}
3206
3208 c.def_static(
3209 "get",
3210 [](const std::vector<PyLocation> &pyLocations,
3211 std::optional<PyAttribute> metadata, DefaultingPyMlirContext context) {
3212 std::vector<MlirLocation> locations;
3213 locations.reserve(pyLocations.size());
3214 for (const PyLocation &pyLocation : pyLocations)
3215 locations.push_back(pyLocation.get());
3216 MlirLocation location = mlirLocationFusedGet(
3217 context->get(), locations.size(), locations.data(),
3218 metadata ? metadata->get() : MlirAttribute{0});
3219 // Strict: `Location.fused(...)` handles the collapse case.
3220 if (!mlirLocationIsAFused(location))
3221 throw nb::value_error(
3222 "FusedLoc.get would collapse to a non-fused location; use "
3223 "Location.fused(...) for the permissive variant.");
3224 return PyFusedLocation(context->getRef(), location);
3225 },
3226 "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
3227 "Gets a FusedLoc from an array of locations and optional metadata. "
3228 "Raises if the fuse would collapse to a non-fused location; use "
3229 "`Location.fused(...)` for the permissive variant.");
3230 c.def_prop_ro(
3231 "locations",
3232 [](PyFusedLocation &self) {
3233 unsigned numLocations = mlirLocationFusedGetNumLocations(self.get());
3234 std::vector<MlirLocation> locations(numLocations);
3235 if (numLocations)
3236 mlirLocationFusedGetLocations(self.get(), locations.data());
3237 std::vector<nb::object> pyLocations;
3238 pyLocations.reserve(numLocations);
3239 for (unsigned i = 0; i < numLocations; ++i)
3240 pyLocations.push_back(
3241 PyLocation(self.getContext(), locations[i]).maybeDownCast());
3242 return pyLocations;
3243 },
3244 "Gets the list of locations from a `FusedLoc`.");
3245 c.def_prop_ro(
3246 "metadata",
3247 [](PyFusedLocation &self) -> std::optional<PyAttribute> {
3248 MlirAttribute metadata = mlirLocationFusedGetMetadata(self.get());
3249 if (mlirAttributeIsNull(metadata))
3250 return std::nullopt;
3251 return PyAttribute(self.getContext(), metadata);
3252 },
3253 "Gets the metadata attribute from a `FusedLoc`, or None if absent.");
3254}
3255
3256//------------------------------------------------------------------------------
3257// Populates the core exports of the 'ir' submodule.
3258//------------------------------------------------------------------------------
3259void populateIRCore(nb::module_ &m) {
3260 //----------------------------------------------------------------------------
3261 // Enums.
3262 //----------------------------------------------------------------------------
3263 nb::enum_<PyDiagnosticSeverity>(m, "DiagnosticSeverity")
3264 .value("ERROR", PyDiagnosticSeverity::Error)
3265 .value("WARNING", PyDiagnosticSeverity::Warning)
3266 .value("NOTE", PyDiagnosticSeverity::Note)
3267 .value("REMARK", PyDiagnosticSeverity::Remark);
3268
3269 nb::enum_<PyWalkOrder>(m, "WalkOrder")
3270 .value("PRE_ORDER", PyWalkOrder::PreOrder)
3271 .value("POST_ORDER", PyWalkOrder::PostOrder);
3272 nb::enum_<PyWalkResult>(m, "WalkResult")
3273 .value("ADVANCE", PyWalkResult::Advance)
3274 .value("INTERRUPT", PyWalkResult::Interrupt)
3275 .value("SKIP", PyWalkResult::Skip);
3276
3277 //----------------------------------------------------------------------------
3278 // Mapping of Diagnostics.
3279 //----------------------------------------------------------------------------
3280 nb::class_<PyDiagnostic>(m, "Diagnostic")
3281 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
3282 "Returns the severity of the diagnostic.")
3283 .def_prop_ro("location", &PyDiagnostic::getLocation,
3284 "Returns the location associated with the diagnostic.")
3285 .def_prop_ro("message", &PyDiagnostic::getMessage,
3286 "Returns the message text of the diagnostic.")
3287 .def_prop_ro("notes", &PyDiagnostic::getNotes,
3288 "Returns a tuple of attached note diagnostics.")
3289 .def(
3290 "__str__",
3291 [](PyDiagnostic &self) -> nb::str {
3292 if (!self.isValid())
3293 return nb::str("<Invalid Diagnostic>");
3294 return self.getMessage();
3295 },
3296 "Returns the diagnostic message as a string.");
3297
3298 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
3299 .def(
3300 "__init__",
3302 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
3303 },
3304 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
3305 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
3306 "The severity level of the diagnostic.")
3307 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
3308 "The location associated with the diagnostic.")
3309 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
3310 "The message text of the diagnostic.")
3311 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
3312 "List of attached note diagnostics.")
3313 .def(
3314 "__str__",
3315 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
3316 "Returns the diagnostic message as a string.");
3317
3318 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
3319 .def("detach", &PyDiagnosticHandler::detach,
3320 "Detaches the diagnostic handler from the context.")
3321 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
3322 "Returns True if the handler is attached to a context.")
3323 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
3324 "Returns True if an error was encountered during diagnostic "
3325 "handling.")
3326 .def("__enter__", &PyDiagnosticHandler::contextEnter,
3327 "Enters the diagnostic handler as a context manager.",
3328 nb::sig("def __enter__(self, /) -> DiagnosticHandler"))
3329 .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
3330 "exc_value"_a.none(), "traceback"_a.none(),
3331 "Exits the diagnostic handler context manager.");
3332
3333 // Expose DefaultThreadPool to python
3334 nb::class_<PyThreadPool>(m, "ThreadPool")
3335 .def(
3336 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
3337 "Creates a new thread pool with default concurrency.")
3338 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
3339 "Returns the maximum number of threads in the pool.")
3340 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
3341 "Returns the raw pointer to the LLVM thread pool as a string.");
3342
3343 nb::class_<PyMlirContext>(m, "Context")
3344 .def(
3345 "__init__",
3346 [](PyMlirContext &self) {
3347 MlirContext context = mlirContextCreateWithThreading(false);
3348 new (&self) PyMlirContext(context);
3349 },
3350 R"(
3351 Creates a new MLIR context.
3352
3353 The context is the top-level container for all MLIR objects. It owns the storage
3354 for types, attributes, locations, and other core IR objects. A context can be
3355 configured to allow or disallow unregistered dialects and can have dialects
3356 loaded on-demand.)")
3357 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
3358 "Gets the number of live Context objects.")
3359 .def(
3360 "_get_context_again",
3361 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
3363 return ref.releaseObject();
3364 },
3365 "Gets another reference to the same context.")
3366 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
3367 "Gets the number of live modules owned by this context.")
3369 "Gets a capsule wrapping the MlirContext.")
3372 "Creates a Context from a capsule wrapping MlirContext.")
3373 .def("__enter__", &PyMlirContext::contextEnter,
3374 "Enters the context as a context manager.",
3375 nb::sig("def __enter__(self, /) -> Context"))
3376 .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
3377 "exc_value"_a.none(), "traceback"_a.none(),
3378 "Exits the context manager.")
3379 .def_prop_ro_static(
3380 "current",
3381 [](nb::object & /*class*/)
3382 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3384 if (!context)
3385 return {};
3386 return nb::cast(context);
3387 },
3388 nb::sig("def current(/) -> Context | None"),
3389 "Gets the Context bound to the current thread or returns None if no "
3390 "context is set.")
3391 .def_prop_ro(
3392 "dialects",
3393 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3394 "Gets a container for accessing dialects by name.")
3395 .def_prop_ro(
3396 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3397 "Alias for `dialects`.")
3398 .def(
3399 "get_dialect_descriptor",
3400 [=](PyMlirContext &self, std::string &name) {
3401 MlirDialect dialect = mlirContextGetOrLoadDialect(
3402 self.get(), {name.data(), name.size()});
3403 if (mlirDialectIsNull(dialect)) {
3404 throw nb::value_error(
3405 join("Dialect '", name, "' not found").c_str());
3406 }
3407 return PyDialectDescriptor(self.getRef(), dialect);
3408 },
3409 "dialect_name"_a,
3410 "Gets or loads a dialect by name, returning its descriptor object.")
3411 .def_prop_rw(
3412 "allow_unregistered_dialects",
3413 [](PyMlirContext &self) -> bool {
3414 return mlirContextGetAllowUnregisteredDialects(self.get());
3415 },
3416 [](PyMlirContext &self, bool value) {
3417 mlirContextSetAllowUnregisteredDialects(self.get(), value);
3418 },
3419 "Controls whether unregistered dialects are allowed in this context.")
3420 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3421 "callback"_a,
3422 "Attaches a diagnostic handler that will receive callbacks.")
3423 .def(
3424 "enable_multithreading",
3425 [](PyMlirContext &self, bool enable) {
3426 mlirContextEnableMultithreading(self.get(), enable);
3427 },
3428 "enable"_a,
3429 R"(
3430 Enables or disables multi-threading support in the context.
3431
3432 Args:
3433 enable: Whether to enable (True) or disable (False) multi-threading.
3434 )")
3435 .def(
3436 "set_thread_pool",
3437 [](PyMlirContext &self, PyThreadPool &pool) {
3438 // we should disable multi-threading first before setting
3439 // new thread pool otherwise the assert in
3440 // MLIRContext::setThreadPool will be raised.
3441 mlirContextEnableMultithreading(self.get(), false);
3442 mlirContextSetThreadPool(self.get(), pool.get());
3443 },
3444 R"(
3445 Sets a custom thread pool for the context to use.
3446
3447 Args:
3448 pool: A ThreadPool object to use for parallel operations.
3449
3450 Note:
3451 Multi-threading is automatically disabled before setting the thread pool.)")
3452 .def(
3453 "get_num_threads",
3454 [](PyMlirContext &self) {
3455 return mlirContextGetNumThreads(self.get());
3456 },
3457 "Gets the number of threads in the context's thread pool.")
3458 .def(
3459 "_mlir_thread_pool_ptr",
3460 [](PyMlirContext &self) {
3461 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3462 std::stringstream ss;
3463 ss << pool.ptr;
3464 return ss.str();
3465 },
3466 "Gets the raw pointer to the LLVM thread pool as a string.")
3467 .def(
3468 "is_registered_operation",
3469 [](PyMlirContext &self, std::string &name) {
3471 self.get(), MlirStringRef{name.data(), name.size()});
3472 },
3473 "operation_name"_a,
3474 R"(
3475 Checks whether an operation with the given name is registered.
3476
3477 Args:
3478 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3479
3480 Returns:
3481 True if the operation is registered, False otherwise.)")
3482 .def(
3483 "append_dialect_registry",
3484 [](PyMlirContext &self, PyDialectRegistry &registry) {
3485 mlirContextAppendDialectRegistry(self.get(), registry);
3486 },
3487 "registry"_a,
3488 R"(
3489 Appends the contents of a dialect registry to the context.
3490
3491 Args:
3492 registry: A DialectRegistry containing dialects to append.)")
3493 .def_prop_rw("emit_error_diagnostics",
3496 R"(
3497 Controls whether error diagnostics are emitted to diagnostic handlers.
3498
3499 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3500 .def(
3501 "load_all_available_dialects",
3502 [](PyMlirContext &self) {
3504 },
3505 R"(
3506 Loads all dialects available in the registry into the context.
3507
3508 This eagerly loads all dialects that have been registered, making them
3509 immediately available for use.)");
3510
3511 //----------------------------------------------------------------------------
3512 // Mapping of PyDialectDescriptor
3513 //----------------------------------------------------------------------------
3514 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3515 .def_prop_ro(
3516 "namespace",
3517 [](PyDialectDescriptor &self) {
3518 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3519 return nb::str(ns.data, ns.length);
3520 },
3521 "Returns the namespace of the dialect.")
3522 .def(
3523 "__repr__",
3524 [](PyDialectDescriptor &self) {
3525 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3526 std::string repr("<DialectDescriptor ");
3527 repr.append(ns.data, ns.length);
3528 repr.append(">");
3529 return repr;
3530 },
3531 nb::sig("def __repr__(self) -> str"),
3532 "Returns a string representation of the dialect descriptor.");
3533
3534 //----------------------------------------------------------------------------
3535 // Mapping of PyDialects
3536 //----------------------------------------------------------------------------
3537 nb::class_<PyDialects>(m, "Dialects")
3538 .def(
3539 "__getitem__",
3540 [=](PyDialects &self, std::string keyName) {
3541 MlirDialect dialect =
3542 self.getDialectForKey(keyName, /*attrError=*/false);
3543 nb::object descriptor =
3544 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3545 return createCustomDialectWrapper(keyName, std::move(descriptor));
3546 },
3547 "Gets a dialect by name using subscript notation.")
3548 .def(
3549 "__getattr__",
3550 [=](PyDialects &self, std::string attrName) {
3551 MlirDialect dialect =
3552 self.getDialectForKey(attrName, /*attrError=*/true);
3553 nb::object descriptor =
3554 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3555 return createCustomDialectWrapper(attrName, std::move(descriptor));
3556 },
3557 "Gets a dialect by name using attribute notation.");
3558
3559 //----------------------------------------------------------------------------
3560 // Mapping of PyDialect
3561 //----------------------------------------------------------------------------
3562 nb::class_<PyDialect>(m, "Dialect")
3563 .def(nb::init<nb::object>(), "descriptor"_a,
3564 "Creates a Dialect from a DialectDescriptor.")
3565 .def_prop_ro(
3566 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3567 "Returns the DialectDescriptor for this dialect.")
3568 .def(
3569 "__repr__",
3570 [](const nb::object &self) {
3571 auto clazz = self.attr("__class__");
3572 return nb::str("<Dialect ") +
3573 self.attr("descriptor").attr("namespace") +
3574 nb::str(" (class ") + clazz.attr("__module__") +
3575 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3576 },
3577 nb::sig("def __repr__(self) -> str"),
3578 "Returns a string representation of the dialect.");
3579
3580 //----------------------------------------------------------------------------
3581 // Mapping of PyDialectRegistry
3582 //----------------------------------------------------------------------------
3583 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3585 "Gets a capsule wrapping the MlirDialectRegistry.")
3588 "Creates a DialectRegistry from a capsule wrapping "
3589 "`MlirDialectRegistry`.")
3590 .def(nb::init<>(), "Creates a new empty dialect registry.");
3591
3592 //----------------------------------------------------------------------------
3593 // Mapping of Location
3594 //----------------------------------------------------------------------------
3595 nb::class_<PyLocation>(m, "Location")
3597 "Gets a capsule wrapping the MlirLocation.")
3599 "Creates a Location from a capsule wrapping MlirLocation.")
3600 .def("__enter__", &PyLocation::contextEnter,
3601 "Enters the location as a context manager.",
3602 nb::sig("def __enter__(self, /) -> Location"))
3603 .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
3604 "exc_value"_a.none(), "traceback"_a.none(),
3605 "Exits the location context manager.")
3606 .def(
3607 "__eq__",
3608 [](PyLocation &self, PyLocation &other) -> bool {
3609 return mlirLocationEqual(self, other);
3610 },
3611 "Compares two locations for equality.")
3612 .def(
3613 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3614 "Compares location with non-location object (always returns False).")
3615 .def_prop_ro_static(
3616 "current",
3617 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3619 if (!loc)
3620 return std::nullopt;
3621 return loc;
3622 },
3623 // clang-format off
3624 nb::sig("def current(/) -> Location | None"),
3625 // clang-format on
3626 "Gets the Location bound to the current thread or raises ValueError.")
3627 .def_static(
3628 "from_attr",
3629 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3630 return PyLocation(context->getRef(),
3631 mlirLocationFromAttribute(attribute))
3632 .maybeDownCast();
3633 },
3634 "attribute"_a, "context"_a = nb::none(),
3635 "Gets a Location from a `LocationAttr`.")
3636 // Factory shims kept for backward compatibility; return the concrete
3637 // subclass. New code should use the subclass `.get()` directly.
3638 .def_static(
3639 "unknown",
3640 [](DefaultingPyMlirContext context) {
3641 return PyUnknownLocation(context->getRef(),
3642 mlirLocationUnknownGet(context->get()));
3643 },
3644 "context"_a = nb::none(), "Alias for `UnknownLoc.get()`.")
3645 .def_static(
3646 "file",
3647 [](std::string filename, int line, int col,
3648 DefaultingPyMlirContext context) {
3649 return PyFileLineColLocation(
3650 context->getRef(),
3652 context->get(), toMlirStringRef(filename), line, col));
3653 },
3654 "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
3655 "Alias for `FileLineColLoc.get()`.")
3656 .def_static(
3657 "file",
3658 [](std::string filename, int startLine, int startCol, int endLine,
3659 int endCol, DefaultingPyMlirContext context) {
3660 return PyFileLineColLocation(
3661 context->getRef(),
3663 context->get(), toMlirStringRef(filename), startLine,
3664 startCol, endLine, endCol));
3665 },
3666 "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
3667 "end_col"_a, "context"_a = nb::none(),
3668 "Alias for `FileLineColLoc.get()` over a range.")
3669 .def_static(
3670 "name",
3671 [](std::string name, std::optional<PyLocation> childLoc,
3672 DefaultingPyMlirContext context) {
3673 return PyNameLocation(
3674 context->getRef(),
3676 context->get(), toMlirStringRef(name),
3677 childLoc ? childLoc->get()
3678 : mlirLocationUnknownGet(context->get())));
3679 },
3680 "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
3681 "Alias for `NameLoc.get()`.")
3682 .def_static(
3683 "callsite",
3684 [](PyLocation callee, const std::vector<PyLocation> &frames,
3685 DefaultingPyMlirContext context) {
3686 if (frames.empty())
3687 throw nb::value_error("No caller frames provided.");
3688 MlirLocation caller = frames.back().get();
3689 for (size_t index = frames.size() - 1; index-- > 0;)
3690 caller = mlirLocationCallSiteGet(frames[index].get(), caller);
3691 return PyCallSiteLocation(
3692 context->getRef(),
3693 mlirLocationCallSiteGet(callee.get(), caller));
3694 },
3695 "callee"_a, "frames"_a, "context"_a = nb::none(),
3696 "Alias for `CallSiteLoc.get()`.")
3697 .def_static(
3698 "fused",
3699 [](const std::vector<PyLocation> &pyLocations,
3700 std::optional<PyAttribute> metadata,
3701 DefaultingPyMlirContext context) {
3702 std::vector<MlirLocation> locations;
3703 locations.reserve(pyLocations.size());
3704 for (const PyLocation &pyLocation : pyLocations)
3705 locations.push_back(pyLocation.get());
3706 MlirLocation location = mlirLocationFusedGet(
3707 context->get(), locations.size(), locations.data(),
3708 metadata ? metadata->get() : MlirAttribute{0});
3709 return PyLocation(context->getRef(), location).maybeDownCast();
3710 },
3711 "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
3712 "Alias for `FusedLoc.get()` (may collapse to a non-fused location).")
3713 .def_prop_ro(
3714 "context",
3715 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3716 return self.getContext().getObject();
3717 },
3718 "Context that owns the `Location`.")
3719 .def_prop_ro(
3720 "attr",
3721 [](PyLocation &self) {
3722 return PyAttribute(self.getContext(),
3724 },
3725 "Get the underlying `LocationAttr`.")
3726 .def_prop_ro(
3727 "typeid",
3728 [](PyLocation &self) {
3729 MlirTypeID mlirTypeID =
3731 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3732 "mlirTypeID was expected to be non-null.");
3733 return PyTypeID(mlirTypeID);
3734 },
3735 "Gets the `TypeID` of the underlying LocationAttr.")
3736 .def(
3737 "emit_error",
3738 [](PyLocation &self, std::string message) {
3739 mlirEmitError(self, message.c_str());
3740 },
3741 "message"_a,
3742 R"(
3743 Emits an error diagnostic at this location.
3744
3745 Args:
3746 message: The error message to emit.)")
3747 .def(
3748 "__str__",
3749 [](PyLocation &self) {
3750 PyPrintAccumulator printAccum;
3751 mlirLocationPrint(self, printAccum.getCallback(),
3752 printAccum.getUserData());
3753 return printAccum.join();
3754 },
3755 "Returns the assembly form of the Location.")
3756 .def(
3757 "__repr__",
3758 [](PyLocation &self) {
3759 PyPrintAccumulator printAccum;
3760 mlirLocationPrint(self, printAccum.getCallback(),
3761 printAccum.getUserData());
3762 return printAccum.join();
3763 },
3764 "Returns the assembly representation of the location.");
3765
3771
3772 //----------------------------------------------------------------------------
3773 // Mapping of Module
3774 //----------------------------------------------------------------------------
3775 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3777 "Gets a capsule wrapping the MlirModule.")
3779 R"(
3780 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3781
3782 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3783 prevent double-frees (of the underlying `mlir::Module`).)")
3784 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3785 R"(
3786 Clears the internal MLIR module reference.
3787
3788 This is used internally to prevent double-free when ownership is transferred
3789 via the C API capsule mechanism. Not intended for normal use.)")
3790 .def_static(
3791 "parse",
3792 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3793 -> nb::typed<nb::object, PyModule> {
3794 PyMlirContext::ErrorCapture errors(context->getRef());
3795 MlirModule module = mlirModuleCreateParse(
3796 context->get(), toMlirStringRef(moduleAsm));
3797 if (mlirModuleIsNull(module))
3798 throw MLIRError("Unable to parse module assembly", errors.take());
3799 return PyModule::forModule(module).releaseObject();
3800 },
3801 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3802 .def_static(
3803 "parse",
3804 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3805 -> nb::typed<nb::object, PyModule> {
3806 PyMlirContext::ErrorCapture errors(context->getRef());
3807 MlirModule module = mlirModuleCreateParse(
3808 context->get(), toMlirStringRef(moduleAsm));
3809 if (mlirModuleIsNull(module))
3810 throw MLIRError("Unable to parse module assembly", errors.take());
3811 return PyModule::forModule(module).releaseObject();
3812 },
3813 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3814 .def_static(
3815 "parseFile",
3816 [](const std::string &path, DefaultingPyMlirContext context)
3817 -> nb::typed<nb::object, PyModule> {
3818 PyMlirContext::ErrorCapture errors(context->getRef());
3819 MlirModule module = mlirModuleCreateParseFromFile(
3820 context->get(), toMlirStringRef(path));
3821 if (mlirModuleIsNull(module))
3822 throw MLIRError("Unable to parse module assembly", errors.take());
3823 return PyModule::forModule(module).releaseObject();
3824 },
3825 "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
3826 .def_static(
3827 "create",
3828 [](const std::optional<PyLocation> &loc)
3829 -> nb::typed<nb::object, PyModule> {
3830 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3831 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3832 return PyModule::forModule(module).releaseObject();
3833 },
3834 "loc"_a = nb::none(), "Creates an empty module.")
3835 .def_prop_ro(
3836 "context",
3837 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3838 return self.getContext().getObject();
3839 },
3840 "Context that created the `Module`.")
3841 .def_prop_ro(
3842 "operation",
3843 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3844 return PyOperation::forOperation(self.getContext(),
3845 mlirModuleGetOperation(self.get()),
3846 self.getRef().releaseObject())
3847 .releaseObject();
3848 },
3849 "Accesses the module as an operation.")
3850 .def_prop_ro(
3851 "body",
3852 [](PyModule &self) {
3854 self.getContext(), mlirModuleGetOperation(self.get()),
3855 self.getRef().releaseObject());
3856 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3857 return returnBlock;
3858 },
3859 "Return the block for this module.")
3860 .def(
3861 "dump",
3862 [](PyModule &self) {
3864 },
3866 .def(
3867 "__str__",
3868 [](const nb::object &self) {
3869 // Defer to the operation's __str__.
3870 return self.attr("operation").attr("__str__")();
3871 },
3872 nb::sig("def __str__(self) -> str"),
3873 R"(
3874 Gets the assembly form of the operation with default options.
3875
3876 If more advanced control over the assembly formatting or I/O options is needed,
3877 use the dedicated print or get_asm method, which supports keyword arguments to
3878 customize behavior.
3879 )")
3880 .def(
3881 "__eq__",
3882 [](PyModule &self, PyModule &other) {
3883 return mlirModuleEqual(self.get(), other.get());
3884 },
3885 "other"_a, "Compares two modules for equality.")
3886 .def(
3887 "__hash__",
3888 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3889 "Returns the hash value of the module.");
3890
3891 //----------------------------------------------------------------------------
3892 // Mapping of Operation.
3893 //----------------------------------------------------------------------------
3894 nb::class_<PyOperationBase>(m, "_OperationBase")
3895 .def_prop_ro(
3897 [](PyOperationBase &self) {
3898 return self.getOperation().getCapsule();
3899 },
3900 "Gets a capsule wrapping the `MlirOperation`.")
3901 .def(
3902 "__eq__",
3903 [](PyOperationBase &self, PyOperationBase &other) {
3904 return mlirOperationEqual(self.getOperation().get(),
3905 other.getOperation().get());
3906 },
3907 "Compares two operations for equality.")
3908 .def(
3909 "__eq__",
3910 [](PyOperationBase &self, nb::object other) { return false; },
3911 "Compares operation with non-operation object (always returns "
3912 "False).")
3913 .def(
3914 "__hash__",
3915 [](PyOperationBase &self) {
3916 return mlirOperationHashValue(self.getOperation().get());
3917 },
3918 "Returns the hash value of the operation.")
3919 .def_prop_ro(
3920 "attributes",
3921 [](PyOperationBase &self) {
3922 return PyOpAttributeMap(self.getOperation().getRef());
3923 },
3924 "Returns a dictionary-like map of operation attributes.")
3925 .def_prop_ro(
3926 "context",
3927 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3928 PyOperation &concreteOperation = self.getOperation();
3929 concreteOperation.checkValid();
3930 return concreteOperation.getContext().getObject();
3931 },
3932 "Context that owns the operation.")
3933 .def_prop_ro(
3934 "name",
3935 [](PyOperationBase &self) {
3936 auto &concreteOperation = self.getOperation();
3937 concreteOperation.checkValid();
3938 MlirOperation operation = concreteOperation.get();
3939 return mlirIdentifierStr(mlirOperationGetName(operation));
3940 },
3941 "Returns the fully qualified name of the operation.")
3942 .def_prop_ro(
3943 "operands",
3944 [](PyOperationBase &self) {
3945 return PyOpOperandList(self.getOperation().getRef());
3946 },
3947 "Returns the list of operation operands.")
3948 .def_prop_ro(
3949 "op_operands",
3950 [](PyOperationBase &self) {
3951 return PyOpOperands(self.getOperation().getRef());
3952 },
3953 "Returns the list of op operands.")
3954 .def_prop_ro(
3955 "regions",
3956 [](PyOperationBase &self) {
3957 return PyRegionList(self.getOperation().getRef());
3958 },
3959 "Returns the list of operation regions.")
3960 .def_prop_ro(
3961 "results",
3962 [](PyOperationBase &self) {
3963 return PyOpResultList(self.getOperation().getRef());
3964 },
3965 "Returns the list of Operation results.")
3966 .def_prop_ro(
3967 "result",
3968 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3969 auto &operation = self.getOperation();
3970 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3971 .maybeDownCast();
3972 },
3973 "Shortcut to get an op result if it has only one (throws an error "
3974 "otherwise).")
3975 .def_prop_rw(
3976 "location",
3977 [](PyOperationBase &self) {
3978 PyOperation &operation = self.getOperation();
3979 return PyLocation(operation.getContext(),
3980 mlirOperationGetLocation(operation.get()))
3981 .maybeDownCast();
3982 },
3983 [](PyOperationBase &self, const PyLocation &location) {
3984 PyOperation &operation = self.getOperation();
3985 mlirOperationSetLocation(operation.get(), location.get());
3986 },
3987 nb::for_getter("Returns the source location the operation was "
3988 "defined or derived from."),
3989 nb::for_setter("Sets the source location the operation was defined "
3990 "or derived from."))
3991 .def_prop_ro(
3992 "parent",
3993 [](PyOperationBase &self)
3994 -> std::optional<nb::typed<nb::object, PyOperation>> {
3995 auto parent = self.getOperation().getParentOperation();
3996 if (parent)
3997 return parent->getObject();
3998 return {};
3999 },
4000 "Returns the parent operation, or `None` if at top level.")
4001 .def(
4002 "__str__",
4003 [](PyOperationBase &self) {
4004 return self.getAsm(/*binary=*/false,
4005 /*largeElementsLimit=*/std::nullopt,
4006 /*largeResourceLimit=*/std::nullopt,
4007 /*enableDebugInfo=*/false,
4008 /*prettyDebugInfo=*/false,
4009 /*printGenericOpForm=*/false,
4010 /*useLocalScope=*/false,
4011 /*useNameLocAsPrefix=*/false,
4012 /*assumeVerified=*/false,
4013 /*skipRegions=*/false);
4014 },
4015 nb::sig("def __str__(self) -> str"),
4016 "Returns the assembly form of the operation.")
4017 .def("print",
4018 nb::overload_cast<PyAsmState &, nb::object, bool>(
4020 "state"_a, "file"_a = nb::none(), "binary"_a = false,
4021 R"(
4022 Prints the assembly form of the operation to a file like object.
4023
4024 Args:
4025 state: `AsmState` capturing the operation numbering and flags.
4026 file: Optional file like object to write to. Defaults to sys.stdout.
4027 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
4028 .def("print",
4029 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
4030 bool, bool, bool, bool, bool, bool, nb::object,
4031 bool, bool>(&PyOperationBase::print),
4032 // Careful: Lots of arguments must match up with print method.
4033 "large_elements_limit"_a = nb::none(),
4034 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
4035 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
4036 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
4037 "assume_verified"_a = false, "file"_a = nb::none(),
4038 "binary"_a = false, "skip_regions"_a = false,
4039 R"(
4040 Prints the assembly form of the operation to a file like object.
4041
4042 Args:
4043 large_elements_limit: Whether to elide elements attributes above this
4044 number of elements. Defaults to None (no limit).
4045 large_resource_limit: Whether to elide resource attributes above this
4046 number of characters. Defaults to None (no limit). If large_elements_limit
4047 is set and this is None, the behavior will be to use large_elements_limit
4048 as large_resource_limit.
4049 enable_debug_info: Whether to print debug/location information. Defaults
4050 to False.
4051 pretty_debug_info: Whether to format debug information for easier reading
4052 by a human (warning: the result is unparseable). Defaults to False.
4053 print_generic_op_form: Whether to print the generic assembly forms of all
4054 ops. Defaults to False.
4055 use_local_scope: Whether to print in a way that is more optimized for
4056 multi-threaded access but may not be consistent with how the overall
4057 module prints.
4058 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
4059 prefixes for the SSA identifiers. Defaults to False.
4060 assume_verified: By default, if not printing generic form, the verifier
4061 will be run and if it fails, generic form will be printed with a comment
4062 about failed verification. While a reasonable default for interactive use,
4063 for systematic use, it is often better for the caller to verify explicitly
4064 and report failures in a more robust fashion. Set this to True if doing this
4065 in order to avoid running a redundant verification. If the IR is actually
4066 invalid, behavior is undefined.
4067 file: The file like object to write to. Defaults to sys.stdout.
4068 binary: Whether to write bytes (True) or str (False). Defaults to False.
4069 skip_regions: Whether to skip printing regions. Defaults to False.)")
4070 .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
4071 "desired_version"_a = nb::none(),
4072 R"(
4073 Write the bytecode form of the operation to a file like object.
4074
4075 Args:
4076 file: The file like object to write to.
4077 desired_version: Optional version of bytecode to emit.
4078 Returns:
4079 The bytecode writer status.)")
4080 .def("get_asm", &PyOperationBase::getAsm,
4081 // Careful: Lots of arguments must match up with get_asm method.
4082 "binary"_a = false, "large_elements_limit"_a = nb::none(),
4083 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
4084 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
4085 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
4086 "assume_verified"_a = false, "skip_regions"_a = false,
4087 R"(
4088 Gets the assembly form of the operation with all options available.
4089
4090 Args:
4091 binary: Whether to return a bytes (True) or str (False) object. Defaults to
4092 False.
4093 ... others ...: See the print() method for common keyword arguments for
4094 configuring the printout.
4095 Returns:
4096 Either a bytes or str object, depending on the setting of the `binary`
4097 argument.)")
4098 .def("verify", &PyOperationBase::verify,
4099 "Verify the operation. Raises MLIRError if verification fails, and "
4100 "returns true otherwise.")
4101 .def("move_after", &PyOperationBase::moveAfter, "other"_a,
4102 "Puts self immediately after the other operation in its parent "
4103 "block.")
4104 .def("move_before", &PyOperationBase::moveBefore, "other"_a,
4105 "Puts self immediately before the other operation in its parent "
4106 "block.")
4107 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
4108 R"(
4109 Checks if this operation is before another in the same block.
4110
4111 Args:
4112 other: Another operation in the same parent block.
4113
4114 Returns:
4115 True if this operation is before `other` in the operation list of the parent block.)")
4116 .def(
4117 "clone",
4118 [](PyOperationBase &self,
4119 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
4120 return self.getOperation().clone(ip);
4121 },
4122 "ip"_a = nb::none(),
4123 R"(
4124 Creates a deep copy of the operation.
4125
4126 Args:
4127 ip: Optional insertion point where the cloned operation should be inserted.
4128 If None, the current insertion point is used. If False, the operation
4129 remains detached.
4130
4131 Returns:
4132 A new Operation that is a clone of this operation.)")
4133 .def(
4134 "detach_from_parent",
4135 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
4136 PyOperation &operation = self.getOperation();
4137 operation.checkValid();
4138 if (!operation.isAttached())
4139 throw nb::value_error("Detached operation has no parent.");
4140
4141 operation.detachFromParent();
4142 return operation.createOpView();
4143 },
4144 "Detaches the operation from its parent block.")
4145 .def_prop_ro(
4146 "attached",
4147 [](PyOperationBase &self) {
4148 PyOperation &operation = self.getOperation();
4149 operation.checkValid();
4150 return operation.isAttached();
4151 },
4152 "Reports if the operation is attached to its parent block.")
4153 .def(
4154 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
4155 R"(
4156 Erases the operation and frees its memory.
4157
4158 Note:
4159 After erasing, any Python references to the operation become invalid.)")
4160 .def(
4161 "walk",
4162 [](PyOperationBase &self,
4163 std::function<PyWalkResult(MlirOperation)> callback,
4164 PyWalkOrder walkOrder, std::optional<nb::object> opClass) {
4165 if (!opClass)
4166 return self.walk(callback, walkOrder);
4167 self.walk(
4168 [&](MlirOperation mlirOp) -> PyWalkResult {
4169 nb::object opview =
4171 self.getOperation().getContext(), mlirOp)
4172 ->createOpView();
4173 if (nb::isinstance(opview, *opClass))
4174 return callback(mlirOp);
4175 return PyWalkResult::Advance;
4176 },
4177 walkOrder);
4178 },
4179 "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
4180 "op_class"_a = nb::none(),
4181 // clang-format off
4182 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ..., op_class: type[OpView] | None = None) -> None"),
4183 // clang-format on
4184 R"(
4185 Walks the operation tree with a callback function.
4186
4187 If op_class is provided, the callback is only invoked on operations
4188 of that type; all other operations are skipped silently.
4189
4190 Args:
4191 callback: A callable that takes an Operation and returns a WalkResult.
4192 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
4193 op_class: If provided, only operations of this type are passed to the callback.)")
4194 .def(
4195 "has_trait",
4196 [](PyOperationBase &self, nb::type_object &traitCls) {
4197 PyTypeID traitTypeID =
4198 nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
4199 MlirIdentifier opName =
4200 mlirOperationGetName(self.getOperation().get());
4202 mlirIdentifierStr(opName), traitTypeID.get(),
4203 self.getOperation().getContext()->get());
4204 },
4205 "trait_cls"_a, "Checks if the operation has a given trait.");
4206
4207 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
4208 .def_static(
4209 "create",
4210 [](std::string_view name,
4211 std::optional<std::vector<PyType *>> results,
4212 std::optional<std::vector<PyValue *>> operands,
4213 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
4214 attributes,
4215 std::optional<std::vector<PyBlock *>> successors, int regions,
4216 const std::optional<PyLocation> &location,
4217 const nb::object &maybeIp,
4218 bool inferType) -> nb::typed<nb::object, PyOperation> {
4219 // Unpack/validate operands.
4220 std::vector<MlirValue> mlirOperands;
4221 if (operands) {
4222 mlirOperands.reserve(operands->size());
4223 for (PyValue *operand : *operands) {
4224 if (!operand)
4225 throw nb::value_error("operand value cannot be None");
4226 mlirOperands.push_back(operand->get());
4227 }
4228 }
4229
4230 PyLocation pyLoc = maybeGetTracebackLocation(location);
4231 return PyOperation::create(
4232 name, results, mlirOperands.data(), mlirOperands.size(),
4233 attributes, successors, regions, pyLoc, maybeIp, inferType);
4234 },
4235 "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
4236 "attributes"_a = nb::none(), "successors"_a = nb::none(),
4237 "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
4238 "infer_type"_a = false,
4239 R"(
4240 Creates a new operation.
4241
4242 Args:
4243 name: Operation name (e.g. `dialect.operation`).
4244 results: Optional sequence of Type representing op result types.
4245 operands: Optional operands of the operation.
4246 attributes: Optional Dict of {str: Attribute}.
4247 successors: Optional List of Block for the operation's successors.
4248 regions: Number of regions to create (default = 0).
4249 location: Optional Location object (defaults to resolve from context manager).
4250 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
4251 infer_type: Whether to infer result types (default = False).
4252 Returns:
4253 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
4254 .def_static(
4255 "parse",
4256 [](const std::string &sourceStr, const std::string &sourceName,
4258 -> nb::typed<nb::object, PyOpView> {
4259 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
4260 ->createOpView();
4261 },
4262 "source"_a, nb::kw_only(), "source_name"_a = "",
4263 "context"_a = nb::none(),
4264 "Parses an operation. Supports both text assembly format and binary "
4265 "bytecode format.")
4267 "Gets a capsule wrapping the MlirOperation.")
4270 "Creates an Operation from a capsule wrapping MlirOperation.")
4271 .def_prop_ro(
4272 "operation",
4273 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
4274 return self;
4275 },
4276 "Returns self (the operation).")
4277 .def_prop_ro(
4278 "opview",
4279 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
4280 return self.createOpView();
4281 },
4282 R"(
4283 Returns an OpView of this operation.
4284
4285 Note:
4286 If the operation has a registered and loaded dialect then this OpView will
4287 be concrete wrapper class.)")
4288 .def_prop_ro("block", &PyOperation::getBlock,
4289 "Returns the block containing this operation.")
4290 .def_prop_ro(
4291 "successors",
4292 [](PyOperationBase &self) {
4293 return PyOpSuccessors(self.getOperation().getRef());
4294 },
4295 "Returns the list of Operation successors.")
4296 .def(
4297 "replace_uses_of_with",
4298 [](PyOperation &self, PyValue &of, PyValue &with) {
4299 mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
4300 },
4301 "of"_a, "with_"_a,
4302 "Replaces uses of the 'of' value with the 'with' value inside the "
4303 "operation.")
4304 .def("_set_invalid", &PyOperation::setInvalid,
4305 "Invalidate the operation.");
4306
4307 auto opViewClass =
4308 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
4309 .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
4310 .def(
4311 "__init__",
4312 [](PyOpView *self, std::string_view name,
4313 std::tuple<int, bool> opRegionSpec,
4314 nb::object operandSegmentSpecObj,
4315 nb::object resultSegmentSpecObj,
4316 std::optional<nb::sequence> resultTypeList,
4317 nb::sequence operandList,
4318 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
4319 attributes,
4320 std::optional<std::vector<PyBlock *>> successors,
4321 std::optional<int> regions,
4322 const std::optional<PyLocation> &location,
4323 const nb::object &maybeIp) {
4324 PyLocation pyLoc = maybeGetTracebackLocation(location);
4326 name, opRegionSpec, operandSegmentSpecObj,
4327 resultSegmentSpecObj, resultTypeList, operandList,
4328 attributes, successors, regions, pyLoc, maybeIp));
4329 },
4330 "name"_a, "opRegionSpec"_a,
4331 "operandSegmentSpecObj"_a = nb::none(),
4332 "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
4333 "operands"_a = nb::none(), "attributes"_a = nb::none(),
4334 "successors"_a = nb::none(), "regions"_a = nb::none(),
4335 "loc"_a = nb::none(), "ip"_a = nb::none())
4336 .def_prop_ro(
4337 "operation",
4338 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
4339 return self.getOperationObject();
4340 })
4341 .def_prop_ro("opview",
4342 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
4343 return self;
4344 })
4345 .def(
4346 "__str__",
4347 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
4348 .def_prop_ro(
4349 "successors",
4350 [](PyOperationBase &self) {
4351 return PyOpSuccessors(self.getOperation().getRef());
4352 },
4353 "Returns the list of Operation successors.")
4354 .def(
4355 "_set_invalid",
4356 [](PyOpView &self) { self.getOperation().setInvalid(); },
4357 "Invalidate the operation.");
4358 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
4359 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
4360 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
4361 // It is faster to pass the operation_name, ods_regions, and
4362 // ods_operand_segments/ods_result_segments as arguments to the constructor,
4363 // rather than to access them as attributes.
4364 opViewClass.attr("build_generic") = classmethod(
4365 [](nb::handle cls, std::optional<nb::sequence> resultTypeList,
4366 nb::sequence operandList,
4367 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>> attributes,
4368 std::optional<std::vector<PyBlock *>> successors,
4369 std::optional<int> regions, std::optional<PyLocation> location,
4370 const nb::object &maybeIp) {
4371 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4372 std::tuple<int, bool> opRegionSpec =
4373 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
4374 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
4375 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
4376 PyLocation pyLoc = maybeGetTracebackLocation(location);
4377 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
4378 resultSegmentSpec, resultTypeList,
4379 operandList, attributes, successors,
4380 regions, pyLoc, maybeIp);
4381 },
4382 "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
4383 "attributes"_a = nb::none(), "successors"_a = nb::none(),
4384 "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
4385 // clang-format off
4386 nb::sig("def build_generic(cls, results: Sequence[Type] | None = None, operands: Sequence[Value] | None = None, attributes: dict[str, Attribute] | None = None, successors: Sequence[Block] | None = None, regions: int | None = None, loc: Location | None = None, ip: InsertionPoint | None = None) -> typing.Self"),
4387 // clang-format on
4388 "Builds a specific, generated OpView based on class level attributes.");
4389 opViewClass.attr("parse") = classmethod(
4390 [](const nb::object &cls, const std::string &sourceStr,
4391 const std::string &sourceName,
4392 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
4393 PyOperationRef parsed =
4394 PyOperation::parse(context->getRef(), sourceStr, sourceName);
4395
4396 // Check if the expected operation was parsed, and cast to to the
4397 // appropriate `OpView` subclass if successful.
4398 // NOTE: This accesses attributes that have been automatically added to
4399 // `OpView` subclasses, and is not intended to be used on `OpView`
4400 // directly.
4401 std::string clsOpName =
4402 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4403 MlirStringRef identifier =
4405 std::string_view parsedOpName(identifier.data, identifier.length);
4406 if (clsOpName != parsedOpName)
4407 throw MLIRError(join("Expected a '", clsOpName, "' op, got: '",
4408 parsedOpName, "'"));
4409 return PyOpView::constructDerived(cls, parsed.getObject());
4410 },
4411 "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
4412 "context"_a = nb::none(),
4413 // clang-format off
4414 nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
4415 // clang-format on
4416 "Parses a specific, generated OpView based on class level attributes.");
4417 opViewClass.attr("has_trait") = classmethod(
4418 [](nb::object &self, nb::type_object &traitCls,
4419 DefaultingPyMlirContext &context) {
4420 PyTypeID traitTypeID =
4421 nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
4422 std::string opName = nb::cast<std::string>(self.attr("OPERATION_NAME"));
4424 mlirStringRefCreate(opName.data(), opName.size()),
4425 traitTypeID.get(), context->get());
4426 },
4427 "cls"_a, "trait_cls"_a, "context"_a = nb::none(),
4428 "Checks if the operation has a given trait.");
4429
4431
4432 //----------------------------------------------------------------------------
4433 // Mapping of PyRegion.
4434 //----------------------------------------------------------------------------
4435 nb::class_<PyRegion>(m, "Region")
4436 .def_prop_ro(
4437 "blocks",
4438 [](PyRegion &self) {
4439 return PyBlockList(self.getParentOperation(), self.get());
4440 },
4441 "Returns a forward-optimized sequence of blocks.")
4442 .def_prop_ro(
4443 "owner",
4444 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4445 return self.getParentOperation()->createOpView();
4446 },
4447 "Returns the operation owning this region.")
4448 .def(
4449 "__iter__",
4450 [](PyRegion &self) {
4451 self.checkValid();
4452 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
4453 return PyBlockIterator(self.getParentOperation(), firstBlock);
4454 },
4455 "Iterates over blocks in the region.")
4456 .def(
4457 "__eq__",
4458 [](PyRegion &self, PyRegion &other) {
4459 return self.get().ptr == other.get().ptr;
4460 },
4461 "Compares two regions for pointer equality.")
4462 .def(
4463 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
4464 "Compares region with non-region object (always returns False).");
4465
4466 //----------------------------------------------------------------------------
4467 // Mapping of PyBlock.
4468 //----------------------------------------------------------------------------
4469 nb::class_<PyBlock>(m, "Block")
4471 "Gets a capsule wrapping the MlirBlock.")
4472 .def_prop_ro(
4473 "owner",
4474 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4475 return self.getParentOperation()->createOpView();
4476 },
4477 "Returns the owning operation of this block.")
4478 .def_prop_ro(
4479 "region",
4480 [](PyBlock &self) {
4481 MlirRegion region = mlirBlockGetParentRegion(self.get());
4482 return PyRegion(self.getParentOperation(), region);
4483 },
4484 "Returns the owning region of this block.")
4485 .def_prop_ro(
4486 "arguments",
4487 [](PyBlock &self) {
4488 return PyBlockArgumentList(self.getParentOperation(), self.get());
4489 },
4490 "Returns a list of block arguments.")
4491 .def(
4492 "add_argument",
4493 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
4494 return PyBlockArgument(self.getParentOperation(),
4495 mlirBlockAddArgument(self.get(), type, loc));
4496 },
4497 "type"_a, "loc"_a,
4498 R"(
4499 Appends an argument of the specified type to the block.
4500
4501 Args:
4502 type: The type of the argument to add.
4503 loc: The source location for the argument.
4504
4505 Returns:
4506 The newly added block argument.)")
4507 .def(
4508 "erase_argument",
4509 [](PyBlock &self, unsigned index) {
4510 return mlirBlockEraseArgument(self.get(), index);
4511 },
4512 "index"_a,
4513 R"(
4514 Erases the argument at the specified index.
4515
4516 Args:
4517 index: The index of the argument to erase.)")
4518 .def_prop_ro(
4519 "operations",
4520 [](PyBlock &self) {
4521 return PyOperationList(self.getParentOperation(), self.get());
4522 },
4523 "Returns a forward-optimized sequence of operations.")
4524 .def_static(
4525 "create_at_start",
4526 [](PyRegion &parent, nb::typed<nb::sequence, PyType> pyArgTypes,
4527 const std::optional<nb::typed<nb::sequence, PyLocation>>
4528 &pyArgLocs) {
4529 parent.checkValid();
4530 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4531 mlirRegionInsertOwnedBlock(parent, 0, block);
4532 return PyBlock(parent.getParentOperation(), block);
4533 },
4534 "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
4535 "Creates and returns a new Block at the beginning of the given "
4536 "region (with given argument types and locations).")
4537 .def(
4538 "append_to",
4539 [](PyBlock &self, PyRegion &region) {
4540 MlirBlock b = self.get();
4543 mlirRegionAppendOwnedBlock(region.get(), b);
4544 },
4545 "region"_a,
4546 R"(
4547 Appends this block to a region.
4548
4549 Transfers ownership if the block is currently owned by another region.
4550
4551 Args:
4552 region: The region to append the block to.)")
4553 .def(
4554 "create_before",
4555 [](PyBlock &self, const nb::args &pyArgTypes,
4556 const std::optional<nb::typed<nb::sequence, PyLocation>>
4557 &pyArgLocs) {
4558 self.checkValid();
4559 MlirBlock block =
4560 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4561 MlirRegion region = mlirBlockGetParentRegion(self.get());
4562 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4563 return PyBlock(self.getParentOperation(), block);
4564 },
4565 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4566 "Creates and returns a new Block before this block "
4567 "(with given argument types and locations).")
4568 .def(
4569 "create_after",
4570 [](PyBlock &self, const nb::args &pyArgTypes,
4571 const std::optional<nb::typed<nb::sequence, PyLocation>>
4572 &pyArgLocs) {
4573 self.checkValid();
4574 MlirBlock block =
4575 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4576 MlirRegion region = mlirBlockGetParentRegion(self.get());
4577 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4578 return PyBlock(self.getParentOperation(), block);
4579 },
4580 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4581 "Creates and returns a new Block after this block "
4582 "(with given argument types and locations).")
4583 .def(
4584 "__iter__",
4585 [](PyBlock &self) {
4586 self.checkValid();
4587 MlirOperation firstOperation =
4588 mlirBlockGetFirstOperation(self.get());
4589 return PyOperationIterator(self.getParentOperation(),
4590 firstOperation);
4591 },
4592 "Iterates over operations in the block.")
4593 .def(
4594 "__eq__",
4595 [](PyBlock &self, PyBlock &other) {
4596 return self.get().ptr == other.get().ptr;
4597 },
4598 "Compares two blocks for pointer equality.")
4599 .def(
4600 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4601 "Compares block with non-block object (always returns False).")
4602 .def(
4603 "__hash__", [](PyBlock &self) { return hash(self.get().ptr); },
4604 "Returns the hash value of the block.")
4605 .def(
4606 "__str__",
4607 [](PyBlock &self) {
4608 self.checkValid();
4609 PyPrintAccumulator printAccum;
4610 mlirBlockPrint(self.get(), printAccum.getCallback(),
4611 printAccum.getUserData());
4612 return printAccum.join();
4613 },
4614 "Returns the assembly form of the block.")
4615 .def(
4616 "append",
4617 [](PyBlock &self, PyOperationBase &operation) {
4618 if (operation.getOperation().isAttached())
4619 operation.getOperation().detachFromParent();
4620
4621 MlirOperation mlirOperation = operation.getOperation().get();
4622 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4623 operation.getOperation().setAttached(
4624 self.getParentOperation().getObject());
4625 },
4626 "operation"_a,
4627 R"(
4628 Appends an operation to this block.
4629
4630 If the operation is currently in another block, it will be moved.
4631
4632 Args:
4633 operation: The operation to append to the block.)")
4634 .def_prop_ro(
4635 "successors",
4636 [](PyBlock &self) {
4637 return PyBlockSuccessors(self, self.getParentOperation());
4638 },
4639 "Returns the list of Block successors.")
4640 .def_prop_ro(
4641 "predecessors",
4642 [](PyBlock &self) {
4643 return PyBlockPredecessors(self, self.getParentOperation());
4644 },
4645 "Returns the list of Block predecessors.");
4646
4647 //----------------------------------------------------------------------------
4648 // Mapping of PyInsertionPoint.
4649 //----------------------------------------------------------------------------
4650
4651 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4652 .def(nb::init<PyBlock &>(), "block"_a,
4653 "Inserts after the last operation but still inside the block.")
4654 .def("__enter__", &PyInsertionPoint::contextEnter,
4655 "Enters the insertion point as a context manager.",
4656 nb::sig("def __enter__(self, /) -> InsertionPoint"))
4657 .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
4658 "exc_value"_a.none(), "traceback"_a.none(),
4659 "Exits the insertion point context manager.")
4660 .def_prop_ro_static(
4661 "current",
4662 [](nb::object & /*class*/) {
4664 if (!ip)
4665 throw nb::value_error("No current InsertionPoint");
4666 return ip;
4667 },
4668 nb::sig("def current(/) -> InsertionPoint"),
4669 "Gets the InsertionPoint bound to the current thread or raises "
4670 "ValueError if none has been set.")
4671 .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
4672 "Inserts before a referenced operation.")
4673 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
4674 R"(
4675 Creates an insertion point at the beginning of a block.
4676
4677 Args:
4678 block: The block at whose beginning operations should be inserted.
4679
4680 Returns:
4681 An InsertionPoint at the block's beginning.)")
4682 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4683 "block"_a,
4684 R"(
4685 Creates an insertion point before a block's terminator.
4686
4687 Args:
4688 block: The block whose terminator to insert before.
4689
4690 Returns:
4691 An InsertionPoint before the terminator.
4692
4693 Raises:
4694 ValueError: If the block has no terminator.)")
4695 .def_static("after", &PyInsertionPoint::after, "operation"_a,
4696 R"(
4697 Creates an insertion point immediately after an operation.
4698
4699 Args:
4700 operation: The operation after which to insert.
4701
4702 Returns:
4703 An InsertionPoint after the operation.)")
4704 .def("insert", &PyInsertionPoint::insert, "operation"_a,
4705 R"(
4706 Inserts an operation at this insertion point.
4707
4708 Args:
4709 operation: The operation to insert.)")
4710 .def_prop_ro(
4711 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4712 "Returns the block that this `InsertionPoint` points to.")
4713 .def_prop_ro(
4714 "ref_operation",
4715 [](PyInsertionPoint &self)
4716 -> std::optional<nb::typed<nb::object, PyOperation>> {
4717 auto refOperation = self.getRefOperation();
4718 if (refOperation)
4719 return refOperation->getObject();
4720 return {};
4721 },
4722 "The reference operation before which new operations are "
4723 "inserted, or None if the insertion point is at the end of "
4724 "the block.");
4725
4726 //----------------------------------------------------------------------------
4727 // Mapping of PyAttribute.
4728 //----------------------------------------------------------------------------
4729 nb::class_<PyAttribute>(m, "Attribute")
4730 // Delegate to the PyAttribute copy constructor, which will also lifetime
4731 // extend the backing context which owns the MlirAttribute.
4732 .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
4733 "Casts the passed attribute to the generic `Attribute`.")
4735 "Gets a capsule wrapping the MlirAttribute.")
4736 .def_static(
4738 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4739 .def_static(
4740 "parse",
4741 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4742 -> nb::typed<nb::object, PyAttribute> {
4743 PyMlirContext::ErrorCapture errors(context->getRef());
4744 MlirAttribute attr = mlirAttributeParseGet(
4745 context->get(), toMlirStringRef(attrSpec));
4746 if (mlirAttributeIsNull(attr))
4747 throw MLIRError("Unable to parse attribute", errors.take());
4748 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4749 },
4750 "asm"_a, "context"_a = nb::none(),
4751 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4752 "failure.")
4753 .def_prop_ro(
4754 "context",
4755 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4756 return self.getContext().getObject();
4757 },
4758 "Context that owns the `Attribute`.")
4759 .def_prop_ro(
4760 "type",
4761 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4762 return PyType(self.getContext(), mlirAttributeGetType(self))
4763 .maybeDownCast();
4764 },
4765 "Returns the type of the `Attribute`.")
4766 .def(
4767 "get_named",
4768 [](PyAttribute &self, std::string name) {
4769 return PyNamedAttribute(self, std::move(name));
4770 },
4771 nb::keep_alive<0, 1>(),
4772 R"(
4773 Binds a name to the attribute, creating a `NamedAttribute`.
4774
4775 Args:
4776 name: The name to bind to the `Attribute`.
4777
4778 Returns:
4779 A `NamedAttribute` with the given name and this attribute.)")
4780 .def(
4781 "__eq__",
4782 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4783 "Compares two attributes for equality.")
4784 .def(
4785 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4786 "Compares attribute with non-attribute object (always returns "
4787 "False).")
4788 .def(
4789 "__hash__", [](PyAttribute &self) { return hash(self.get().ptr); },
4790 "Returns the hash value of the attribute.")
4791 .def(
4792 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4794 .def(
4795 "__str__",
4796 [](PyAttribute &self) {
4797 PyPrintAccumulator printAccum;
4798 mlirAttributePrint(self, printAccum.getCallback(),
4799 printAccum.getUserData());
4800 return printAccum.join();
4801 },
4802 "Returns the assembly form of the Attribute.")
4803 .def(
4804 "__repr__",
4805 [](PyAttribute &self) {
4806 // Generally, assembly formats are not printed for __repr__ because
4807 // this can cause exceptionally long debug output and exceptions.
4808 // However, attribute values are generally considered useful and
4809 // are printed. This may need to be re-evaluated if debug dumps end
4810 // up being excessive.
4811 PyPrintAccumulator printAccum;
4812 printAccum.parts.append("Attribute(");
4813 mlirAttributePrint(self, printAccum.getCallback(),
4814 printAccum.getUserData());
4815 printAccum.parts.append(")");
4816 return printAccum.join();
4817 },
4818 "Returns a string representation of the attribute.")
4819 .def_prop_ro(
4820 "typeid",
4821 [](PyAttribute &self) {
4822 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4823 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4824 "mlirTypeID was expected to be non-null.");
4825 return PyTypeID(mlirTypeID);
4826 },
4827 "Returns the `TypeID` of the attribute.")
4828 .def(
4830 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4831 return self.maybeDownCast();
4832 },
4833 "Downcasts the attribute to a more specific attribute if possible.");
4834
4835 //----------------------------------------------------------------------------
4836 // Mapping of PyNamedAttribute
4837 //----------------------------------------------------------------------------
4838 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4839 .def(
4840 "__repr__",
4841 [](PyNamedAttribute &self) {
4842 PyPrintAccumulator printAccum;
4843 printAccum.parts.append("NamedAttribute(");
4844 printAccum.parts.append(
4845 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4846 mlirIdentifierStr(self.namedAttr.name).length));
4847 printAccum.parts.append("=");
4848 mlirAttributePrint(self.namedAttr.attribute,
4849 printAccum.getCallback(),
4850 printAccum.getUserData());
4851 printAccum.parts.append(")");
4852 return printAccum.join();
4853 },
4854 "Returns a string representation of the named attribute.")
4855 .def_prop_ro(
4856 "name",
4857 [](PyNamedAttribute &self) {
4858 return mlirIdentifierStr(self.namedAttr.name);
4859 },
4860 "The name of the `NamedAttribute` binding.")
4861 .def_prop_ro(
4862 "attr",
4863 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4864 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4865 "The underlying generic attribute of the `NamedAttribute` binding.");
4866
4867 //----------------------------------------------------------------------------
4868 // Mapping of PyType.
4869 //----------------------------------------------------------------------------
4870 nb::class_<PyType>(m, "Type")
4871 // Delegate to the PyType copy constructor, which will also lifetime
4872 // extend the backing context which owns the MlirType.
4873 .def(nb::init<PyType &>(), "cast_from_type"_a,
4874 "Casts the passed type to the generic `Type`.")
4876 "Gets a capsule wrapping the `MlirType`.")
4878 "Creates a Type from a capsule wrapping `MlirType`.")
4879 .def_static(
4880 "parse",
4881 [](std::string typeSpec,
4882 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4883 PyMlirContext::ErrorCapture errors(context->getRef());
4884 MlirType type =
4885 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4886 if (mlirTypeIsNull(type))
4887 throw MLIRError("Unable to parse type", errors.take());
4888 return PyType(context.get()->getRef(), type).maybeDownCast();
4889 },
4890 "asm"_a, "context"_a = nb::none(),
4891 R"(
4892 Parses the assembly form of a type.
4893
4894 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4895
4896 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4897 .def_prop_ro(
4898 "context",
4899 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4900 return self.getContext().getObject();
4901 },
4902 "Context that owns the `Type`.")
4903 .def(
4904 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4905 "Compares two types for equality.")
4906 .def(
4907 "__eq__", [](PyType &self, nb::object &other) { return false; },
4908 "other"_a.none(),
4909 "Compares type with non-type object (always returns False).")
4910 .def(
4911 "__hash__", [](PyType &self) { return hash(self.get().ptr); },
4912 "Returns the hash value of the `Type`.")
4913 .def(
4914 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4915 .def(
4916 "__str__",
4917 [](PyType &self) {
4918 PyPrintAccumulator printAccum;
4919 mlirTypePrint(self, printAccum.getCallback(),
4920 printAccum.getUserData());
4921 return printAccum.join();
4922 },
4923 "Returns the assembly form of the `Type`.")
4924 .def(
4925 "__repr__",
4926 [](PyType &self) {
4927 // Generally, assembly formats are not printed for __repr__ because
4928 // this can cause exceptionally long debug output and exceptions.
4929 // However, types are an exception as they typically have compact
4930 // assembly forms and printing them is useful.
4931 PyPrintAccumulator printAccum;
4932 printAccum.parts.append("Type(");
4933 mlirTypePrint(self, printAccum.getCallback(),
4934 printAccum.getUserData());
4935 printAccum.parts.append(")");
4936 return printAccum.join();
4937 },
4938 "Returns a string representation of the `Type`.")
4939 .def(
4941 [](PyType &self) -> nb::typed<nb::object, PyType> {
4942 return self.maybeDownCast();
4943 },
4944 "Downcasts the Type to a more specific `Type` if possible.")
4945 .def_prop_ro(
4946 "typeid",
4947 [](PyType &self) {
4948 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4949 if (!mlirTypeIDIsNull(mlirTypeID))
4950 return PyTypeID(mlirTypeID);
4951 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4952 throw nb::value_error(join(origRepr, " has no typeid.").c_str());
4953 },
4954 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4955 "`Type` has no "
4956 "`TypeID`.");
4957
4958 //----------------------------------------------------------------------------
4959 // Mapping of PyTypeID.
4960 //----------------------------------------------------------------------------
4961 nb::class_<PyTypeID>(m, "TypeID")
4963 "Gets a capsule wrapping the `MlirTypeID`.")
4965 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4966 // Note, this tests whether the underlying TypeIDs are the same,
4967 // not whether the wrapper MlirTypeIDs are the same, nor whether
4968 // the Python objects are the same (i.e., PyTypeID is a value type).
4969 .def(
4970 "__eq__",
4971 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4972 "Compares two `TypeID`s for equality.")
4973 .def(
4974 "__eq__",
4975 [](PyTypeID &self, const nb::object &other) { return false; },
4976 "Compares TypeID with non-TypeID object (always returns False).")
4977 // Note, this gives the hash value of the underlying TypeID, not the
4978 // hash value of the Python object, nor the hash value of the
4979 // MlirTypeID wrapper.
4980 .def(
4981 "__hash__",
4982 [](PyTypeID &self) {
4983 return static_cast<size_t>(mlirTypeIDHashValue(self));
4984 },
4985 "Returns the hash value of the `TypeID`.");
4986
4987 //----------------------------------------------------------------------------
4988 // Mapping of Value.
4989 //----------------------------------------------------------------------------
4990 m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
4991
4992 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4993 nb::sig("class Value(typing.Generic[_T])"))
4994 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
4995 "Creates a Value reference from another `Value`.")
4997 "Gets a capsule wrapping the `MlirValue`.")
4999 "Creates a `Value` from a capsule wrapping `MlirValue`.")
5000 .def_prop_ro(
5001 "context",
5002 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
5003 return self.getParentOperation()->getContext().getObject();
5004 },
5005 "Context in which the value lives.")
5006 .def(
5007 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
5009 .def_prop_ro(
5010 "owner",
5011 [](PyValue &self)
5012 -> nb::typed<nb::object, std::variant<PyOpView, PyBlock>> {
5013 MlirValue v = self.get();
5014 if (mlirValueIsAOpResult(v)) {
5015 assert(mlirOperationEqual(self.getParentOperation()->get(),
5016 mlirOpResultGetOwner(self.get())) &&
5017 "expected the owner of the value in Python to match "
5018 "that in "
5019 "the IR");
5020 return self.getParentOperation()->createOpView();
5021 }
5022
5024 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
5025 return nb::cast(PyBlock(self.getParentOperation(), block));
5026 }
5027
5028 assert(false && "Value must be a block argument or an op result");
5029 return nb::none();
5030 },
5031 "Returns the owner of the value (`Operation` for results, `Block` "
5032 "for "
5033 "arguments).")
5034 .def_prop_ro(
5035 "uses",
5036 [](PyValue &self) {
5037 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
5038 },
5039 "Returns an iterator over uses of this value.")
5040 .def(
5041 "__eq__",
5042 [](PyValue &self, PyValue &other) {
5043 return self.get().ptr == other.get().ptr;
5044 },
5045 "Compares two values for pointer equality.")
5046 .def(
5047 "__eq__", [](PyValue &self, nb::object other) { return false; },
5048 "Compares value with non-value object (always returns False).")
5049 .def(
5050 "__hash__", [](PyValue &self) { return hash(self.get().ptr); },
5051 "Returns the hash value of the value.")
5052 .def(
5053 "__str__",
5054 [](PyValue &self) {
5055 PyPrintAccumulator printAccum;
5056 printAccum.parts.append("Value(");
5057 mlirValuePrint(self.get(), printAccum.getCallback(),
5058 printAccum.getUserData());
5059 printAccum.parts.append(")");
5060 return printAccum.join();
5061 },
5062 R"(
5063 Returns the string form of the value.
5064
5065 If the value is a block argument, this is the assembly form of its type and the
5066 position in the argument list. If the value is an operation result, this is
5067 equivalent to printing the operation that produced it.
5068 )")
5069 .def(
5070 "get_name",
5071 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
5072 PyPrintAccumulator printAccum;
5073 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
5074 if (useLocalScope)
5076 if (useNameLocAsPrefix)
5078 MlirAsmState valueState =
5079 mlirAsmStateCreateForValue(self.get(), flags);
5080 mlirValuePrintAsOperand(self.get(), valueState,
5081 printAccum.getCallback(),
5082 printAccum.getUserData());
5084 mlirAsmStateDestroy(valueState);
5085 return printAccum.join();
5086 },
5087 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
5088 R"(
5089 Returns the string form of value as an operand.
5090
5091 Args:
5092 use_local_scope: Whether to use local scope for naming.
5093 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
5094
5095 Returns:
5096 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
5097 .def(
5098 "get_name",
5099 [](PyValue &self, PyAsmState &state) {
5100 PyPrintAccumulator printAccum;
5101 MlirAsmState valueState = state.get();
5102 mlirValuePrintAsOperand(self.get(), valueState,
5103 printAccum.getCallback(),
5104 printAccum.getUserData());
5105 return printAccum.join();
5106 },
5107 "state"_a,
5108 "Returns the string form of value as an operand (i.e., the ValueID).")
5109 .def_prop_ro(
5110 "type",
5111 [](PyValue &self) {
5112 return PyType(self.getParentOperation()->getContext(),
5113 mlirValueGetType(self.get()))
5114 .maybeDownCast();
5115 },
5116 "Returns the type of the value.", nb::sig("def type(self) -> _T"))
5117 .def(
5118 "set_type",
5119 [](PyValue &self, const PyType &type) {
5120 mlirValueSetType(self.get(), type);
5121 },
5122 "type"_a, "Sets the type of the value.",
5123 nb::sig("def set_type(self, type: _T)"))
5124 .def(
5125 "replace_all_uses_with",
5126 [](PyValue &self, PyValue &with) {
5127 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
5128 },
5129 "Replace all uses of value with the new value, updating anything in "
5130 "the IR that uses `self` to use the other value instead.")
5131 .def(
5132 "replace_all_uses_except",
5133 [](PyValue &self, PyValue &with, PyOperation &exception) {
5134 MlirOperation exceptedUser = exception.get();
5135 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
5136 },
5137 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
5138 .def(
5139 "replace_all_uses_except",
5140 [](PyValue &self, PyValue &with,
5141 std::vector<PyOperation> &exceptions) {
5142 // Convert Python list to a std::vector of MlirOperations
5143 std::vector<MlirOperation> exceptionOps;
5144 for (PyOperation &exception : exceptions)
5145 exceptionOps.push_back(exception);
5147 self, with, static_cast<intptr_t>(exceptionOps.size()),
5148 exceptionOps.data());
5149 },
5150 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
5151 .def(
5153 [](PyValue &self) { return self.maybeDownCast(); },
5154 "Downcasts the `Value` to a more specific kind if possible.")
5155 .def_prop_ro(
5156 "location",
5157 [](PyValue self) {
5158 return PyLocation(
5161 .maybeDownCast();
5162 },
5163 "Returns the source location of the value.");
5164
5168
5169 nb::class_<PyAsmState>(m, "AsmState")
5170 .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
5171 R"(
5172 Creates an `AsmState` for consistent SSA value naming.
5173
5174 Args:
5175 value: The value to create state for.
5176 use_local_scope: Whether to use local scope for naming.)")
5177 .def(nb::init<PyOperationBase &, bool>(), "op"_a,
5178 "use_local_scope"_a = false,
5179 R"(
5180 Creates an AsmState for consistent SSA value naming.
5181
5182 Args:
5183 op: The operation to create state for.
5184 use_local_scope: Whether to use local scope for naming.)");
5185
5186 //----------------------------------------------------------------------------
5187 // Mapping of SymbolTable.
5188 //----------------------------------------------------------------------------
5189 nb::class_<PySymbolTable>(m, "SymbolTable")
5190 .def(nb::init<PyOperationBase &>(),
5191 R"(
5192 Creates a symbol table for an operation.
5193
5194 Args:
5195 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
5196
5197 Raises:
5198 TypeError: If the operation is not a symbol table.)")
5199 .def(
5200 "__getitem__",
5201 [](PySymbolTable &self,
5202 const std::string &name) -> nb::typed<nb::object, PyOpView> {
5203 return self.dunderGetItem(name);
5204 },
5205 R"(
5206 Looks up a symbol by name in the symbol table.
5207
5208 Args:
5209 name: The name of the symbol to look up.
5210
5211 Returns:
5212 The operation defining the symbol.
5213
5214 Raises:
5215 KeyError: If the symbol is not found.)")
5216 .def("insert", &PySymbolTable::insert, "operation"_a,
5217 R"(
5218 Inserts a symbol operation into the symbol table.
5219
5220 Args:
5221 operation: An operation with a symbol name to insert.
5222
5223 Returns:
5224 The symbol name attribute of the inserted operation.
5225
5226 Raises:
5227 ValueError: If the operation does not have a symbol name.)")
5228 .def("erase", &PySymbolTable::erase, "operation"_a,
5229 R"(
5230 Erases a symbol operation from the symbol table.
5231
5232 Args:
5233 operation: The symbol operation to erase.
5234
5235 Note:
5236 The operation is also erased from the IR and invalidated.)")
5237 .def("__delitem__", &PySymbolTable::dunderDel,
5238 "Deletes a symbol by name from the symbol table.")
5239 .def(
5240 "__contains__",
5241 [](PySymbolTable &table, const std::string &name) {
5242 return !mlirOperationIsNull(mlirSymbolTableLookup(
5243 table, mlirStringRefCreate(name.data(), name.length())));
5244 },
5245 "Checks if a symbol with the given name exists in the table.")
5246 // Static helpers.
5247 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
5248 "name"_a, "Sets the symbol name for a symbol operation.")
5249 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
5250 "Gets the symbol name from a symbol operation.")
5251 .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
5252 "Gets the visibility attribute of a symbol operation.")
5253 .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
5254 "visibility"_a,
5255 "Sets the visibility attribute of a symbol operation.")
5256 .def_static("replace_all_symbol_uses",
5257 &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
5258 "new_symbol"_a, "from_op"_a,
5259 "Replaces all uses of a symbol with a new symbol name within "
5260 "the given operation.")
5261 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
5262 "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
5263 "Walks symbol tables starting from an operation with a "
5264 "callback function.");
5265
5266 // Container bindings.
5281
5282 // Debug bindings.
5284
5285 // Attribute builder getter.
5287
5288 // Extensible Dialect
5292
5293 // MLIRError exception.
5294 MLIRError::bind(m);
5295}
5296} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
5297} // namespace python
5298} // namespace mlir
return success()
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:32
static const char kModuleParseDocstring[]
Definition IRCore.cpp:33
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:56
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:61
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:44
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:453
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:479
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1113
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:682
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1133
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:694
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:686
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1150
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:528
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:749
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:444
bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID, MlirContext context)
Definition IR.cpp:662
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:177
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:745
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:646
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1298
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1169
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1099
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:201
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:475
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:209
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:810
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:718
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:219
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:467
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:214
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:232
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1145
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:757
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1317
MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags)
Definition IR.cpp:156
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:650
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:236
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:251
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1109
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:461
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:432
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:224
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:690
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1173
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:753
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1313
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:422
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:829
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:734
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1160
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:436
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:654
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:511
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:642
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1141
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1155
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:726
MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:730
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:668
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:824
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1309
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:520
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:672
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1302
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:834
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1137
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1118
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:698
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:596
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:255
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1294
void mlirOperationRemoveFromParent(MlirOperation op)
Definition IR.cpp:648
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1105
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:819
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:205
void mlirValueDump(MlirValue value)
Definition IR.cpp:1177
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1283
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:457
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:633
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:195
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:658
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1164
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:524
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:247
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:228
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:240
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:516
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:487
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:722
void mlirTypeDump(MlirType type)
Definition IR.cpp:1288
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:815
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the value caster registration bindin...
Definition Interop.h:142
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Definition Interop.h:130
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
Accumulates into a file, either writing text (default) or binary.
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:310
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:291
PyAsmState(MlirValue value, bool useLocalScope)
Definition IRCore.cpp:1750
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1018
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRCore.h:1020
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:1860
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:1856
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:1864
nanobind::typed< nanobind::object, PyAttribute > maybeDownCast()
Definition IRCore.cpp:1872
PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2206
Python wrapper for MlirBlockArgument.
Definition IRCore.h:1758
nanobind::typed< nanobind::object, PyBlock > dunderNext()
Definition IRCore.cpp:217
Blocks are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1553
PyBlock appendBlock(const nanobind::args &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Definition IRCore.cpp:273
PyBlockPredecessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2366
PyBlockSuccessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2343
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:186
Represents a diagnostic handler attached to the context.
Definition IRCore.h:422
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:728
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:722
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.h:434
Python class mirroring the C MlirDiagnostic struct.
Definition IRCore.h:372
nanobind::typed< nanobind::tuple, PyDiagnostic > getNotes()
Definition IRCore.cpp:767
Wrapper around an MlirDialectRegistry.
Definition IRCore.h:514
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:812
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRCore.h:490
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:795
static bool attach(const nanobind::object &opName, const nanobind::object &target, PyMlirContext &context)
Definition IRCore.cpp:2571
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2623
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2642
CurrentLocAction
Policy for composing Location.current with the computed location.
Definition Globals.h:141
OnExplicitAction
Policy for handling explicit loc= when loc_tracebacks() is active.
Definition Globals.h:133
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:29
void registerOpAdaptorImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds an operation adaptor class.
Definition Globals.cpp:167
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:203
bool loadDialectModule(std::string_view dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:64
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation dialect class.
Definition Globals.cpp:145
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
std::optional< nanobind::object > lookupOperationClass(std::string_view operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:233
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:125
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:156
void setDialectSearchPrefixes(std::vector< std::string > newValues)
Definition Globals.h:43
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:189
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:135
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:179
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:218
std::vector< std::string > getDialectSearchPrefixes()
Get and set the list of parent modules to search for dialect implementation classes.
Definition Globals.h:39
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false, bool allow_existing=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:99
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:849
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1781
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:1846
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1820
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:1829
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1807
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1772
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:1842
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:836
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:828
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:824
nanobind::typed< nanobind::object, PyLocation > maybeDownCast()
Returns the most-derived Location subclass registered for this TypeID, or self.
Definition IRCore.cpp:1890
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:840
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRCore.h:319
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:461
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:486
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:454
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:501
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:495
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:224
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRCore.cpp:446
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRCore.h:258
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:491
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:450
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:1840
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:905
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:564
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:873
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:898
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRCore.h:1044
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:1907
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRCore.h:66
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRCore.h:104
void dunderSetItem(const std::string &name, const PyAttribute &attr)
Definition IRCore.cpp:2426
nanobind::typed< nanobind::object, PyAttribute > dunderGetItemNamed(const std::string &name)
Definition IRCore.cpp:2393
nanobind::typed< nanobind::object, std::optional< PyAttribute > > get(const std::string &key, nanobind::object defaultValue)
Definition IRCore.cpp:2403
static void forEachAttr(MlirOperation op, std::function< void(MlirStringRef, MlirAttribute)> fn)
Definition IRCore.cpp:2448
PyNamedAttribute dunderGetItemIndexed(intptr_t index)
Definition IRCore.cpp:2411
nanobind::typed< nanobind::object, PyOpOperand > dunderNext()
Definition IRCore.cpp:382
PyOpOperandList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2238
void dunderSetItem(intptr_t index, PyValue value)
Definition IRCore.cpp:2246
nanobind::typed< nanobind::object, PyOpView > getOwner() const
Definition IRCore.cpp:363
PyOpOperands(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2281
Sliceable< PyOpOperandList, PyOpOperand > SliceableT
Definition IRCore.cpp:2279
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1400
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2310
void dunderSetItem(intptr_t index, PyBlock block)
Definition IRCore.cpp:2318
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRCore.h:751
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1740
static nanobind::typed< nanobind::object, PyOperation > buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::sequence > resultTypeList, nanobind::sequence operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1563
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1732
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:594
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1164
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1118
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition IRCore.cpp:1017
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1065
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1146
void walk(std::function< PyWalkResult(MlirOperation)> callback, PyWalkOrder walkOrder)
Definition IRCore.cpp:1086
nanobind::typed< nanobind::object, PyOpView > dunderNext()
Definition IRCore.cpp:294
Operations are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1594
nanobind::typed< nanobind::object, PyOpView > dunderGetItem(intptr_t index)
Definition IRCore.cpp:333
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, const MlirValue *operands, size_t numOperands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1228
void setInvalid()
Invalidate the operation.
Definition IRCore.h:718
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRCore.h:651
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:974
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1343
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1204
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1180
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1352
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1199
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:958
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRCore.cpp:986
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1363
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1190
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:965
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:913
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRCore.cpp:1001
Regions of an op are fixed length and indexed numerically so are represented with a sequence-like con...
Definition IRCore.h:1514
PyRegionList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:194
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2059
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2023
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2099
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2044
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2140
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2031
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2128
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2071
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2054
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2084
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2110
Tracks an entry in the thread context stack.
Definition IRCore.h:137
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:638
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:666
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:678
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:643
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:586
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:689
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:648
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:633
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:581
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRCore.h:193
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRCore.h:917
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:1953
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:1963
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:1957
Wrapper around the generic MlirType.
Definition IRCore.h:891
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:893
bool operator==(const PyType &other) const
Definition IRCore.cpp:1919
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:1923
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:1927
nanobind::typed< nanobind::object, PyType > maybeDownCast()
Definition IRCore.cpp:1935
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:1971
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRCore.h:1310
nanobind::typed< nanobind::object, std::variant< PyBlockArgument, PyOpResult, PyValue > > maybeDownCast()
Definition IRCore.cpp:1990
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:2011
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
struct MlirDiagnostic MlirDiagnostic
Definition Diagnostics.h:29
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate(void)
Get the dynamic op trait that indicates the operation is a terminator.
MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void)
Get the type ID of the dynamic op trait that indicates the operation is a terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate(MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData)
Create a custom dynamic op trait with the given type ID and callbacks.
MLIR_CAPI_EXPORTED bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, MlirStringRef opName, MlirContext context)
Attach a dynamic op trait to the given operation name.
MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void)
Get the type ID of the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate(void)
Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:264
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:866
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1220
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:116
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:859
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:272
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1402
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:136
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:310
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1381
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:859
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1329
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:391
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1386
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:83
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1350
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1263
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1246
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:72
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:969
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:354
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1202
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1185
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:402
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1234
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:288
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:360
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1321
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:1008
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:883
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1206
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:850
MlirWalkResult
Operation walk result.
Definition IR.h:852
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:949
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1171
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:99
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1077
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:292
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:346
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1058
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:94
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:955
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:992
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:952
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1033
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1095
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1376
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1267
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:280
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1232
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1366
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:414
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:298
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:374
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:376
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1081
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:844
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1192
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:111
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:875
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1279
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1242
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:333
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1023
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:387
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:887
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:268
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1271
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1362
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location)
Getter for metadata of Fused.
Definition IR.cpp:368
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:1012
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:324
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:398
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:905
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:54
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:1004
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:418
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1086
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1327
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1391
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:945
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:1016
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:891
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1275
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1338
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:107
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:120
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:304
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:378
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:103
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:328
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1224
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1358
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:932
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1072
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:873
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1261
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:76
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
Definition IR.cpp:923
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:879
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1179
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:866
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:70
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:938
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:124
MLIR_CAPI_EXPORTED int mlirLlvmThreadPoolGetMaxConcurrency(MlirLlvmThreadPool pool)
Returns the maximum number of threads in the thread pool.
Definition Support.cpp:38
MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool)
Destroy an LLVM thread pool.
Definition Support.cpp:34
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void)
Create an LLVM thread pool.
Definition Support.cpp:30
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
struct MlirStringRef MlirStringRef
Definition Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:201
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1529
static std::string formatMLIRError(const MLIRError &e)
Definition IRCore.cpp:2827
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m)
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1213
static void populateResultTypes(std::string_view name, nb::sequence resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1442
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:210
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1544
PyObjectRef< PyOperation > PyOperationRef
Definition IRCore.h:646
MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.h:1477
static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait, PyMlirContext &context)
Definition IRCore.cpp:2556
PyObjectRef< PyModule > PyModuleRef
Definition IRCore.h:553
static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData, const char *methodName)
Definition IRCore.cpp:2545
static PyOperationRef getValueOwnerRef(MlirValue value)
Definition IRCore.cpp:1975
MlirBlock MLIR_PYTHON_API_EXPORTED createBlock(const nanobind::typed< nanobind::sequence, PyType > &pyArgTypes, const std::optional< nanobind::typed< nanobind::sequence, PyLocation > > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1389
PyWalkOrder
Traversal order for operation walk.
Definition IRCore.h:363
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m)
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:2002
Action
The actions performed by @newSparseTensor.
Definition Enums.h:146
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
MlirLogicalResult(* verifyTrait)(MlirOperation op, void *userData)
The callback function to verify the operation.
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
MlirLogicalResult(* verifyRegionTrait)(MlirOperation op, void *userData)
The callback function to verify the operation with access to regions.
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Named MLIR attribute.
Definition IR.h:76
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRCore.h:1459
MLIRError(std::string message, std::vector< PyDiagnostic::DiagnosticInfo > &&errorDiagnostics={})
Definition IRCore.h:1460
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition IRCore.h:1470
static void bind(nanobind::module_ &m)
Bind the MLIRError exception class to the given module.
Definition IRCore.cpp:2862
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:146
static nanobind::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:151
static void dunderSetItemNamed(const std::string &attributeKind, nanobind::callable func, bool replace, bool allow_existing)
Definition IRCore.cpp:158
static void set(nanobind::object &o, bool enable)
Definition IRCore.cpp:108
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRCore.h:450
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRCore.h:460