MLIR 23.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
14#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15// clang-format on
17#include "mlir-c/Debug.h"
18#include "mlir-c/Diagnostics.h"
19#include "mlir-c/IR.h"
20#include "mlir-c/Support.h"
21
22#include <optional>
23#include <sstream>
24#include <string>
25
26namespace nb = nanobind;
27using namespace nb::literals;
28using namespace mlir;
29
30static const char kModuleParseDocstring[] =
31 R"(Parses a module's assembly format from a string.
32
33Returns a new MlirModule or raises an MLIRError if the parsing fails.
34
35See also: https://mlir.llvm.org/docs/LangRef/
36)";
37
38static const char kDumpDocstring[] =
39 "Dumps a debug representation of the object to stderr.";
40
42 R"(Replace all uses of this value with the `with` value, except for those
43in `exceptions`. `exceptions` can be either a single operation or a list of
44operations.
45)";
46
47//------------------------------------------------------------------------------
48// Utilities.
49//------------------------------------------------------------------------------
50
51/// Local helper to concatenate arguments into a `std::string`.
52template <typename... Ts>
53static std::string join(const Ts &...args) {
54 std::ostringstream oss;
55 (oss << ... << args);
56 return oss.str();
57}
58
59/// Helper for creating an @classmethod.
60template <class Func, typename... Args>
61static nb::object classmethod(Func f, Args... args) {
62 nb::object cf = nb::cpp_function(f, args...);
63 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
64}
65
66static nb::object
67createCustomDialectWrapper(const std::string &dialectNamespace,
68 nb::object dialectDescriptor) {
69 auto dialectClass =
71 dialectNamespace);
72 if (!dialectClass) {
73 // Use the base class.
75 std::move(dialectDescriptor)));
76 }
77
78 // Create the custom implementation.
79 return (*dialectClass)(std::move(dialectDescriptor));
80}
81
82namespace mlir {
83namespace python {
85
86MlirBlock createBlock(const nb::sequence &pyArgTypes,
87 const std::optional<nb::sequence> &pyArgLocs) {
88 std::vector<MlirType> argTypes;
89 argTypes.reserve(nb::len(pyArgTypes));
90 for (const auto &pyType : pyArgTypes)
91 argTypes.push_back(
92 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
93
94 std::vector<MlirLocation> argLocs;
95 if (pyArgLocs) {
96 argLocs.reserve(nb::len(*pyArgLocs));
97 for (const auto &pyLoc : *pyArgLocs)
98 argLocs.push_back(
99 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
100 } else if (!argTypes.empty()) {
101 argLocs.assign(
102 argTypes.size(),
104 }
105
106 if (argTypes.size() != argLocs.size())
107 throw nb::value_error(
108 join("Expected ", argTypes.size(), " locations, got: ", argLocs.size())
109 .c_str());
110 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
111}
112
113void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
114 nb::ft_lock_guard lock(mutex);
115 mlirEnableGlobalDebug(enable);
116}
117
118bool PyGlobalDebugFlag::get(const nb::object &) {
119 nb::ft_lock_guard lock(mutex);
121}
122
123void PyGlobalDebugFlag::bind(nb::module_ &m) {
124 // Debug flags.
125 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
126 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
127 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
128 .def_static(
129 "set_types",
130 [](const std::string &type) {
131 nb::ft_lock_guard lock(mutex);
132 mlirSetGlobalDebugType(type.c_str());
133 },
134 "types"_a, "Sets specific debug types to be produced by LLVM.")
135 .def_static(
136 "set_types",
137 [](const std::vector<std::string> &types) {
138 std::vector<const char *> pointers;
139 pointers.reserve(types.size());
140 for (const std::string &str : types)
141 pointers.push_back(str.c_str());
142 nb::ft_lock_guard lock(mutex);
143 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
144 },
145 "types"_a,
146 "Sets multiple specific debug types to be produced by LLVM.");
147}
148
149nb::ft_mutex PyGlobalDebugFlag::mutex;
150
151bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
152 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
153}
154
155nb::callable
156PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
157 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
158 if (!builder)
159 throw nb::key_error(attributeKind.c_str());
160 return *builder;
161}
162
163void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
164 nb::callable func, bool replace) {
165 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
166 replace);
167}
168
169void PyAttrBuilderMap::bind(nb::module_ &m) {
170 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
171 .def_static("contains", &PyAttrBuilderMap::dunderContains,
172 "attribute_kind"_a,
173 "Checks whether an attribute builder is registered for the "
174 "given attribute kind.")
175 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
176 "attribute_kind"_a,
177 "Gets the registered attribute builder for the given "
178 "attribute kind.")
179 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
180 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
181 "Register an attribute builder for building MLIR "
182 "attributes from Python values.");
183}
184
185//------------------------------------------------------------------------------
186// PyBlock
187//------------------------------------------------------------------------------
188
190 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
191}
192
193//------------------------------------------------------------------------------
194// Collections.
195//------------------------------------------------------------------------------
196
197nb::typed<nb::object, PyRegion> PyRegionIterator::dunderNext() {
198 operation->checkValid();
199 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
200 PyErr_SetNone(PyExc_StopIteration);
201 // python functions should return NULL after setting any exception
202 return nb::object();
203 }
204 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
205 return nb::cast(PyRegion(operation, region));
206}
207
208void PyRegionIterator::bind(nb::module_ &m) {
209 nb::class_<PyRegionIterator>(m, "RegionIterator")
210 .def("__iter__", &PyRegionIterator::dunderIter,
211 "Returns an iterator over the regions in the operation.")
212 .def("__next__", &PyRegionIterator::dunderNext,
213 "Returns the next region in the iteration.");
214}
215
219 length == -1 ? mlirOperationGetNumRegions(operation->get())
220 : length,
221 step),
222 operation(std::move(operation)) {}
223
225 operation->checkValid();
226 return PyRegionIterator(operation, startIndex);
227}
228
230 c.def("__iter__", &PyRegionList::dunderIter,
231 "Returns an iterator over the regions in the sequence.");
232}
233
234intptr_t PyRegionList::getRawNumElements() {
235 operation->checkValid();
236 return mlirOperationGetNumRegions(operation->get());
237}
238
239PyRegion PyRegionList::getRawElement(intptr_t pos) {
240 operation->checkValid();
241 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
242}
243
244PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
245 intptr_t step) const {
246 return PyRegionList(operation, startIndex, length, step);
247}
248
249nb::typed<nb::object, PyBlock> PyBlockIterator::dunderNext() {
250 operation->checkValid();
251 if (mlirBlockIsNull(next)) {
252 PyErr_SetNone(PyExc_StopIteration);
253 // python functions should return NULL after setting any exception
254 return nb::object();
255 }
256
257 PyBlock returnBlock(operation, next);
258 next = mlirBlockGetNextInRegion(next);
259 return nb::cast(returnBlock);
260}
261
262void PyBlockIterator::bind(nb::module_ &m) {
263 nb::class_<PyBlockIterator>(m, "BlockIterator")
264 .def("__iter__", &PyBlockIterator::dunderIter,
265 "Returns an iterator over the blocks in the operation's region.")
266 .def("__next__", &PyBlockIterator::dunderNext,
267 "Returns the next block in the iteration.");
268}
269
271 operation->checkValid();
272 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
273}
274
276 operation->checkValid();
277 intptr_t count = 0;
278 MlirBlock block = mlirRegionGetFirstBlock(region);
279 while (!mlirBlockIsNull(block)) {
280 count += 1;
281 block = mlirBlockGetNextInRegion(block);
282 }
283 return count;
284}
285
287 operation->checkValid();
288 if (index < 0) {
289 index += dunderLen();
290 }
291 if (index < 0) {
292 throw nb::index_error("attempt to access out of bounds block");
293 }
294 MlirBlock block = mlirRegionGetFirstBlock(region);
295 while (!mlirBlockIsNull(block)) {
296 if (index == 0) {
297 return PyBlock(operation, block);
298 }
299 block = mlirBlockGetNextInRegion(block);
300 index -= 1;
301 }
302 throw nb::index_error("attempt to access out of bounds block");
303}
304
305PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
306 const std::optional<nb::sequence> &pyArgLocs) {
307 operation->checkValid();
308 MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
309 mlirRegionAppendOwnedBlock(region, block);
310 return PyBlock(operation, block);
311}
312
313void PyBlockList::bind(nb::module_ &m) {
314 nb::class_<PyBlockList>(m, "BlockList")
315 .def("__getitem__", &PyBlockList::dunderGetItem,
316 "Returns the block at the specified index.")
317 .def("__iter__", &PyBlockList::dunderIter,
318 "Returns an iterator over blocks in the operation's region.")
319 .def("__len__", &PyBlockList::dunderLen,
320 "Returns the number of blocks in the operation's region.")
321 .def("append", &PyBlockList::appendBlock,
322 R"(
323 Appends a new block, with argument types as positional args.
324
325 Returns:
326 The created block.
327 )",
328 "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
329}
330
331nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
332 parentOperation->checkValid();
333 if (mlirOperationIsNull(next)) {
334 PyErr_SetNone(PyExc_StopIteration);
335 // python functions should return NULL after setting any exception
336 return nb::object();
337 }
338
339 PyOperationRef returnOperation =
340 PyOperation::forOperation(parentOperation->getContext(), next);
341 next = mlirOperationGetNextInBlock(next);
342 return returnOperation->createOpView();
343}
344
345void PyOperationIterator::bind(nb::module_ &m) {
346 nb::class_<PyOperationIterator>(m, "OperationIterator")
347 .def("__iter__", &PyOperationIterator::dunderIter,
348 "Returns an iterator over the operations in an operation's block.")
349 .def("__next__", &PyOperationIterator::dunderNext,
350 "Returns the next operation in the iteration.");
351}
352
354 parentOperation->checkValid();
355 return PyOperationIterator(parentOperation,
357}
358
360 parentOperation->checkValid();
361 intptr_t count = 0;
362 MlirOperation childOp = mlirBlockGetFirstOperation(block);
363 while (!mlirOperationIsNull(childOp)) {
364 count += 1;
365 childOp = mlirOperationGetNextInBlock(childOp);
366 }
367 return count;
368}
369
370nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
371 parentOperation->checkValid();
372 if (index < 0) {
373 index += dunderLen();
374 }
375 if (index < 0) {
376 throw nb::index_error("attempt to access out of bounds operation");
377 }
378 MlirOperation childOp = mlirBlockGetFirstOperation(block);
379 while (!mlirOperationIsNull(childOp)) {
380 if (index == 0) {
381 return PyOperation::forOperation(parentOperation->getContext(), childOp)
382 ->createOpView();
383 }
384 childOp = mlirOperationGetNextInBlock(childOp);
385 index -= 1;
386 }
387 throw nb::index_error("attempt to access out of bounds operation");
388}
389
390void PyOperationList::bind(nb::module_ &m) {
391 nb::class_<PyOperationList>(m, "OperationList")
392 .def("__getitem__", &PyOperationList::dunderGetItem,
393 "Returns the operation at the specified index.")
394 .def("__iter__", &PyOperationList::dunderIter,
395 "Returns an iterator over operations in the list.")
396 .def("__len__", &PyOperationList::dunderLen,
397 "Returns the number of operations in the list.");
398}
399
400nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
401 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
405}
407size_t PyOpOperand::getOperandNumber() const {
408 return mlirOpOperandGetOperandNumber(opOperand);
409}
410
411void PyOpOperand::bind(nb::module_ &m) {
412 nb::class_<PyOpOperand>(m, "OpOperand")
413 .def_prop_ro("owner", &PyOpOperand::getOwner,
414 "Returns the operation that owns this operand.")
415 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
416 "Returns the operand number in the owning operation.");
417}
418
419nb::typed<nb::object, PyOpOperand> PyOpOperandIterator::dunderNext() {
420 if (mlirOpOperandIsNull(opOperand)) {
421 PyErr_SetNone(PyExc_StopIteration);
422 // python functions should return NULL after setting any exception
423 return nb::object();
424 }
425
426 PyOpOperand returnOpOperand(opOperand);
427 opOperand = mlirOpOperandGetNextUse(opOperand);
428 return nb::cast(returnOpOperand);
429}
430
431void PyOpOperandIterator::bind(nb::module_ &m) {
432 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
433 .def("__iter__", &PyOpOperandIterator::dunderIter,
434 "Returns an iterator over operands.")
435 .def("__next__", &PyOpOperandIterator::dunderNext,
436 "Returns the next operand in the iteration.");
437}
439//------------------------------------------------------------------------------
440// PyThreadPool
441//------------------------------------------------------------------------------
444 ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
445}
446
447std::string PyThreadPool::_mlir_thread_pool_ptr() const {
448 std::stringstream ss;
449 ss << ownedThreadPool.get();
450 return ss.str();
451}
453//------------------------------------------------------------------------------
454// PyMlirContext
455//------------------------------------------------------------------------------
456
457PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
458 nb::gil_scoped_acquire acquire;
459 nb::ft_lock_guard lock(live_contexts_mutex);
460 auto &liveContexts = getLiveContexts();
461 liveContexts[context.ptr] = this;
462}
463
465 // Note that the only public way to construct an instance is via the
466 // forContext method, which always puts the associated handle into
467 // liveContexts.
468 nb::gil_scoped_acquire acquire;
469 {
470 nb::ft_lock_guard lock(live_contexts_mutex);
471 getLiveContexts().erase(context.ptr);
472 }
473 mlirContextDestroy(context);
474}
477 return PyMlirContextRef(this, nb::cast(this));
478}
480nb::object PyMlirContext::getCapsule() {
481 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
482}
483
484nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
485 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
486 if (mlirContextIsNull(rawContext))
487 throw nb::python_error();
488 return forContext(rawContext).releaseObject();
489}
490
491PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
492 nb::gil_scoped_acquire acquire;
493 nb::ft_lock_guard lock(live_contexts_mutex);
494 auto &liveContexts = getLiveContexts();
495 auto it = liveContexts.find(context.ptr);
496 if (it == liveContexts.end()) {
497 // Create.
498 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
499 nb::object pyRef = nb::cast(unownedContextWrapper);
500 assert(pyRef && "cast to nb::object failed");
501 liveContexts[context.ptr] = unownedContextWrapper;
502 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
503 }
504 // Use existing.
505 nb::object pyRef = nb::cast(it->second);
506 return PyMlirContextRef(it->second, std::move(pyRef));
507}
508
509nb::ft_mutex PyMlirContext::live_contexts_mutex;
510
511PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
512 static LiveContextMap liveContexts;
513 return liveContexts;
514}
515
517 nb::ft_lock_guard lock(live_contexts_mutex);
518 return getLiveContexts().size();
519}
521nb::object PyMlirContext::contextEnter(nb::object context) {
522 return PyThreadContextEntry::pushContext(context);
523}
524
525void PyMlirContext::contextExit(const nb::object &excType,
526 const nb::object &excVal,
527 const nb::object &excTb) {
529}
530
531nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
532 // Note that ownership is transferred to the delete callback below by way of
533 // an explicit inc_ref (borrow).
534 PyDiagnosticHandler *pyHandler =
535 new PyDiagnosticHandler(get(), std::move(callback));
536 nb::object pyHandlerObject =
537 nb::cast(pyHandler, nb::rv_policy::take_ownership);
538 (void)pyHandlerObject.inc_ref();
539
540 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
541 // guaranteed to be known to pybind.
542 auto handlerCallback =
543 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
544 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
545 nb::object pyDiagnosticObject =
546 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
547
548 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
549 bool result = false;
550 {
551 // Since this can be called from arbitrary C++ contexts, always get the
552 // gil.
553 nb::gil_scoped_acquire gil;
554 try {
555 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
556 } catch (std::exception &e) {
557 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
558 e.what());
559 pyHandler->hadError = true;
560 }
561 }
562
563 pyDiagnostic->invalidate();
565 };
566 auto deleteCallback = +[](void *userData) {
567 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
568 assert(pyHandler->registeredID && "handler is not registered");
569 pyHandler->registeredID.reset();
570
571 // Decrement reference, balancing the inc_ref() above.
572 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
573 pyHandlerObject.dec_ref();
574 };
575
576 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
577 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
578 return pyHandlerObject;
579}
580
581MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
582 void *userData) {
583 auto *self = static_cast<ErrorCapture *>(userData);
584 // Check if the context requested we emit errors instead of capturing them.
585 if (self->ctx->emitErrorDiagnostics)
587
589 MlirDiagnosticSeverity::MlirDiagnosticError)
592 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
594}
595
598 if (!context) {
599 throw std::runtime_error(
600 "An MLIR function requires a Context but none was provided in the call "
601 "or from the surrounding environment. Either pass to the function with "
602 "a 'context=' argument or establish a default using 'with Context():'");
603 }
604 return *context;
605}
607//------------------------------------------------------------------------------
608// PyThreadContextEntry management
609//------------------------------------------------------------------------------
610
611std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
612 static thread_local std::vector<PyThreadContextEntry> stack;
613 return stack;
614}
615
617 auto &stack = getStack();
618 if (stack.empty())
619 return nullptr;
620 return &stack.back();
621}
622
623void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
624 nb::object insertionPoint,
625 nb::object location) {
626 auto &stack = getStack();
627 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
628 std::move(location));
629 // If the new stack has more than one entry and the context of the new top
630 // entry matches the previous, copy the insertionPoint and location from the
631 // previous entry if missing from the new top entry.
632 if (stack.size() > 1) {
633 auto &prev = *(stack.rbegin() + 1);
634 auto &current = stack.back();
635 if (current.context.is(prev.context)) {
636 // Default non-context objects from the previous entry.
637 if (!current.insertionPoint)
638 current.insertionPoint = prev.insertionPoint;
639 if (!current.location)
640 current.location = prev.location;
641 }
642 }
643}
644
646 if (!context)
647 return nullptr;
648 return nb::cast<PyMlirContext *>(context);
649}
650
652 if (!insertionPoint)
653 return nullptr;
654 return nb::cast<PyInsertionPoint *>(insertionPoint);
655}
656
658 if (!location)
659 return nullptr;
660 return nb::cast<PyLocation *>(location);
661}
662
664 auto *tos = getTopOfStack();
665 return tos ? tos->getContext() : nullptr;
666}
667
669 auto *tos = getTopOfStack();
670 return tos ? tos->getInsertionPoint() : nullptr;
671}
672
674 auto *tos = getTopOfStack();
675 return tos ? tos->getLocation() : nullptr;
676}
677
678nb::object PyThreadContextEntry::pushContext(nb::object context) {
679 push(FrameKind::Context, /*context=*/context,
680 /*insertionPoint=*/nb::object(),
681 /*location=*/nb::object());
682 return context;
683}
684
686 auto &stack = getStack();
687 if (stack.empty())
688 throw std::runtime_error("Unbalanced Context enter/exit");
689 auto &tos = stack.back();
690 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
691 throw std::runtime_error("Unbalanced Context enter/exit");
692 stack.pop_back();
693}
694
695nb::object
696PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
697 PyInsertionPoint &insertionPoint =
698 nb::cast<PyInsertionPoint &>(insertionPointObj);
699 nb::object contextObj =
700 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
701 push(FrameKind::InsertionPoint,
702 /*context=*/contextObj,
703 /*insertionPoint=*/insertionPointObj,
704 /*location=*/nb::object());
705 return insertionPointObj;
706}
707
709 auto &stack = getStack();
710 if (stack.empty())
711 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
712 auto &tos = stack.back();
713 if (tos.frameKind != FrameKind::InsertionPoint &&
714 tos.getInsertionPoint() != &insertionPoint)
715 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
716 stack.pop_back();
717}
718
719nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
720 PyLocation &location = nb::cast<PyLocation &>(locationObj);
721 nb::object contextObj = location.getContext().getObject();
722 push(FrameKind::Location, /*context=*/contextObj,
723 /*insertionPoint=*/nb::object(),
724 /*location=*/locationObj);
725 return locationObj;
726}
727
729 auto &stack = getStack();
730 if (stack.empty())
731 throw std::runtime_error("Unbalanced Location enter/exit");
732 auto &tos = stack.back();
733 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
734 throw std::runtime_error("Unbalanced Location enter/exit");
735 stack.pop_back();
736}
738//------------------------------------------------------------------------------
739// PyDiagnostic*
740//------------------------------------------------------------------------------
741
743 valid = false;
744 if (materializedNotes) {
745 for (nb::handle noteObject : *materializedNotes) {
746 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
747 note->invalidate();
748 }
749 }
750}
751
753 nb::object callback)
754 : context(context), callback(std::move(callback)) {}
755
757
759 if (!registeredID)
760 return;
761 MlirDiagnosticHandlerID localID = *registeredID;
762 mlirContextDetachDiagnosticHandler(context, localID);
763 assert(!registeredID && "should have unregistered");
764 // Not strictly necessary but keeps stale pointers from being around to cause
765 // issues.
766 context = {nullptr};
767}
768
769void PyDiagnostic::checkValid() {
770 if (!valid) {
771 throw std::invalid_argument(
772 "Diagnostic is invalid (used outside of callback)");
773 }
774}
775
777 checkValid();
778 return static_cast<PyDiagnosticSeverity>(
779 mlirDiagnosticGetSeverity(diagnostic));
780}
781
783 checkValid();
784 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
785 MlirContext context = mlirLocationGetContext(loc);
786 return PyLocation(PyMlirContext::forContext(context), loc);
787}
788
789nb::str PyDiagnostic::getMessage() {
790 checkValid();
791 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
792 PyFileAccumulator accum(fileObject, /*binary=*/false);
793 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
794 return nb::cast<nb::str>(fileObject.attr("getvalue")());
795}
796
797nb::tuple PyDiagnostic::getNotes() {
798 checkValid();
799 if (materializedNotes)
800 return *materializedNotes;
801 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
802 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
803 for (intptr_t i = 0; i < numNotes; ++i) {
804 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
805 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
806 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
807 }
808 materializedNotes = std::move(notes);
809
810 return *materializedNotes;
811}
812
814 std::vector<DiagnosticInfo> notes;
815 for (nb::handle n : getNotes())
816 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
817 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
818 std::move(notes)};
819}
821//------------------------------------------------------------------------------
822// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
823//------------------------------------------------------------------------------
824
825MlirDialect PyDialects::getDialectForKey(const std::string &key,
826 bool attrError) {
827 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
828 {key.data(), key.size()});
829 if (mlirDialectIsNull(dialect)) {
830 std::string msg = join("Dialect '", key, "' not found");
831 if (attrError)
832 throw nb::attribute_error(msg.c_str());
833 throw nb::index_error(msg.c_str());
834 }
835 return dialect;
836}
839 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
840}
841
843 MlirDialectRegistry rawRegistry =
845 if (mlirDialectRegistryIsNull(rawRegistry))
846 throw nb::python_error();
847 return PyDialectRegistry(rawRegistry);
848}
850//------------------------------------------------------------------------------
851// PyLocation
852//------------------------------------------------------------------------------
854nb::object PyLocation::getCapsule() {
855 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
856}
857
858PyLocation PyLocation::createFromCapsule(nb::object capsule) {
859 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
860 if (mlirLocationIsNull(rawLoc))
861 throw nb::python_error();
863 rawLoc);
864}
866nb::object PyLocation::contextEnter(nb::object locationObj) {
867 return PyThreadContextEntry::pushLocation(locationObj);
868}
869
870void PyLocation::contextExit(const nb::object &excType,
871 const nb::object &excVal,
872 const nb::object &excTb) {
874}
875
878 if (!location) {
879 throw std::runtime_error(
880 "An MLIR function requires a Location but none was provided in the "
881 "call or from the surrounding environment. Either pass to the function "
882 "with a 'loc=' argument or establish a default using 'with loc:'");
883 }
884 return *location;
885}
886
887//------------------------------------------------------------------------------
888// PyModule
889//------------------------------------------------------------------------------
890
891PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
892 : BaseContextObject(std::move(contextRef)), module(module) {}
893
895 nb::gil_scoped_acquire acquire;
896 auto &liveModules = getContext()->liveModules;
897 assert(liveModules.count(module.ptr) == 1 &&
898 "destroying module not in live map");
899 liveModules.erase(module.ptr);
900 mlirModuleDestroy(module);
901}
902
903PyModuleRef PyModule::forModule(MlirModule module) {
904 MlirContext context = mlirModuleGetContext(module);
905 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
906
907 nb::gil_scoped_acquire acquire;
908 auto &liveModules = contextRef->liveModules;
909 auto it = liveModules.find(module.ptr);
910 if (it == liveModules.end()) {
911 // Create.
912 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
913 // Note that the default return value policy on cast is automatic_reference,
914 // which does not take ownership (delete will not be called).
915 // Just be explicit.
916 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
917 unownedModule->handle = pyRef;
918 liveModules[module.ptr] =
919 std::make_pair(unownedModule->handle, unownedModule);
920 return PyModuleRef(unownedModule, std::move(pyRef));
921 }
922 // Use existing.
923 PyModule *existing = it->second.second;
924 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
925 return PyModuleRef(existing, std::move(pyRef));
926}
927
928nb::object PyModule::createFromCapsule(nb::object capsule) {
929 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
930 if (mlirModuleIsNull(rawModule))
931 throw nb::python_error();
932 return forModule(rawModule).releaseObject();
933}
934
935nb::object PyModule::getCapsule() {
936 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
937}
939//------------------------------------------------------------------------------
940// PyOperation
941//------------------------------------------------------------------------------
942
943PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
944 : BaseContextObject(std::move(contextRef)), operation(operation) {}
945
947 // If the operation has already been invalidated there is nothing to do.
948 if (!valid)
949 return;
950 // Otherwise, invalidate the operation when it is attached.
951 if (isAttached())
952 setInvalid();
953 else {
954 // And destroy it when it is detached, i.e. owned by Python.
955 erase();
956 }
957}
958
959namespace {
960
961// Constructs a new object of type T in-place on the Python heap, returning a
962// PyObjectRef to it, loosely analogous to std::make_shared<T>().
963template <typename T, class... Args>
964PyObjectRef<T> makeObjectRef(Args &&...args) {
965 nb::handle type = nb::type<T>();
966 nb::object instance = nb::inst_alloc(type);
967 T *ptr = nb::inst_ptr<T>(instance);
968 new (ptr) T(std::forward<Args>(args)...);
969 nb::inst_mark_ready(instance);
970 return PyObjectRef<T>(ptr, std::move(instance));
971}
972
973} // namespace
974
975PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
976 MlirOperation operation,
977 nb::object parentKeepAlive) {
978 // Create.
979 PyOperationRef unownedOperation =
980 makeObjectRef<PyOperation>(std::move(contextRef), operation);
981 unownedOperation->handle = unownedOperation.getObject();
982 if (parentKeepAlive) {
983 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
984 }
985 return unownedOperation;
986}
987
989 MlirOperation operation,
990 nb::object parentKeepAlive) {
991 return createInstance(std::move(contextRef), operation,
992 std::move(parentKeepAlive));
993}
994
996 MlirOperation operation,
997 nb::object parentKeepAlive) {
998 PyOperationRef created = createInstance(std::move(contextRef), operation,
999 std::move(parentKeepAlive));
1000 created->attached = false;
1001 return created;
1002}
1003
1005 const std::string &sourceStr,
1006 const std::string &sourceName) {
1007 PyMlirContext::ErrorCapture errors(contextRef);
1008 MlirOperation op =
1009 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1010 toMlirStringRef(sourceName));
1011 if (mlirOperationIsNull(op))
1012 throw MLIRError("Unable to parse operation assembly", errors.take());
1013 return PyOperation::createDetached(std::move(contextRef), op);
1014}
1015
1018 setDetached();
1019 parentKeepAlive = nb::object();
1020}
1021
1022MlirOperation PyOperation::get() const {
1023 checkValid();
1024 return operation;
1025}
1028 return PyOperationRef(this, nb::borrow<nb::object>(handle));
1029}
1030
1031void PyOperation::setAttached(const nb::object &parent) {
1032 assert(!attached && "operation already attached");
1033 attached = true;
1034}
1035
1037 assert(attached && "operation already detached");
1038 attached = false;
1039}
1040
1041void PyOperation::checkValid() const {
1042 if (!valid) {
1043 throw std::runtime_error("the operation has been invalidated");
1044 }
1045}
1046
1047void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1048 std::optional<int64_t> largeResourceLimit,
1049 bool enableDebugInfo, bool prettyDebugInfo,
1050 bool printGenericOpForm, bool useLocalScope,
1051 bool useNameLocAsPrefix, bool assumeVerified,
1052 nb::object fileObject, bool binary,
1053 bool skipRegions) {
1054 PyOperation &operation = getOperation();
1055 operation.checkValid();
1056 if (fileObject.is_none())
1057 fileObject = nb::module_::import_("sys").attr("stdout");
1058
1059 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1060 if (largeElementsLimit)
1061 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1062 if (largeResourceLimit)
1063 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1064 if (enableDebugInfo)
1065 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1066 /*prettyForm=*/prettyDebugInfo);
1067 if (printGenericOpForm)
1069 if (useLocalScope)
1071 if (assumeVerified)
1073 if (skipRegions)
1075 if (useNameLocAsPrefix)
1077
1078 PyFileAccumulator accum(fileObject, binary);
1079 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1080 accum.getUserData());
1082}
1083
1084void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1085 bool binary) {
1086 PyOperation &operation = getOperation();
1087 operation.checkValid();
1088 if (fileObject.is_none())
1089 fileObject = nb::module_::import_("sys").attr("stdout");
1090 PyFileAccumulator accum(fileObject, binary);
1091 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1092 accum.getUserData());
1093}
1094
1095void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1096 std::optional<int64_t> bytecodeVersion) {
1097 PyOperation &operation = getOperation();
1098 operation.checkValid();
1099 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1100
1101 if (!bytecodeVersion.has_value())
1102 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1103 accum.getUserData());
1104
1105 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1108 operation, config, accum.getCallback(), accum.getUserData());
1111 throw nb::value_error(
1112 join("Unable to honor desired bytecode version ", *bytecodeVersion)
1113 .c_str());
1114}
1115
1116void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
1117 PyWalkOrder walkOrder) {
1118 PyOperation &operation = getOperation();
1119 operation.checkValid();
1120 struct UserData {
1121 std::function<PyWalkResult(MlirOperation)> callback;
1122 bool gotException;
1123 std::string exceptionWhat;
1124 nb::object exceptionType;
1125 };
1126 UserData userData{callback, false, {}, {}};
1127 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1128 void *userData) {
1129 UserData *calleeUserData = static_cast<UserData *>(userData);
1130 try {
1131 return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
1132 } catch (nb::python_error &e) {
1133 calleeUserData->gotException = true;
1134 calleeUserData->exceptionWhat = std::string(e.what());
1135 calleeUserData->exceptionType = nb::borrow(e.type());
1136 return MlirWalkResult::MlirWalkResultInterrupt;
1137 }
1138 };
1139 mlirOperationWalk(operation, walkCallback, &userData,
1140 static_cast<MlirWalkOrder>(walkOrder));
1141 if (userData.gotException) {
1142 std::string message("Exception raised in callback: ");
1143 message.append(userData.exceptionWhat);
1144 throw std::runtime_error(message);
1145 }
1146}
1147
1148nb::object PyOperationBase::getAsm(bool binary,
1149 std::optional<int64_t> largeElementsLimit,
1150 std::optional<int64_t> largeResourceLimit,
1151 bool enableDebugInfo, bool prettyDebugInfo,
1152 bool printGenericOpForm, bool useLocalScope,
1153 bool useNameLocAsPrefix, bool assumeVerified,
1154 bool skipRegions) {
1155 nb::object fileObject;
1156 if (binary) {
1157 fileObject = nb::module_::import_("io").attr("BytesIO")();
1158 } else {
1159 fileObject = nb::module_::import_("io").attr("StringIO")();
1160 }
1161 print(/*largeElementsLimit=*/largeElementsLimit,
1162 /*largeResourceLimit=*/largeResourceLimit,
1163 /*enableDebugInfo=*/enableDebugInfo,
1164 /*prettyDebugInfo=*/prettyDebugInfo,
1165 /*printGenericOpForm=*/printGenericOpForm,
1166 /*useLocalScope=*/useLocalScope,
1167 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1168 /*assumeVerified=*/assumeVerified,
1169 /*fileObject=*/fileObject,
1170 /*binary=*/binary,
1171 /*skipRegions=*/skipRegions);
1172
1173 return fileObject.attr("getvalue")();
1174}
1175
1177 PyOperation &operation = getOperation();
1178 PyOperation &otherOp = other.getOperation();
1179 operation.checkValid();
1180 otherOp.checkValid();
1181 mlirOperationMoveAfter(operation, otherOp);
1182 operation.parentKeepAlive = otherOp.parentKeepAlive;
1183}
1184
1186 PyOperation &operation = getOperation();
1187 PyOperation &otherOp = other.getOperation();
1188 operation.checkValid();
1189 otherOp.checkValid();
1190 mlirOperationMoveBefore(operation, otherOp);
1191 operation.parentKeepAlive = otherOp.parentKeepAlive;
1192}
1193
1195 PyOperation &operation = getOperation();
1196 PyOperation &otherOp = other.getOperation();
1197 operation.checkValid();
1198 otherOp.checkValid();
1199 return mlirOperationIsBeforeInBlock(operation, otherOp);
1200}
1201
1203 PyOperation &op = getOperation();
1206 throw MLIRError("Verification failed", errors.take());
1207 return true;
1208}
1209
1210std::optional<PyOperationRef> PyOperation::getParentOperation() {
1211 checkValid();
1212 if (!isAttached())
1213 throw nb::value_error("Detached operations have no parent");
1214 MlirOperation operation = mlirOperationGetParentOperation(get());
1215 if (mlirOperationIsNull(operation))
1216 return {};
1217 return PyOperation::forOperation(getContext(), operation);
1218}
1219
1221 checkValid();
1222 std::optional<PyOperationRef> parentOperation = getParentOperation();
1223 MlirBlock block = mlirOperationGetBlock(get());
1224 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1225 assert(parentOperation && "Operation has no parent");
1226 return PyBlock{std::move(*parentOperation), block};
1227}
1228
1230 checkValid();
1231 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1232}
1233
1234nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1235 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1236 if (mlirOperationIsNull(rawOperation))
1237 throw nb::python_error();
1238 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1239 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1240 .releaseObject();
1241}
1242
1243static void maybeInsertOperation(PyOperationRef &op,
1244 const nb::object &maybeIp) {
1245 // InsertPoint active?
1246 if (!maybeIp.is(nb::cast(false))) {
1247 PyInsertionPoint *ip;
1248 if (maybeIp.is_none()) {
1250 } else {
1251 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1252 }
1253 if (ip)
1254 ip->insert(*op.get());
1255 }
1256}
1257
1258nb::object PyOperation::create(std::string_view name,
1259 std::optional<std::vector<PyType *>> results,
1260 const MlirValue *operands, size_t numOperands,
1261 std::optional<nb::dict> attributes,
1262 std::optional<std::vector<PyBlock *>> successors,
1263 int regions, PyLocation &location,
1264 const nb::object &maybeIp, bool inferType) {
1265 std::vector<MlirType> mlirResults;
1266 std::vector<MlirBlock> mlirSuccessors;
1267 std::vector<std::pair<std::string, MlirAttribute>> mlirAttributes;
1268
1269 // General parameter validation.
1270 if (regions < 0)
1271 throw nb::value_error("number of regions must be >= 0");
1272
1273 // Unpack/validate results.
1274 if (results) {
1275 mlirResults.reserve(results->size());
1276 for (PyType *result : *results) {
1277 // TODO: Verify result type originate from the same context.
1278 if (!result)
1279 throw nb::value_error("result type cannot be None");
1280 mlirResults.push_back(*result);
1281 }
1282 }
1283 // Unpack/validate attributes.
1284 if (attributes) {
1285 mlirAttributes.reserve(attributes->size());
1286 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1287 std::string key;
1288 try {
1289 key = nb::cast<std::string>(it.first);
1290 } catch (nb::cast_error &err) {
1291 std::string msg = join("Invalid attribute key (not a string) when "
1292 "attempting to create the operation \"",
1293 name, "\" (", err.what(), ")");
1294 throw nb::type_error(msg.c_str());
1295 }
1296 try {
1297 auto &attribute = nb::cast<PyAttribute &>(it.second);
1298 // TODO: Verify attribute originates from the same context.
1299 mlirAttributes.emplace_back(std::move(key), attribute);
1300 } catch (nb::cast_error &err) {
1301 std::string msg = join("Invalid attribute value for the key \"", key,
1302 "\" when attempting to create the operation \"",
1303 name, "\" (", err.what(), ")");
1304 throw nb::type_error(msg.c_str());
1305 } catch (std::runtime_error &) {
1306 // This exception seems thrown when the value is "None".
1307 std::string msg = join(
1308 "Found an invalid (`None`?) attribute value for the key \"", key,
1309 "\" when attempting to create the operation \"", name, "\"");
1310 throw std::runtime_error(msg);
1311 }
1312 }
1313 }
1314 // Unpack/validate successors.
1315 if (successors) {
1316 mlirSuccessors.reserve(successors->size());
1317 for (auto *successor : *successors) {
1318 // TODO: Verify successor originate from the same context.
1319 if (!successor)
1320 throw nb::value_error("successor block cannot be None");
1321 mlirSuccessors.push_back(successor->get());
1322 }
1323 }
1324
1325 // Apply unpacked/validated to the operation state. Beyond this
1326 // point, exceptions cannot be thrown or else the state will leak.
1327 MlirOperationState state =
1328 mlirOperationStateGet(toMlirStringRef(name), location);
1329 if (numOperands > 0)
1330 mlirOperationStateAddOperands(&state, numOperands, operands);
1331 state.enableResultTypeInference = inferType;
1332 if (!mlirResults.empty())
1333 mlirOperationStateAddResults(&state, mlirResults.size(),
1334 mlirResults.data());
1335 if (!mlirAttributes.empty()) {
1336 // Note that the attribute names directly reference bytes in
1337 // mlirAttributes, so that vector must not be changed from here
1338 // on.
1339 std::vector<MlirNamedAttribute> mlirNamedAttributes;
1340 mlirNamedAttributes.reserve(mlirAttributes.size());
1341 for (auto &it : mlirAttributes)
1342 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1344 toMlirStringRef(it.first)),
1345 it.second));
1346 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1347 mlirNamedAttributes.data());
1348 }
1349 if (!mlirSuccessors.empty())
1350 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1351 mlirSuccessors.data());
1352 if (regions) {
1353 std::vector<MlirRegion> mlirRegions;
1354 mlirRegions.resize(regions);
1355 for (int i = 0; i < regions; ++i)
1356 mlirRegions[i] = mlirRegionCreate();
1357 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1358 mlirRegions.data());
1359 }
1360
1361 // Construct the operation.
1362 PyMlirContext::ErrorCapture errors(location.getContext());
1363 MlirOperation operation = mlirOperationCreate(&state);
1364 if (!operation.ptr)
1365 throw MLIRError("Operation creation failed", errors.take());
1366 PyOperationRef created =
1367 PyOperation::createDetached(location.getContext(), operation);
1368 maybeInsertOperation(created, maybeIp);
1369
1370 return created.getObject();
1371}
1372
1373nb::object PyOperation::clone(const nb::object &maybeIp) {
1374 MlirOperation clonedOperation = mlirOperationClone(operation);
1375 PyOperationRef cloned =
1376 PyOperation::createDetached(getContext(), clonedOperation);
1377 maybeInsertOperation(cloned, maybeIp);
1378
1379 return cloned->createOpView();
1380}
1381
1382nb::object PyOperation::createOpView() {
1383 checkValid();
1384 MlirIdentifier ident = mlirOperationGetName(get());
1385 MlirStringRef identStr = mlirIdentifierStr(ident);
1386 auto operationCls = PyGlobals::get().lookupOperationClass(
1387 std::string_view(identStr.data, identStr.length));
1388 if (operationCls)
1389 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1390 return nb::cast(PyOpView(getRef().getObject()));
1391}
1392
1393void PyOperation::erase() {
1395 setInvalid();
1396 mlirOperationDestroy(operation);
1397}
1398
1399void PyOpResult::bindDerived(ClassTy &c) {
1400 c.def_prop_ro(
1401 "owner",
1402 [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
1403 assert(mlirOperationEqual(self.getParentOperation()->get(),
1404 mlirOpResultGetOwner(self.get())) &&
1405 "expected the owner of the value in Python to match that in "
1406 "the IR");
1407 return self.getParentOperation()->createOpView();
1408 },
1409 "Returns the operation that produces this result.");
1410 c.def_prop_ro(
1411 "result_number",
1412 [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
1413 "Returns the position of this result in the operation's result list.");
1415
1416/// Returns the list of types of the values held by container.
1417template <typename Container>
1418static std::vector<nb::typed<nb::object, PyType>>
1419getValueTypes(Container &container, PyMlirContextRef &context) {
1420 std::vector<nb::typed<nb::object, PyType>> result;
1421 result.reserve(container.size());
1422 for (int i = 0, e = container.size(); i < e; ++i) {
1423 result.push_back(PyType(context->getRef(),
1424 mlirValueGetType(container.getElement(i).get()))
1426 }
1427 return result;
1428}
1429
1431 intptr_t length, intptr_t step)
1432 : Sliceable(startIndex,
1434 : length,
1435 step),
1436 operation(std::move(operation)) {}
1437
1438void PyOpResultList::bindDerived(ClassTy &c) {
1439 c.def_prop_ro(
1440 "types",
1441 [](PyOpResultList &self) {
1442 return getValueTypes(self, self.operation->getContext());
1443 },
1444 "Returns a list of types for all results in this result list.");
1445 c.def_prop_ro(
1446 "owner",
1447 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1448 return self.operation->createOpView();
1449 },
1450 "Returns the operation that owns this result list.");
1451}
1452
1453intptr_t PyOpResultList::getRawNumElements() {
1454 operation->checkValid();
1455 return mlirOperationGetNumResults(operation->get());
1456}
1457
1458PyOpResult PyOpResultList::getRawElement(intptr_t index) {
1459 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1460 return PyOpResult(value);
1461}
1462
1463PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
1464 intptr_t step) const {
1465 return PyOpResultList(operation, startIndex, length, step);
1466}
1468//------------------------------------------------------------------------------
1469// PyOpView
1470//------------------------------------------------------------------------------
1471
1472static void populateResultTypes(std::string_view name, nb::list resultTypeList,
1473 const nb::object &resultSegmentSpecObj,
1474 std::vector<int32_t> &resultSegmentLengths,
1475 std::vector<PyType *> &resultTypes) {
1476 resultTypes.reserve(resultTypeList.size());
1477 if (resultSegmentSpecObj.is_none()) {
1478 // Non-variadic result unpacking.
1479 for (const auto &it : llvm::enumerate(resultTypeList)) {
1480 try {
1481 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1482 if (!resultTypes.back())
1483 throw nb::cast_error();
1484 } catch (nb::cast_error &err) {
1485 throw nb::value_error(join("Result ", it.index(), " of operation \"",
1486 name, "\" must be a Type (", err.what(), ")")
1487 .c_str());
1488 }
1489 }
1490 } else {
1491 // Sized result unpacking.
1492 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1493 if (resultSegmentSpec.size() != resultTypeList.size()) {
1494 throw nb::value_error(
1495 join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
1496 " result segments but was provided ", resultTypeList.size())
1497 .c_str());
1498 }
1499 resultSegmentLengths.reserve(resultTypeList.size());
1500 for (const auto &it :
1501 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1502 int segmentSpec = std::get<1>(it.value());
1503 if (segmentSpec == 1 || segmentSpec == 0) {
1504 // Unpack unary element.
1505 try {
1506 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1507 if (resultType) {
1508 resultTypes.push_back(resultType);
1509 resultSegmentLengths.push_back(1);
1510 } else if (segmentSpec == 0) {
1511 // Allowed to be optional.
1512 resultSegmentLengths.push_back(0);
1513 } else {
1514 throw nb::value_error(
1515 join("Result ", it.index(), " of operation \"", name,
1516 "\" must be a Type (was None and result is not optional)")
1517 .c_str());
1518 }
1519 } catch (nb::cast_error &err) {
1520 throw nb::value_error(join("Result ", it.index(), " of operation \"",
1521 name, "\" must be a Type (", err.what(),
1522 ")")
1523 .c_str());
1524 }
1525 } else if (segmentSpec == -1) {
1526 // Unpack sequence by appending.
1527 try {
1528 if (std::get<0>(it.value()).is_none()) {
1529 // Treat it as an empty list.
1530 resultSegmentLengths.push_back(0);
1531 } else {
1532 // Unpack the list.
1533 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1534 for (nb::handle segmentItem : segment) {
1535 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1536 if (!resultTypes.back()) {
1537 throw nb::type_error("contained a None item");
1538 }
1539 }
1540 resultSegmentLengths.push_back(nb::len(segment));
1541 }
1542 } catch (std::exception &err) {
1543 // NOTE: Sloppy to be using a catch-all here, but there are at least
1544 // three different unrelated exceptions that can be thrown in the
1545 // above "casts". Just keep the scope above small and catch them all.
1546 throw nb::value_error(join("Result ", it.index(), " of operation \"",
1547 name, "\" must be a Sequence of Types (",
1548 err.what(), ")")
1549 .c_str());
1550 }
1551 } else {
1552 throw nb::value_error("Unexpected segment spec");
1554 }
1555 }
1556}
1557
1558MlirValue getUniqueResult(MlirOperation operation) {
1559 auto numResults = mlirOperationGetNumResults(operation);
1560 if (numResults != 1) {
1561 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1562 throw nb::value_error(
1563 join("Cannot call .result on operation ",
1564 std::string_view(name.data, name.length), " which has ",
1565 numResults,
1566 " results (it is only valid for operations with a "
1567 "single result)")
1568 .c_str());
1569 }
1570 return mlirOperationGetResult(operation, 0);
1571}
1572
1573static MlirValue getOpResultOrValue(nb::handle operand) {
1574 if (operand.is_none()) {
1575 throw nb::value_error("contained a None item");
1576 }
1577 PyOperationBase *op;
1578 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1579 return getUniqueResult(op->getOperation());
1580 }
1581 PyOpResultList *opResultList;
1582 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1583 return getUniqueResult(opResultList->getOperation()->get());
1584 }
1585 PyValue *value;
1586 if (nb::try_cast<PyValue *>(operand, value)) {
1587 return value->get();
1588 }
1589 throw nb::value_error("is not a Value");
1590}
1591
1592nb::object PyOpView::buildGeneric(
1593 std::string_view name, std::tuple<int, bool> opRegionSpec,
1594 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1595 std::optional<nb::list> resultTypeList, nb::list operandList,
1596 std::optional<nb::dict> attributes,
1597 std::optional<std::vector<PyBlock *>> successors,
1598 std::optional<int> regions, PyLocation &location,
1599 const nb::object &maybeIp) {
1600 PyMlirContextRef context = location.getContext();
1601
1602 // Class level operation construction metadata.
1603 // Operand and result segment specs are either none, which does no
1604 // variadic unpacking, or a list of ints with segment sizes, where each
1605 // element is either a positive number (typically 1 for a scalar) or -1 to
1606 // indicate that it is derived from the length of the same-indexed operand
1607 // or result (implying that it is a list at that position).
1608 std::vector<int32_t> operandSegmentLengths;
1609 std::vector<int32_t> resultSegmentLengths;
1610
1611 // Validate/determine region count.
1612 int opMinRegionCount = std::get<0>(opRegionSpec);
1613 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1614 if (!regions) {
1615 regions = opMinRegionCount;
1616 }
1617 if (*regions < opMinRegionCount) {
1618 throw nb::value_error(join("Operation \"", name,
1619 "\" requires a minimum of ", opMinRegionCount,
1620 " regions but was built with regions=", *regions)
1621 .c_str());
1622 }
1623 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1624 throw nb::value_error(join("Operation \"", name,
1625 "\" requires a maximum of ", opMinRegionCount,
1626 " regions but was built with regions=", *regions)
1627 .c_str());
1628 }
1629
1630 // Unpack results.
1631 std::vector<PyType *> resultTypes;
1632 if (resultTypeList.has_value()) {
1633 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1634 resultSegmentLengths, resultTypes);
1635 }
1636
1637 // Unpack operands.
1638 std::vector<MlirValue> operands;
1639 operands.reserve(operands.size());
1640 if (operandSegmentSpecObj.is_none()) {
1641 // Non-sized operand unpacking.
1642 for (const auto &it : llvm::enumerate(operandList)) {
1643 try {
1644 operands.push_back(getOpResultOrValue(it.value()));
1645 } catch (nb::builtin_exception &err) {
1646 throw nb::value_error(join("Operand ", it.index(), " of operation \"",
1647 name, "\" must be a Value (", err.what(),
1648 ")")
1649 .c_str());
1650 }
1651 }
1652 } else {
1653 // Sized operand unpacking.
1654 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1655 if (operandSegmentSpec.size() != operandList.size()) {
1656 throw nb::value_error(
1657 join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
1658 "operand segments but was provided ", operandList.size())
1659 .c_str());
1660 }
1661 operandSegmentLengths.reserve(operandList.size());
1662 for (const auto &it :
1663 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1664 int segmentSpec = std::get<1>(it.value());
1665 if (segmentSpec == 1 || segmentSpec == 0) {
1666 // Unpack unary element.
1667 auto &operand = std::get<0>(it.value());
1668 if (!operand.is_none()) {
1669 try {
1670
1671 operands.push_back(getOpResultOrValue(operand));
1672 } catch (nb::builtin_exception &err) {
1673 throw nb::value_error(join("Operand ", it.index(),
1674 " of operation \"", name,
1675 "\" must be a Value (", err.what(), ")")
1676 .c_str());
1677 }
1678
1679 operandSegmentLengths.push_back(1);
1680 } else if (segmentSpec == 0) {
1681 // Allowed to be optional.
1682 operandSegmentLengths.push_back(0);
1683 } else {
1684 throw nb::value_error(
1685 join("Operand ", it.index(), " of operation \"", name,
1686 "\" must be a Value (was None and operand is not optional)")
1687 .c_str());
1688 }
1689 } else if (segmentSpec == -1) {
1690 // Unpack sequence by appending.
1691 try {
1692 if (std::get<0>(it.value()).is_none()) {
1693 // Treat it as an empty list.
1694 operandSegmentLengths.push_back(0);
1695 } else {
1696 // Unpack the list.
1697 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1698 for (nb::handle segmentItem : segment) {
1699 operands.push_back(getOpResultOrValue(segmentItem));
1700 }
1701 operandSegmentLengths.push_back(nb::len(segment));
1702 }
1703 } catch (std::exception &err) {
1704 // NOTE: Sloppy to be using a catch-all here, but there are at least
1705 // three different unrelated exceptions that can be thrown in the
1706 // above "casts". Just keep the scope above small and catch them all.
1707 throw nb::value_error(join("Operand ", it.index(), " of operation \"",
1708 name, "\" must be a Sequence of Values (",
1709 err.what(), ")")
1710 .c_str());
1711 }
1712 } else {
1713 throw nb::value_error("Unexpected segment spec");
1714 }
1715 }
1716 }
1717
1718 // Merge operand/result segment lengths into attributes if needed.
1719 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1720 // Dup.
1721 if (attributes) {
1722 attributes = nb::dict(*attributes);
1723 } else {
1724 attributes = nb::dict();
1725 }
1726 if (attributes->contains("resultSegmentSizes") ||
1727 attributes->contains("operandSegmentSizes")) {
1728 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1729 "'operandSegmentSizes' attribute is unsupported. "
1730 "Use Operation.create for such low-level access.");
1731 }
1732
1733 // Add resultSegmentSizes attribute.
1734 if (!resultSegmentLengths.empty()) {
1735 MlirAttribute segmentLengthAttr =
1736 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1737 resultSegmentLengths.data());
1738 (*attributes)["resultSegmentSizes"] =
1739 PyAttribute(context, segmentLengthAttr);
1740 }
1741
1742 // Add operandSegmentSizes attribute.
1743 if (!operandSegmentLengths.empty()) {
1744 MlirAttribute segmentLengthAttr =
1745 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1746 operandSegmentLengths.data());
1747 (*attributes)["operandSegmentSizes"] =
1748 PyAttribute(context, segmentLengthAttr);
1749 }
1750 }
1751
1752 // Delegate to create.
1753 return PyOperation::create(name,
1754 /*results=*/std::move(resultTypes),
1755 /*operands=*/operands.data(),
1756 /*numOperands=*/operands.size(),
1757 /*attributes=*/std::move(attributes),
1758 /*successors=*/std::move(successors),
1759 /*regions=*/*regions, location, maybeIp,
1760 !resultTypeList);
1761}
1762
1763nb::object PyOpView::constructDerived(const nb::object &cls,
1764 const nb::object &operation) {
1765 nb::handle opViewType = nb::type<PyOpView>();
1766 nb::object instance = cls.attr("__new__")(cls);
1767 opViewType.attr("__init__")(instance, operation);
1768 return instance;
1769}
1770
1771PyOpView::PyOpView(const nb::object &operationObject)
1772 // Casting through the PyOperationBase base-class and then back to the
1773 // Operation lets us accept any PyOperationBase subclass.
1774 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1775 operationObject(operation.getRef().getObject()) {}
1777//------------------------------------------------------------------------------
1778// PyAsmState
1779//------------------------------------------------------------------------------
1780
1781PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
1782 flags = mlirOpPrintingFlagsCreate();
1783 // The OpPrintingFlags are not exposed Python side, create locally and
1784 // associate lifetime with the state.
1785 if (useLocalScope)
1787 state = mlirAsmStateCreateForValue(value, flags);
1788}
1789
1790PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
1791 flags = mlirOpPrintingFlagsCreate();
1792 // The OpPrintingFlags are not exposed Python side, create locally and
1793 // associate lifetime with the state.
1794 if (useLocalScope)
1796 state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
1797}
1799//------------------------------------------------------------------------------
1800// PyInsertionPoint.
1801//------------------------------------------------------------------------------
1802
1803PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1806 : refOperation(beforeOperationBase.getOperation().getRef()),
1807 block((*refOperation)->getBlock()) {}
1808
1810 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1811
1812void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1813 PyOperation &operation = operationBase.getOperation();
1814 if (operation.isAttached())
1815 throw nb::value_error(
1816 "Attempt to insert operation that is already attached");
1817 block.getParentOperation()->checkValid();
1818 MlirOperation beforeOp = {nullptr};
1819 if (refOperation) {
1820 // Insert before operation.
1821 (*refOperation)->checkValid();
1822 beforeOp = (*refOperation)->get();
1823 } else {
1824 // Insert at end (before null) is only valid if the block does not
1825 // already end in a known terminator (violating this will cause assertion
1826 // failures later).
1827 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1828 throw nb::index_error("Cannot insert operation at the end of a block "
1829 "that already has a terminator. Did you mean to "
1830 "use 'InsertionPoint.at_block_terminator(block)' "
1831 "versus 'InsertionPoint(block)'?");
1832 }
1834 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1835 operation.setAttached();
1836}
1837
1839 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1840 if (mlirOperationIsNull(firstOp)) {
1841 // Just insert at end.
1842 return PyInsertionPoint(block);
1843 }
1844
1845 // Insert before first op.
1847 block.getParentOperation()->getContext(), firstOp);
1848 return PyInsertionPoint{block, std::move(firstOpRef)};
1849}
1850
1852 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1853 if (mlirOperationIsNull(terminator))
1854 throw nb::value_error("Block has no terminator");
1856 block.getParentOperation()->getContext(), terminator);
1857 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1858}
1859
1861 PyOperation &operation = op.getOperation();
1862 PyBlock block = operation.getBlock();
1863 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
1864 if (mlirOperationIsNull(nextOperation))
1865 return PyInsertionPoint(block);
1867 block.getParentOperation()->getContext(), nextOperation);
1868 return PyInsertionPoint{block, std::move(nextOpRef)};
1869}
1870
1871size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
1873nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
1874 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
1875}
1876
1877void PyInsertionPoint::contextExit(const nb::object &excType,
1878 const nb::object &excVal,
1879 const nb::object &excTb) {
1881}
1883//------------------------------------------------------------------------------
1884// PyAttribute.
1885//------------------------------------------------------------------------------
1887bool PyAttribute::operator==(const PyAttribute &other) const {
1888 return mlirAttributeEqual(attr, other.attr);
1889}
1891nb::object PyAttribute::getCapsule() {
1892 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
1893}
1894
1895PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
1896 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1897 if (mlirAttributeIsNull(rawAttr))
1898 throw nb::python_error();
1899 return PyAttribute(
1901}
1902
1903nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
1904 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
1905 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1906 "mlirTypeID was expected to be non-null.");
1907 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1908 mlirTypeID, mlirAttributeGetDialect(this->get()));
1909 // nb::rv_policy::move means use std::move to move the return value
1910 // contents into a new instance that will be owned by Python.
1911 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1912 if (!typeCaster)
1913 return thisObj;
1914 return typeCaster.value()(thisObj);
1915}
1917//------------------------------------------------------------------------------
1918// PyNamedAttribute.
1919//------------------------------------------------------------------------------
1920
1921PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1922 : ownedName(new std::string(std::move(ownedName))) {
1925 toMlirStringRef(*this->ownedName)),
1926 attr);
1927}
1929//------------------------------------------------------------------------------
1930// PyType.
1931//------------------------------------------------------------------------------
1933bool PyType::operator==(const PyType &other) const {
1934 return mlirTypeEqual(type, other.type);
1935}
1937nb::object PyType::getCapsule() {
1938 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
1939}
1940
1941PyType PyType::createFromCapsule(nb::object capsule) {
1942 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1943 if (mlirTypeIsNull(rawType))
1944 throw nb::python_error();
1946 rawType);
1947}
1948
1949nb::typed<nb::object, PyType> PyType::maybeDownCast() {
1950 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
1951 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1952 "mlirTypeID was expected to be non-null.");
1953 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1954 mlirTypeID, mlirTypeGetDialect(this->get()));
1955 // nb::rv_policy::move means use std::move to move the return value
1956 // contents into a new instance that will be owned by Python.
1957 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1958 if (!typeCaster)
1959 return thisObj;
1960 return typeCaster.value()(thisObj);
1961}
1963//------------------------------------------------------------------------------
1964// PyTypeID.
1965//------------------------------------------------------------------------------
1967nb::object PyTypeID::getCapsule() {
1968 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
1969}
1970
1971PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
1972 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1973 if (mlirTypeIDIsNull(mlirTypeID))
1974 throw nb::python_error();
1975 return PyTypeID(mlirTypeID);
1976}
1977bool PyTypeID::operator==(const PyTypeID &other) const {
1978 return mlirTypeIDEqual(typeID, other.typeID);
1979}
1981//------------------------------------------------------------------------------
1982// PyValue and subclasses.
1983//------------------------------------------------------------------------------
1985nb::object PyValue::getCapsule() {
1986 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
1987}
1988
1989static PyOperationRef getValueOwnerRef(MlirValue value) {
1990 MlirOperation owner;
1991 if (mlirValueIsAOpResult(value))
1992 owner = mlirOpResultGetOwner(value);
1993 else if (mlirValueIsABlockArgument(value))
1995 else
1996 assert(false && "Value must be an block arg or op result.");
1997 if (mlirOperationIsNull(owner))
1998 throw nb::python_error();
1999 MlirContext ctx = mlirOperationGetContext(owner);
2001}
2002
2003nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
2005 MlirType type = mlirValueGetType(get());
2006 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2007 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2008 "mlirTypeID was expected to be non-null.");
2009 std::optional<nb::callable> valueCaster =
2011 // nb::rv_policy::move means use std::move to move the return value
2012 // contents into a new instance that will be owned by Python.
2013 nb::object thisObj;
2014 if (mlirValueIsAOpResult(value))
2015 thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
2016 else if (mlirValueIsABlockArgument(value))
2017 thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
2018 else
2019 assert(false && "Value must be an block arg or op result.");
2020 if (valueCaster)
2021 return valueCaster.value()(thisObj);
2022 return thisObj;
2023}
2024
2025PyValue PyValue::createFromCapsule(nb::object capsule) {
2026 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2027 if (mlirValueIsNull(value))
2028 throw nb::python_error();
2029 PyOperationRef ownerRef = getValueOwnerRef(value);
2030 return PyValue(ownerRef, value);
2031}
2033//------------------------------------------------------------------------------
2034// PySymbolTable.
2035//------------------------------------------------------------------------------
2036
2038 : operation(operation.getOperation().getRef()) {
2039 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2040 if (mlirSymbolTableIsNull(symbolTable)) {
2041 throw nb::type_error("Operation is not a Symbol Table.");
2042 }
2043}
2044
2045nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2046 operation->checkValid();
2047 MlirOperation symbol = mlirSymbolTableLookup(
2048 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2049 if (mlirOperationIsNull(symbol))
2050 throw nb::key_error(
2051 join("Symbol '", name, "' not in the symbol table.").c_str());
2052
2053 return PyOperation::forOperation(operation->getContext(), symbol,
2054 operation.getObject())
2055 ->createOpView();
2056}
2057
2059 operation->checkValid();
2060 symbol.getOperation().checkValid();
2061 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2062 // The operation is also erased, so we must invalidate it. There may be Python
2063 // references to this operation so we don't want to delete it from the list of
2064 // live operations here.
2065 symbol.getOperation().valid = false;
2066}
2067
2068void PySymbolTable::dunderDel(const std::string &name) {
2069 nb::object operation = dunderGetItem(name);
2070 erase(nb::cast<PyOperationBase &>(operation));
2071}
2072
2074 operation->checkValid();
2075 symbol.getOperation().checkValid();
2076 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2078 if (mlirAttributeIsNull(symbolAttr))
2079 throw nb::value_error("Expected operation to have a symbol name.");
2081 symbol.getOperation().getContext(),
2082 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2083}
2084
2086 // Op must already be a symbol.
2087 PyOperation &operation = symbol.getOperation();
2088 operation.checkValid();
2090 MlirAttribute existingNameAttr =
2091 mlirOperationGetAttributeByName(operation.get(), attrName);
2092 if (mlirAttributeIsNull(existingNameAttr))
2093 throw nb::value_error("Expected operation to have a symbol name.");
2094 return PyStringAttribute(symbol.getOperation().getContext(),
2095 existingNameAttr);
2096}
2097
2099 const std::string &name) {
2100 // Op must already be a symbol.
2101 PyOperation &operation = symbol.getOperation();
2102 operation.checkValid();
2104 MlirAttribute existingNameAttr =
2105 mlirOperationGetAttributeByName(operation.get(), attrName);
2106 if (mlirAttributeIsNull(existingNameAttr))
2107 throw nb::value_error("Expected operation to have a symbol name.");
2108 MlirAttribute newNameAttr =
2109 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2110 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2111}
2112
2114 PyOperation &operation = symbol.getOperation();
2115 operation.checkValid();
2117 MlirAttribute existingVisAttr =
2118 mlirOperationGetAttributeByName(operation.get(), attrName);
2119 if (mlirAttributeIsNull(existingVisAttr))
2120 throw nb::value_error("Expected operation to have a symbol visibility.");
2121 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2122}
2123
2125 const std::string &visibility) {
2126 if (visibility != "public" && visibility != "private" &&
2127 visibility != "nested")
2128 throw nb::value_error(
2129 "Expected visibility to be 'public', 'private' or 'nested'");
2130 PyOperation &operation = symbol.getOperation();
2131 operation.checkValid();
2133 MlirAttribute existingVisAttr =
2134 mlirOperationGetAttributeByName(operation.get(), attrName);
2135 if (mlirAttributeIsNull(existingVisAttr))
2136 throw nb::value_error("Expected operation to have a symbol visibility.");
2137 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2138 toMlirStringRef(visibility));
2139 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2140}
2141
2142void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2143 const std::string &newSymbol,
2144 PyOperationBase &from) {
2145 PyOperation &fromOperation = from.getOperation();
2146 fromOperation.checkValid();
2148 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2150
2151 throw nb::value_error("Symbol rename failed");
2152}
2153
2155 bool allSymUsesVisible,
2156 nb::object callback) {
2157 PyOperation &fromOperation = from.getOperation();
2158 fromOperation.checkValid();
2159 struct UserData {
2160 PyMlirContextRef context;
2161 nb::object callback;
2162 bool gotException;
2163 std::string exceptionWhat;
2164 nb::object exceptionType;
2165 };
2166 UserData userData{
2167 fromOperation.getContext(), std::move(callback), false, {}, {}};
2169 fromOperation.get(), allSymUsesVisible,
2170 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2171 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2172 auto pyFoundOp =
2173 PyOperation::forOperation(calleeUserData->context, foundOp);
2174 if (calleeUserData->gotException)
2175 return;
2176 try {
2177 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2178 } catch (nb::python_error &e) {
2179 calleeUserData->gotException = true;
2180 calleeUserData->exceptionWhat = e.what();
2181 calleeUserData->exceptionType = nb::borrow(e.type());
2182 }
2183 },
2184 static_cast<void *>(&userData));
2185 if (userData.gotException) {
2186 std::string message("Exception raised in callback: ");
2187 message.append(userData.exceptionWhat);
2188 throw std::runtime_error(message);
2189 }
2190}
2191
2192void PyBlockArgument::bindDerived(ClassTy &c) {
2193 c.def_prop_ro(
2194 "owner",
2195 [](PyBlockArgument &self) {
2196 return PyBlock(self.getParentOperation(),
2198 },
2199 "Returns the block that owns this argument.");
2200 c.def_prop_ro(
2201 "arg_number",
2202 [](PyBlockArgument &self) {
2203 return mlirBlockArgumentGetArgNumber(self.get());
2204 },
2205 "Returns the position of this argument in the block's argument list.");
2206 c.def(
2207 "set_type",
2208 [](PyBlockArgument &self, PyType type) {
2209 return mlirBlockArgumentSetType(self.get(), type);
2210 },
2211 "type"_a, "Sets the type of this block argument.");
2212 c.def(
2213 "set_location",
2214 [](PyBlockArgument &self, PyLocation loc) {
2216 },
2217 "loc"_a, "Sets the location of this block argument.");
2218}
2219
2221 MlirBlock block, intptr_t startIndex,
2224 length == -1 ? mlirBlockGetNumArguments(block) : length, step),
2225 operation(std::move(operation)), block(block) {}
2226
2227void PyBlockArgumentList::bindDerived(ClassTy &c) {
2228 c.def_prop_ro(
2229 "types",
2230 [](PyBlockArgumentList &self) {
2231 return getValueTypes(self, self.operation->getContext());
2232 },
2233 "Returns a list of types for all arguments in this argument list.");
2234}
2235
2236intptr_t PyBlockArgumentList::getRawNumElements() {
2237 operation->checkValid();
2238 return mlirBlockGetNumArguments(block);
2239}
2240
2241PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
2242 MlirValue argument = mlirBlockGetArgument(block, pos);
2243 return PyBlockArgument(operation, argument);
2244}
2245
2246PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
2248 intptr_t step) const {
2249 return PyBlockArgumentList(operation, block, startIndex, length, step);
2250}
2251
2253 intptr_t length, intptr_t step)
2254 : Sliceable(startIndex,
2256 : length,
2257 step),
2258 operation(operation) {}
2259
2262 mlirOperationSetOperand(operation->get(), index, value.get());
2263}
2264
2265void PyOpOperandList::bindDerived(ClassTy &c) {
2266 c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
2267 "Sets the operand at the specified index to a new value.");
2268}
2269
2270intptr_t PyOpOperandList::getRawNumElements() {
2271 operation->checkValid();
2272 return mlirOperationGetNumOperands(operation->get());
2273}
2274
2275PyValue PyOpOperandList::getRawElement(intptr_t pos) {
2276 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2277 PyOperationRef pyOwner = getValueOwnerRef(operand);
2278 return PyValue(pyOwner, operand);
2279}
2280
2282 intptr_t step) const {
2283 return PyOpOperandList(operation, startIndex, length, step);
2284}
2285
2287 intptr_t length, intptr_t step)
2288 : Sliceable(startIndex,
2290 : length,
2291 step),
2292 operation(operation) {}
2293
2296 mlirOperationSetSuccessor(operation->get(), index, block.get());
2297}
2298
2299void PyOpSuccessors::bindDerived(ClassTy &c) {
2300 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
2301 "Sets the successor block at the specified index.");
2302}
2303
2304intptr_t PyOpSuccessors::getRawNumElements() {
2305 operation->checkValid();
2306 return mlirOperationGetNumSuccessors(operation->get());
2307}
2308
2309PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
2310 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2311 return PyBlock(operation, block);
2312}
2313
2315 intptr_t step) const {
2316 return PyOpSuccessors(operation, startIndex, length, step);
2317}
2318
2320 intptr_t startIndex, intptr_t length,
2321 intptr_t step)
2322 : Sliceable(startIndex,
2323 length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
2324 step),
2325 operation(operation), block(block) {}
2326
2327intptr_t PyBlockSuccessors::getRawNumElements() {
2328 block.checkValid();
2329 return mlirBlockGetNumSuccessors(block.get());
2330}
2331
2332PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
2333 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2334 return PyBlock(operation, block);
2335}
2336
2338 intptr_t step) const {
2339 return PyBlockSuccessors(block, operation, startIndex, length, step);
2340}
2341
2343 PyOperationRef operation,
2344 intptr_t startIndex, intptr_t length,
2345 intptr_t step)
2346 : Sliceable(startIndex,
2347 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2348 : length,
2349 step),
2350 operation(operation), block(block) {}
2351
2352intptr_t PyBlockPredecessors::getRawNumElements() {
2353 block.checkValid();
2354 return mlirBlockGetNumPredecessors(block.get());
2355}
2356
2357PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
2358 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2359 return PyBlock(operation, block);
2360}
2361
2362PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
2363 intptr_t length,
2364 intptr_t step) const {
2365 return PyBlockPredecessors(block, operation, startIndex, length, step);
2366}
2367
2368nb::typed<nb::object, PyAttribute>
2369PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
2370 MlirAttribute attr =
2372 if (mlirAttributeIsNull(attr)) {
2373 throw nb::key_error("attempt to access a non-existent attribute");
2375 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2376}
2377
2378nb::typed<nb::object, std::optional<PyAttribute>>
2379PyOpAttributeMap::get(const std::string &key, nb::object defaultValue) {
2380 MlirAttribute attr =
2382 if (mlirAttributeIsNull(attr))
2383 return defaultValue;
2384 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2385}
2386
2388 if (index < 0) {
2389 index += dunderLen();
2390 }
2391 if (index < 0 || index >= dunderLen()) {
2392 throw nb::index_error("attempt to access out of bounds attribute");
2393 }
2394 MlirNamedAttribute namedAttr =
2395 mlirOperationGetAttribute(operation->get(), index);
2396 return PyNamedAttribute(
2397 namedAttr.attribute,
2398 std::string(mlirIdentifierStr(namedAttr.name).data,
2399 mlirIdentifierStr(namedAttr.name).length));
2400}
2401
2402void PyOpAttributeMap::dunderSetItem(const std::string &name,
2403 const PyAttribute &attr) {
2404 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2405 attr);
2406}
2407
2408void PyOpAttributeMap::dunderDelItem(const std::string &name) {
2409 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2411 if (!removed)
2412 throw nb::key_error("attempt to delete a non-existent attribute");
2413}
2416 return mlirOperationGetNumAttributes(operation->get());
2417}
2418
2419bool PyOpAttributeMap::dunderContains(const std::string &name) {
2420 return !mlirAttributeIsNull(
2421 mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
2422}
2423
2425 MlirOperation op, std::function<void(MlirStringRef, MlirAttribute)> fn) {
2427 for (intptr_t i = 0; i < n; ++i) {
2430 fn(name, na.attribute);
2431 }
2432}
2433
2434void PyOpAttributeMap::bind(nb::module_ &m) {
2435 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2436 .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
2437 "Checks if an attribute with the given name exists in the map.")
2438 .def("__len__", &PyOpAttributeMap::dunderLen,
2439 "Returns the number of attributes in the map.")
2440 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
2441 "Gets an attribute by name.")
2442 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
2443 "Gets a named attribute by index.")
2444 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
2445 "Sets an attribute with the given name.")
2446 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
2447 "Deletes an attribute with the given name.")
2448 .def("get", &PyOpAttributeMap::get, nb::arg("key"),
2449 nb::arg("default") = nb::none(),
2450 "Gets an attribute by name or the default value, if it does not "
2451 "exist.")
2452 .def(
2453 "__iter__",
2454 [](PyOpAttributeMap &self) {
2455 nb::list keys;
2457 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2458 keys.append(nb::str(name.data, name.length));
2459 });
2460 return nb::iter(keys);
2461 },
2462 "Iterates over attribute names.")
2463 .def(
2464 "keys",
2465 [](PyOpAttributeMap &self) {
2466 nb::list out;
2468 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2469 out.append(nb::str(name.data, name.length));
2470 });
2471 return out;
2472 },
2473 "Returns a list of attribute names.")
2474 .def(
2475 "values",
2476 [](PyOpAttributeMap &self) {
2477 nb::list out;
2479 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
2480 out.append(PyAttribute(self.operation->getContext(), attr)
2481 .maybeDownCast());
2482 });
2483 return out;
2484 },
2485 "Returns a list of attribute values.")
2486 .def(
2487 "items",
2488 [](PyOpAttributeMap &self) {
2489 nb::list out;
2491 self.operation->get(),
2492 [&](MlirStringRef name, MlirAttribute attr) {
2493 out.append(nb::make_tuple(
2494 nb::str(name.data, name.length),
2495 PyAttribute(self.operation->getContext(), attr)
2496 .maybeDownCast()));
2497 });
2498 return out;
2499 },
2500 "Returns a list of `(name, attribute)` tuples.");
2501}
2502
2503} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
2504} // namespace python
2505} // namespace mlir
2506
2507namespace {
2508// see
2509// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
2510
2511#ifndef _Py_CAST
2512#define _Py_CAST(type, expr) ((type)(expr))
2513#endif
2514
2515// Static inline functions should use _Py_NULL rather than using directly NULL
2516// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
2517// _Py_NULL is defined as nullptr.
2518#ifndef _Py_NULL
2519#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
2520 (defined(__cplusplus) && __cplusplus >= 201103)
2521#define _Py_NULL nullptr
2522#else
2523#define _Py_NULL NULL
2524#endif
2525#endif
2526
2527// Python 3.10.0a3
2528#if PY_VERSION_HEX < 0x030A00A3
2529
2530// bpo-42262 added Py_XNewRef()
2531#if !defined(Py_XNewRef)
2532[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
2533 Py_XINCREF(obj);
2534 return obj;
2535}
2536#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
2537#endif
2538
2539// bpo-42262 added Py_NewRef()
2540#if !defined(Py_NewRef)
2541[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
2542 Py_INCREF(obj);
2543 return obj;
2544}
2545#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
2546#endif
2547
2548#endif // Python 3.10.0a3
2549
2550// Python 3.9.0b1
2551#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
2552
2553// bpo-40429 added PyThreadState_GetFrame()
2554PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
2555 assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
2556 return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
2557}
2558
2559// bpo-40421 added PyFrame_GetBack()
2560PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
2561 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2562 return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
2563}
2564
2565// bpo-40421 added PyFrame_GetCode()
2566PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
2567 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2568 assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
2569 return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
2570}
2571
2572#endif // Python 3.9.0b1
2573
2574using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
2575
2576MlirLocation tracebackToLocation(MlirContext ctx) {
2577 size_t framesLimit =
2579 // Use a thread_local here to avoid requiring a large amount of space.
2580 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2581 frames;
2582 size_t count = 0;
2583
2584 nb::gil_scoped_acquire acquire;
2585 PyThreadState *tstate = PyThreadState_GET();
2586 PyFrameObject *next;
2587 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2588 // In the increment expression:
2589 // 1. get the next prev frame;
2590 // 2. decrement the ref count on the current frame (in order that it can get
2591 // gc'd, along with any objects in its closure and etc);
2592 // 3. set current = next.
2593 for (; pyFrame != nullptr && count < framesLimit;
2594 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2595 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2596 auto fileNameStr =
2597 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2598 std::string_view fileName(fileNameStr);
2599 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2600 continue;
2601
2602 // co_qualname and PyCode_Addr2Location added in py3.11
2603#if PY_VERSION_HEX < 0x030B00F0
2604 std::string name =
2605 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2606 std::string_view funcName(name);
2607 int startLine = PyFrame_GetLineNumber(pyFrame);
2608 MlirLocation loc =
2609 mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2610#else
2611 std::string name =
2612 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2613 std::string_view funcName(name);
2614 int startLine, startCol, endLine, endCol;
2615 int lasti = PyFrame_GetLasti(pyFrame);
2616 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2617 &endCol)) {
2618 throw nb::python_error();
2619 }
2620 MlirLocation loc = mlirLocationFileLineColRangeGet(
2621 ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2622#endif
2623
2624 frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2625 ++count;
2626 }
2627 // When the loop breaks (after the last iter), current frame (if non-null)
2628 // is leaked without this.
2629 Py_XDECREF(pyFrame);
2630
2631 if (count == 0)
2632 return mlirLocationUnknownGet(ctx);
2633
2634 MlirLocation callee = frames[0];
2635 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2636 if (count == 1)
2637 return callee;
2638
2639 MlirLocation caller = frames[count - 1];
2640 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2641 for (int i = count - 2; i >= 1; i--)
2642 caller = mlirLocationCallSiteGet(frames[i], caller);
2643
2644 return mlirLocationCallSiteGet(callee, caller);
2645}
2646
2647PyLocation
2648maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2649 if (location.has_value())
2650 return location.value();
2651 if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2653
2654 PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2655 MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2657 return {ref, mlirLoc};
2658}
2659} // namespace
2660
2661namespace mlir {
2662namespace python {
2664
2665void populateRoot(nb::module_ &m) {
2666 m.attr("T") = nb::type_var("T");
2667 m.attr("U") = nb::type_var("U");
2668
2669 nb::class_<PyGlobals>(m, "_Globals")
2670 .def_prop_rw("dialect_search_modules",
2673 .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
2674 "module_name"_a)
2675 .def(
2676 "_check_dialect_module_loaded",
2677 [](PyGlobals &self, const std::string &dialectNamespace) {
2678 return self.loadDialectModule(dialectNamespace);
2679 },
2680 "dialect_namespace"_a)
2681 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
2682 "dialect_namespace"_a, "dialect_class"_a,
2683 "Testing hook for directly registering a dialect")
2684 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
2685 "operation_name"_a, "operation_class"_a, nb::kw_only(),
2686 "replace"_a = false,
2687 "Testing hook for directly registering an operation")
2688 .def("loc_tracebacks_enabled",
2689 [](PyGlobals &self) {
2690 return self.getTracebackLoc().locTracebacksEnabled();
2691 })
2692 .def("set_loc_tracebacks_enabled",
2693 [](PyGlobals &self, bool enabled) {
2695 })
2696 .def("loc_tracebacks_frame_limit",
2697 [](PyGlobals &self) {
2699 })
2700 .def("set_loc_tracebacks_frame_limit",
2701 [](PyGlobals &self, std::optional<int> n) {
2704 })
2705 .def("register_traceback_file_inclusion",
2706 [](PyGlobals &self, const std::string &filename) {
2708 })
2709 .def("register_traceback_file_exclusion",
2710 [](PyGlobals &self, const std::string &filename) {
2712 });
2713
2714 // Aside from making the globals accessible to python, having python manage
2715 // it is necessary to make sure it is destroyed (and releases its python
2716 // resources) properly.
2717 m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
2718
2719 // Registration decorators.
2720 m.def(
2721 "register_dialect",
2722 [](nb::type_object pyClass) {
2723 std::string dialectNamespace =
2724 nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
2725 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
2726 return pyClass;
2727 },
2728 "dialect_class"_a,
2729 "Class decorator for registering a custom Dialect wrapper");
2730 m.def(
2731 "register_operation",
2732 [](const nb::type_object &dialectClass, bool replace) -> nb::object {
2733 return nb::cpp_function(
2734 [dialectClass,
2735 replace](nb::type_object opClass) -> nb::type_object {
2736 std::string operationName =
2737 nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
2738 PyGlobals::get().registerOperationImpl(operationName, opClass,
2739 replace);
2740 // Dict-stuff the new opClass by name onto the dialect class.
2741 nb::object opClassName = opClass.attr("__name__");
2742 dialectClass.attr(opClassName) = opClass;
2743 return opClass;
2744 });
2745 },
2746 // clang-format off
2747 nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
2748 "-> typing.Callable[[type[T]], type[T]]"),
2749 // clang-format on
2750 "dialect_class"_a, nb::kw_only(), "replace"_a = false,
2751 "Produce a class decorator for registering an Operation class as part of "
2752 "a dialect");
2753 m.def(
2755 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2756 return nb::cpp_function([mlirTypeID, replace](
2757 nb::callable typeCaster) -> nb::object {
2758 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
2759 return typeCaster;
2760 });
2761 },
2762 // clang-format off
2763 nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2764 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2765 // clang-format on
2766 "typeid"_a, nb::kw_only(), "replace"_a = false,
2767 "Register a type caster for casting MLIR types to custom user types.");
2768 m.def(
2770 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2771 return nb::cpp_function(
2772 [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
2773 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
2774 replace);
2775 return valueCaster;
2776 });
2777 },
2778 // clang-format off
2779 nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2780 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2781 // clang-format on
2782 "typeid"_a, nb::kw_only(), "replace"_a = false,
2783 "Register a value caster for casting MLIR values to custom user values.");
2784}
2785
2786//------------------------------------------------------------------------------
2787// Populates the core exports of the 'ir' submodule.
2788//------------------------------------------------------------------------------
2789void populateIRCore(nb::module_ &m) {
2790 //----------------------------------------------------------------------------
2791 // Enums.
2792 //----------------------------------------------------------------------------
2793 nb::enum_<PyDiagnosticSeverity>(m, "DiagnosticSeverity")
2794 .value("ERROR", PyDiagnosticSeverity::Error)
2795 .value("WARNING", PyDiagnosticSeverity::Warning)
2796 .value("NOTE", PyDiagnosticSeverity::Note)
2797 .value("REMARK", PyDiagnosticSeverity::Remark);
2798
2799 nb::enum_<PyWalkOrder>(m, "WalkOrder")
2800 .value("PRE_ORDER", PyWalkOrder::PreOrder)
2801 .value("POST_ORDER", PyWalkOrder::PostOrder);
2802 nb::enum_<PyWalkResult>(m, "WalkResult")
2803 .value("ADVANCE", PyWalkResult::Advance)
2804 .value("INTERRUPT", PyWalkResult::Interrupt)
2805 .value("SKIP", PyWalkResult::Skip);
2806
2807 //----------------------------------------------------------------------------
2808 // Mapping of Diagnostics.
2809 //----------------------------------------------------------------------------
2810 nb::class_<PyDiagnostic>(m, "Diagnostic")
2811 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
2812 "Returns the severity of the diagnostic.")
2813 .def_prop_ro("location", &PyDiagnostic::getLocation,
2814 "Returns the location associated with the diagnostic.")
2815 .def_prop_ro("message", &PyDiagnostic::getMessage,
2816 "Returns the message text of the diagnostic.")
2817 .def_prop_ro("notes", &PyDiagnostic::getNotes,
2818 "Returns a tuple of attached note diagnostics.")
2819 .def(
2820 "__str__",
2821 [](PyDiagnostic &self) -> nb::str {
2822 if (!self.isValid())
2823 return nb::str("<Invalid Diagnostic>");
2824 return self.getMessage();
2825 },
2826 "Returns the diagnostic message as a string.");
2827
2828 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2829 .def(
2830 "__init__",
2832 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2833 },
2834 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
2835 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
2836 "The severity level of the diagnostic.")
2837 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
2838 "The location associated with the diagnostic.")
2839 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
2840 "The message text of the diagnostic.")
2841 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
2842 "List of attached note diagnostics.")
2843 .def(
2844 "__str__",
2845 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
2846 "Returns the diagnostic message as a string.");
2847
2848 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2849 .def("detach", &PyDiagnosticHandler::detach,
2850 "Detaches the diagnostic handler from the context.")
2851 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
2852 "Returns True if the handler is attached to a context.")
2853 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
2854 "Returns True if an error was encountered during diagnostic "
2855 "handling.")
2856 .def("__enter__", &PyDiagnosticHandler::contextEnter,
2857 "Enters the diagnostic handler as a context manager.")
2858 .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
2859 "exc_value"_a.none(), "traceback"_a.none(),
2860 "Exits the diagnostic handler context manager.");
2861
2862 // Expose DefaultThreadPool to python
2863 nb::class_<PyThreadPool>(m, "ThreadPool")
2864 .def(
2865 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
2866 "Creates a new thread pool with default concurrency.")
2867 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
2868 "Returns the maximum number of threads in the pool.")
2869 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
2870 "Returns the raw pointer to the LLVM thread pool as a string.");
2871
2872 nb::class_<PyMlirContext>(m, "Context")
2873 .def(
2874 "__init__",
2875 [](PyMlirContext &self) {
2876 MlirContext context = mlirContextCreateWithThreading(false);
2877 new (&self) PyMlirContext(context);
2878 },
2879 R"(
2880 Creates a new MLIR context.
2881
2882 The context is the top-level container for all MLIR objects. It owns the storage
2883 for types, attributes, locations, and other core IR objects. A context can be
2884 configured to allow or disallow unregistered dialects and can have dialects
2885 loaded on-demand.)")
2886 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
2887 "Gets the number of live Context objects.")
2888 .def(
2889 "_get_context_again",
2890 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
2892 return ref.releaseObject();
2893 },
2894 "Gets another reference to the same context.")
2895 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
2896 "Gets the number of live modules owned by this context.")
2898 "Gets a capsule wrapping the MlirContext.")
2901 "Creates a Context from a capsule wrapping MlirContext.")
2902 .def("__enter__", &PyMlirContext::contextEnter,
2903 "Enters the context as a context manager.")
2904 .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
2905 "exc_value"_a.none(), "traceback"_a.none(),
2906 "Exits the context manager.")
2907 .def_prop_ro_static(
2908 "current",
2909 [](nb::object & /*class*/)
2910 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
2912 if (!context)
2913 return {};
2914 return nb::cast(context);
2915 },
2916 nb::sig("def current(/) -> Context | None"),
2917 "Gets the Context bound to the current thread or returns None if no "
2918 "context is set.")
2919 .def_prop_ro(
2920 "dialects",
2921 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2922 "Gets a container for accessing dialects by name.")
2923 .def_prop_ro(
2924 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2925 "Alias for `dialects`.")
2926 .def(
2927 "get_dialect_descriptor",
2928 [=](PyMlirContext &self, std::string &name) {
2929 MlirDialect dialect = mlirContextGetOrLoadDialect(
2930 self.get(), {name.data(), name.size()});
2931 if (mlirDialectIsNull(dialect)) {
2932 throw nb::value_error(
2933 join("Dialect '", name, "' not found").c_str());
2934 }
2935 return PyDialectDescriptor(self.getRef(), dialect);
2936 },
2937 "dialect_name"_a,
2938 "Gets or loads a dialect by name, returning its descriptor object.")
2939 .def_prop_rw(
2940 "allow_unregistered_dialects",
2941 [](PyMlirContext &self) -> bool {
2942 return mlirContextGetAllowUnregisteredDialects(self.get());
2943 },
2944 [](PyMlirContext &self, bool value) {
2945 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2946 },
2947 "Controls whether unregistered dialects are allowed in this context.")
2948 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2949 "callback"_a,
2950 "Attaches a diagnostic handler that will receive callbacks.")
2951 .def(
2952 "enable_multithreading",
2953 [](PyMlirContext &self, bool enable) {
2954 mlirContextEnableMultithreading(self.get(), enable);
2955 },
2956 "enable"_a,
2957 R"(
2958 Enables or disables multi-threading support in the context.
2959
2960 Args:
2961 enable: Whether to enable (True) or disable (False) multi-threading.
2962 )")
2963 .def(
2964 "set_thread_pool",
2965 [](PyMlirContext &self, PyThreadPool &pool) {
2966 // we should disable multi-threading first before setting
2967 // new thread pool otherwise the assert in
2968 // MLIRContext::setThreadPool will be raised.
2969 mlirContextEnableMultithreading(self.get(), false);
2970 mlirContextSetThreadPool(self.get(), pool.get());
2971 },
2972 R"(
2973 Sets a custom thread pool for the context to use.
2974
2975 Args:
2976 pool: A ThreadPool object to use for parallel operations.
2977
2978 Note:
2979 Multi-threading is automatically disabled before setting the thread pool.)")
2980 .def(
2981 "get_num_threads",
2982 [](PyMlirContext &self) {
2983 return mlirContextGetNumThreads(self.get());
2984 },
2985 "Gets the number of threads in the context's thread pool.")
2986 .def(
2987 "_mlir_thread_pool_ptr",
2988 [](PyMlirContext &self) {
2989 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
2990 std::stringstream ss;
2991 ss << pool.ptr;
2992 return ss.str();
2993 },
2994 "Gets the raw pointer to the LLVM thread pool as a string.")
2995 .def(
2996 "is_registered_operation",
2997 [](PyMlirContext &self, std::string &name) {
2999 self.get(), MlirStringRef{name.data(), name.size()});
3000 },
3001 "operation_name"_a,
3002 R"(
3003 Checks whether an operation with the given name is registered.
3004
3005 Args:
3006 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3007
3008 Returns:
3009 True if the operation is registered, False otherwise.)")
3010 .def(
3011 "append_dialect_registry",
3012 [](PyMlirContext &self, PyDialectRegistry &registry) {
3013 mlirContextAppendDialectRegistry(self.get(), registry);
3014 },
3015 "registry"_a,
3016 R"(
3017 Appends the contents of a dialect registry to the context.
3018
3019 Args:
3020 registry: A DialectRegistry containing dialects to append.)")
3021 .def_prop_rw("emit_error_diagnostics",
3024 R"(
3025 Controls whether error diagnostics are emitted to diagnostic handlers.
3026
3027 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3028 .def(
3029 "load_all_available_dialects",
3030 [](PyMlirContext &self) {
3032 },
3033 R"(
3034 Loads all dialects available in the registry into the context.
3035
3036 This eagerly loads all dialects that have been registered, making them
3037 immediately available for use.)");
3038
3039 //----------------------------------------------------------------------------
3040 // Mapping of PyDialectDescriptor
3041 //----------------------------------------------------------------------------
3042 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3043 .def_prop_ro(
3044 "namespace",
3045 [](PyDialectDescriptor &self) {
3046 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3047 return nb::str(ns.data, ns.length);
3048 },
3049 "Returns the namespace of the dialect.")
3050 .def(
3051 "__repr__",
3052 [](PyDialectDescriptor &self) {
3053 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3054 std::string repr("<DialectDescriptor ");
3055 repr.append(ns.data, ns.length);
3056 repr.append(">");
3057 return repr;
3058 },
3059 nb::sig("def __repr__(self) -> str"),
3060 "Returns a string representation of the dialect descriptor.");
3061
3062 //----------------------------------------------------------------------------
3063 // Mapping of PyDialects
3064 //----------------------------------------------------------------------------
3065 nb::class_<PyDialects>(m, "Dialects")
3066 .def(
3067 "__getitem__",
3068 [=](PyDialects &self, std::string keyName) {
3069 MlirDialect dialect =
3070 self.getDialectForKey(keyName, /*attrError=*/false);
3071 nb::object descriptor =
3072 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3073 return createCustomDialectWrapper(keyName, std::move(descriptor));
3074 },
3075 "Gets a dialect by name using subscript notation.")
3076 .def(
3077 "__getattr__",
3078 [=](PyDialects &self, std::string attrName) {
3079 MlirDialect dialect =
3080 self.getDialectForKey(attrName, /*attrError=*/true);
3081 nb::object descriptor =
3082 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3083 return createCustomDialectWrapper(attrName, std::move(descriptor));
3084 },
3085 "Gets a dialect by name using attribute notation.");
3086
3087 //----------------------------------------------------------------------------
3088 // Mapping of PyDialect
3089 //----------------------------------------------------------------------------
3090 nb::class_<PyDialect>(m, "Dialect")
3091 .def(nb::init<nb::object>(), "descriptor"_a,
3092 "Creates a Dialect from a DialectDescriptor.")
3093 .def_prop_ro(
3094 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3095 "Returns the DialectDescriptor for this dialect.")
3096 .def(
3097 "__repr__",
3098 [](const nb::object &self) {
3099 auto clazz = self.attr("__class__");
3100 return nb::str("<Dialect ") +
3101 self.attr("descriptor").attr("namespace") +
3102 nb::str(" (class ") + clazz.attr("__module__") +
3103 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3104 },
3105 nb::sig("def __repr__(self) -> str"),
3106 "Returns a string representation of the dialect.");
3107
3108 //----------------------------------------------------------------------------
3109 // Mapping of PyDialectRegistry
3110 //----------------------------------------------------------------------------
3111 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3113 "Gets a capsule wrapping the MlirDialectRegistry.")
3116 "Creates a DialectRegistry from a capsule wrapping "
3117 "`MlirDialectRegistry`.")
3118 .def(nb::init<>(), "Creates a new empty dialect registry.");
3119
3120 //----------------------------------------------------------------------------
3121 // Mapping of Location
3122 //----------------------------------------------------------------------------
3123 nb::class_<PyLocation>(m, "Location")
3125 "Gets a capsule wrapping the MlirLocation.")
3127 "Creates a Location from a capsule wrapping MlirLocation.")
3128 .def("__enter__", &PyLocation::contextEnter,
3129 "Enters the location as a context manager.")
3130 .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
3131 "exc_value"_a.none(), "traceback"_a.none(),
3132 "Exits the location context manager.")
3133 .def(
3134 "__eq__",
3135 [](PyLocation &self, PyLocation &other) -> bool {
3136 return mlirLocationEqual(self, other);
3137 },
3138 "Compares two locations for equality.")
3139 .def(
3140 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3141 "Compares location with non-location object (always returns False).")
3142 .def_prop_ro_static(
3143 "current",
3144 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3146 if (!loc)
3147 return std::nullopt;
3148 return loc;
3149 },
3150 // clang-format off
3151 nb::sig("def current(/) -> Location | None"),
3152 // clang-format on
3153 "Gets the Location bound to the current thread or raises ValueError.")
3154 .def_static(
3155 "unknown",
3156 [](DefaultingPyMlirContext context) {
3157 return PyLocation(context->getRef(),
3158 mlirLocationUnknownGet(context->get()));
3159 },
3160 "context"_a = nb::none(),
3161 "Gets a Location representing an unknown location.")
3162 .def_static(
3163 "callsite",
3164 [](PyLocation callee, const std::vector<PyLocation> &frames,
3165 DefaultingPyMlirContext context) {
3166 if (frames.empty())
3167 throw nb::value_error("No caller frames provided.");
3168 MlirLocation caller = frames.back().get();
3169 for (const PyLocation &frame :
3170 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
3171 caller = mlirLocationCallSiteGet(frame.get(), caller);
3172 return PyLocation(context->getRef(),
3173 mlirLocationCallSiteGet(callee.get(), caller));
3174 },
3175 "callee"_a, "frames"_a, "context"_a = nb::none(),
3176 "Gets a Location representing a caller and callsite.")
3177 .def("is_a_callsite", mlirLocationIsACallSite,
3178 "Returns True if this location is a CallSiteLoc.")
3179 .def_prop_ro(
3180 "callee",
3181 [](PyLocation &self) {
3182 return PyLocation(self.getContext(),
3184 },
3185 "Gets the callee location from a CallSiteLoc.")
3186 .def_prop_ro(
3187 "caller",
3188 [](PyLocation &self) {
3189 return PyLocation(self.getContext(),
3191 },
3192 "Gets the caller location from a CallSiteLoc.")
3193 .def_static(
3194 "file",
3195 [](std::string filename, int line, int col,
3196 DefaultingPyMlirContext context) {
3197 return PyLocation(
3198 context->getRef(),
3200 context->get(), toMlirStringRef(filename), line, col));
3201 },
3202 "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
3203 "Gets a Location representing a file, line and column.")
3204 .def_static(
3205 "file",
3206 [](std::string filename, int startLine, int startCol, int endLine,
3207 int endCol, DefaultingPyMlirContext context) {
3208 return PyLocation(context->getRef(),
3210 context->get(), toMlirStringRef(filename),
3211 startLine, startCol, endLine, endCol));
3212 },
3213 "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
3214 "end_col"_a, "context"_a = nb::none(),
3215 "Gets a Location representing a file, line and column range.")
3216 .def("is_a_file", mlirLocationIsAFileLineColRange,
3217 "Returns True if this location is a FileLineColLoc.")
3218 .def_prop_ro(
3219 "filename",
3220 [](PyLocation loc) {
3221 return mlirIdentifierStr(
3223 },
3224 "Gets the filename from a FileLineColLoc.")
3225 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
3226 "Gets the start line number from a `FileLineColLoc`.")
3227 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
3228 "Gets the start column number from a `FileLineColLoc`.")
3229 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
3230 "Gets the end line number from a `FileLineColLoc`.")
3231 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
3232 "Gets the end column number from a `FileLineColLoc`.")
3233 .def_static(
3234 "fused",
3235 [](const std::vector<PyLocation> &pyLocations,
3236 std::optional<PyAttribute> metadata,
3237 DefaultingPyMlirContext context) {
3238 std::vector<MlirLocation> locations;
3239 locations.reserve(pyLocations.size());
3240 for (auto &pyLocation : pyLocations)
3241 locations.push_back(pyLocation.get());
3242 MlirLocation location = mlirLocationFusedGet(
3243 context->get(), locations.size(), locations.data(),
3244 metadata ? metadata->get() : MlirAttribute{0});
3245 return PyLocation(context->getRef(), location);
3246 },
3247 "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
3248 "Gets a Location representing a fused location with optional "
3249 "metadata.")
3250 .def("is_a_fused", mlirLocationIsAFused,
3251 "Returns True if this location is a `FusedLoc`.")
3252 .def_prop_ro(
3253 "locations",
3254 [](PyLocation &self) {
3255 unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3256 std::vector<MlirLocation> locations(numLocations);
3257 if (numLocations)
3258 mlirLocationFusedGetLocations(self, locations.data());
3259 std::vector<PyLocation> pyLocations{};
3260 pyLocations.reserve(numLocations);
3261 for (unsigned i = 0; i < numLocations; ++i)
3262 pyLocations.emplace_back(self.getContext(), locations[i]);
3263 return pyLocations;
3264 },
3265 "Gets the list of locations from a `FusedLoc`.")
3266 .def_static(
3267 "name",
3268 [](std::string name, std::optional<PyLocation> childLoc,
3269 DefaultingPyMlirContext context) {
3270 return PyLocation(
3271 context->getRef(),
3273 context->get(), toMlirStringRef(name),
3274 childLoc ? childLoc->get()
3275 : mlirLocationUnknownGet(context->get())));
3276 },
3277 "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
3278 "Gets a Location representing a named location with optional child "
3279 "location.")
3280 .def("is_a_name", mlirLocationIsAName,
3281 "Returns True if this location is a `NameLoc`.")
3282 .def_prop_ro(
3283 "name_str",
3284 [](PyLocation loc) {
3286 },
3287 "Gets the name string from a `NameLoc`.")
3288 .def_prop_ro(
3289 "child_loc",
3290 [](PyLocation &self) {
3291 return PyLocation(self.getContext(),
3293 },
3294 "Gets the child location from a `NameLoc`.")
3295 .def_static(
3296 "from_attr",
3297 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3298 return PyLocation(context->getRef(),
3299 mlirLocationFromAttribute(attribute));
3300 },
3301 "attribute"_a, "context"_a = nb::none(),
3302 "Gets a Location from a `LocationAttr`.")
3303 .def_prop_ro(
3304 "context",
3305 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3306 return self.getContext().getObject();
3307 },
3308 "Context that owns the `Location`.")
3309 .def_prop_ro(
3310 "attr",
3311 [](PyLocation &self) {
3312 return PyAttribute(self.getContext(),
3314 },
3315 "Get the underlying `LocationAttr`.")
3316 .def(
3317 "emit_error",
3318 [](PyLocation &self, std::string message) {
3319 mlirEmitError(self, message.c_str());
3320 },
3321 "message"_a,
3322 R"(
3323 Emits an error diagnostic at this location.
3324
3325 Args:
3326 message: The error message to emit.)")
3327 .def(
3328 "__repr__",
3329 [](PyLocation &self) {
3330 PyPrintAccumulator printAccum;
3331 mlirLocationPrint(self, printAccum.getCallback(),
3332 printAccum.getUserData());
3333 return printAccum.join();
3334 },
3335 "Returns the assembly representation of the location.");
3336
3337 //----------------------------------------------------------------------------
3338 // Mapping of Module
3339 //----------------------------------------------------------------------------
3340 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3342 "Gets a capsule wrapping the MlirModule.")
3344 R"(
3345 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3346
3347 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3348 prevent double-frees (of the underlying `mlir::Module`).)")
3349 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3350 R"(
3351 Clears the internal MLIR module reference.
3352
3353 This is used internally to prevent double-free when ownership is transferred
3354 via the C API capsule mechanism. Not intended for normal use.)")
3355 .def_static(
3356 "parse",
3357 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3358 -> nb::typed<nb::object, PyModule> {
3359 PyMlirContext::ErrorCapture errors(context->getRef());
3360 MlirModule module = mlirModuleCreateParse(
3361 context->get(), toMlirStringRef(moduleAsm));
3362 if (mlirModuleIsNull(module))
3363 throw MLIRError("Unable to parse module assembly", errors.take());
3364 return PyModule::forModule(module).releaseObject();
3365 },
3366 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3367 .def_static(
3368 "parse",
3369 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3370 -> nb::typed<nb::object, PyModule> {
3371 PyMlirContext::ErrorCapture errors(context->getRef());
3372 MlirModule module = mlirModuleCreateParse(
3373 context->get(), toMlirStringRef(moduleAsm));
3374 if (mlirModuleIsNull(module))
3375 throw MLIRError("Unable to parse module assembly", errors.take());
3376 return PyModule::forModule(module).releaseObject();
3377 },
3378 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3379 .def_static(
3380 "parseFile",
3381 [](const std::string &path, DefaultingPyMlirContext context)
3382 -> nb::typed<nb::object, PyModule> {
3383 PyMlirContext::ErrorCapture errors(context->getRef());
3384 MlirModule module = mlirModuleCreateParseFromFile(
3385 context->get(), toMlirStringRef(path));
3386 if (mlirModuleIsNull(module))
3387 throw MLIRError("Unable to parse module assembly", errors.take());
3388 return PyModule::forModule(module).releaseObject();
3389 },
3390 "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
3391 .def_static(
3392 "create",
3393 [](const std::optional<PyLocation> &loc)
3394 -> nb::typed<nb::object, PyModule> {
3395 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3396 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3397 return PyModule::forModule(module).releaseObject();
3398 },
3399 "loc"_a = nb::none(), "Creates an empty module.")
3400 .def_prop_ro(
3401 "context",
3402 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3403 return self.getContext().getObject();
3404 },
3405 "Context that created the `Module`.")
3406 .def_prop_ro(
3407 "operation",
3408 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3409 return PyOperation::forOperation(self.getContext(),
3410 mlirModuleGetOperation(self.get()),
3411 self.getRef().releaseObject())
3412 .releaseObject();
3413 },
3414 "Accesses the module as an operation.")
3415 .def_prop_ro(
3416 "body",
3417 [](PyModule &self) {
3419 self.getContext(), mlirModuleGetOperation(self.get()),
3420 self.getRef().releaseObject());
3421 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3422 return returnBlock;
3423 },
3424 "Return the block for this module.")
3425 .def(
3426 "dump",
3427 [](PyModule &self) {
3429 },
3431 .def(
3432 "__str__",
3433 [](const nb::object &self) {
3434 // Defer to the operation's __str__.
3435 return self.attr("operation").attr("__str__")();
3436 },
3437 nb::sig("def __str__(self) -> str"),
3438 R"(
3439 Gets the assembly form of the operation with default options.
3440
3441 If more advanced control over the assembly formatting or I/O options is needed,
3442 use the dedicated print or get_asm method, which supports keyword arguments to
3443 customize behavior.
3444 )")
3445 .def(
3446 "__eq__",
3447 [](PyModule &self, PyModule &other) {
3448 return mlirModuleEqual(self.get(), other.get());
3449 },
3450 "other"_a, "Compares two modules for equality.")
3451 .def(
3452 "__hash__",
3453 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3454 "Returns the hash value of the module.");
3455
3456 //----------------------------------------------------------------------------
3457 // Mapping of Operation.
3458 //----------------------------------------------------------------------------
3459 nb::class_<PyOperationBase>(m, "_OperationBase")
3460 .def_prop_ro(
3462 [](PyOperationBase &self) {
3463 return self.getOperation().getCapsule();
3464 },
3465 "Gets a capsule wrapping the `MlirOperation`.")
3466 .def(
3467 "__eq__",
3468 [](PyOperationBase &self, PyOperationBase &other) {
3469 return mlirOperationEqual(self.getOperation().get(),
3470 other.getOperation().get());
3471 },
3472 "Compares two operations for equality.")
3473 .def(
3474 "__eq__",
3475 [](PyOperationBase &self, nb::object other) { return false; },
3476 "Compares operation with non-operation object (always returns "
3477 "False).")
3478 .def(
3479 "__hash__",
3480 [](PyOperationBase &self) {
3481 return mlirOperationHashValue(self.getOperation().get());
3482 },
3483 "Returns the hash value of the operation.")
3484 .def_prop_ro(
3485 "attributes",
3486 [](PyOperationBase &self) {
3487 return PyOpAttributeMap(self.getOperation().getRef());
3488 },
3489 "Returns a dictionary-like map of operation attributes.")
3490 .def_prop_ro(
3491 "context",
3492 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3493 PyOperation &concreteOperation = self.getOperation();
3494 concreteOperation.checkValid();
3495 return concreteOperation.getContext().getObject();
3496 },
3497 "Context that owns the operation.")
3498 .def_prop_ro(
3499 "name",
3500 [](PyOperationBase &self) {
3501 auto &concreteOperation = self.getOperation();
3502 concreteOperation.checkValid();
3503 MlirOperation operation = concreteOperation.get();
3504 return mlirIdentifierStr(mlirOperationGetName(operation));
3505 },
3506 "Returns the fully qualified name of the operation.")
3507 .def_prop_ro(
3508 "operands",
3509 [](PyOperationBase &self) {
3510 return PyOpOperandList(self.getOperation().getRef());
3511 },
3512 "Returns the list of operation operands.")
3513 .def_prop_ro(
3514 "regions",
3515 [](PyOperationBase &self) {
3516 return PyRegionList(self.getOperation().getRef());
3517 },
3518 "Returns the list of operation regions.")
3519 .def_prop_ro(
3520 "results",
3521 [](PyOperationBase &self) {
3522 return PyOpResultList(self.getOperation().getRef());
3523 },
3524 "Returns the list of Operation results.")
3525 .def_prop_ro(
3526 "result",
3527 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3528 auto &operation = self.getOperation();
3529 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3530 .maybeDownCast();
3531 },
3532 "Shortcut to get an op result if it has only one (throws an error "
3533 "otherwise).")
3534 .def_prop_rw(
3535 "location",
3536 [](PyOperationBase &self) {
3537 PyOperation &operation = self.getOperation();
3538 return PyLocation(operation.getContext(),
3539 mlirOperationGetLocation(operation.get()));
3540 },
3541 [](PyOperationBase &self, const PyLocation &location) {
3542 PyOperation &operation = self.getOperation();
3543 mlirOperationSetLocation(operation.get(), location.get());
3544 },
3545 nb::for_getter("Returns the source location the operation was "
3546 "defined or derived from."),
3547 nb::for_setter("Sets the source location the operation was defined "
3548 "or derived from."))
3549 .def_prop_ro(
3550 "parent",
3551 [](PyOperationBase &self)
3552 -> std::optional<nb::typed<nb::object, PyOperation>> {
3553 auto parent = self.getOperation().getParentOperation();
3554 if (parent)
3555 return parent->getObject();
3556 return {};
3557 },
3558 "Returns the parent operation, or `None` if at top level.")
3559 .def(
3560 "__str__",
3561 [](PyOperationBase &self) {
3562 return self.getAsm(/*binary=*/false,
3563 /*largeElementsLimit=*/std::nullopt,
3564 /*largeResourceLimit=*/std::nullopt,
3565 /*enableDebugInfo=*/false,
3566 /*prettyDebugInfo=*/false,
3567 /*printGenericOpForm=*/false,
3568 /*useLocalScope=*/false,
3569 /*useNameLocAsPrefix=*/false,
3570 /*assumeVerified=*/false,
3571 /*skipRegions=*/false);
3572 },
3573 nb::sig("def __str__(self) -> str"),
3574 "Returns the assembly form of the operation.")
3575 .def("print",
3576 nb::overload_cast<PyAsmState &, nb::object, bool>(
3578 "state"_a, "file"_a = nb::none(), "binary"_a = false,
3579 R"(
3580 Prints the assembly form of the operation to a file like object.
3581
3582 Args:
3583 state: `AsmState` capturing the operation numbering and flags.
3584 file: Optional file like object to write to. Defaults to sys.stdout.
3585 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3586 .def("print",
3587 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3588 bool, bool, bool, bool, bool, bool, nb::object,
3589 bool, bool>(&PyOperationBase::print),
3590 // Careful: Lots of arguments must match up with print method.
3591 "large_elements_limit"_a = nb::none(),
3592 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3593 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3594 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3595 "assume_verified"_a = false, "file"_a = nb::none(),
3596 "binary"_a = false, "skip_regions"_a = false,
3597 R"(
3598 Prints the assembly form of the operation to a file like object.
3599
3600 Args:
3601 large_elements_limit: Whether to elide elements attributes above this
3602 number of elements. Defaults to None (no limit).
3603 large_resource_limit: Whether to elide resource attributes above this
3604 number of characters. Defaults to None (no limit). If large_elements_limit
3605 is set and this is None, the behavior will be to use large_elements_limit
3606 as large_resource_limit.
3607 enable_debug_info: Whether to print debug/location information. Defaults
3608 to False.
3609 pretty_debug_info: Whether to format debug information for easier reading
3610 by a human (warning: the result is unparseable). Defaults to False.
3611 print_generic_op_form: Whether to print the generic assembly forms of all
3612 ops. Defaults to False.
3613 use_local_scope: Whether to print in a way that is more optimized for
3614 multi-threaded access but may not be consistent with how the overall
3615 module prints.
3616 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3617 prefixes for the SSA identifiers. Defaults to False.
3618 assume_verified: By default, if not printing generic form, the verifier
3619 will be run and if it fails, generic form will be printed with a comment
3620 about failed verification. While a reasonable default for interactive use,
3621 for systematic use, it is often better for the caller to verify explicitly
3622 and report failures in a more robust fashion. Set this to True if doing this
3623 in order to avoid running a redundant verification. If the IR is actually
3624 invalid, behavior is undefined.
3625 file: The file like object to write to. Defaults to sys.stdout.
3626 binary: Whether to write bytes (True) or str (False). Defaults to False.
3627 skip_regions: Whether to skip printing regions. Defaults to False.)")
3628 .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
3629 "desired_version"_a = nb::none(),
3630 R"(
3631 Write the bytecode form of the operation to a file like object.
3632
3633 Args:
3634 file: The file like object to write to.
3635 desired_version: Optional version of bytecode to emit.
3636 Returns:
3637 The bytecode writer status.)")
3638 .def("get_asm", &PyOperationBase::getAsm,
3639 // Careful: Lots of arguments must match up with get_asm method.
3640 "binary"_a = false, "large_elements_limit"_a = nb::none(),
3641 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3642 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3643 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3644 "assume_verified"_a = false, "skip_regions"_a = false,
3645 R"(
3646 Gets the assembly form of the operation with all options available.
3647
3648 Args:
3649 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3650 False.
3651 ... others ...: See the print() method for common keyword arguments for
3652 configuring the printout.
3653 Returns:
3654 Either a bytes or str object, depending on the setting of the `binary`
3655 argument.)")
3656 .def("verify", &PyOperationBase::verify,
3657 "Verify the operation. Raises MLIRError if verification fails, and "
3658 "returns true otherwise.")
3659 .def("move_after", &PyOperationBase::moveAfter, "other"_a,
3660 "Puts self immediately after the other operation in its parent "
3661 "block.")
3662 .def("move_before", &PyOperationBase::moveBefore, "other"_a,
3663 "Puts self immediately before the other operation in its parent "
3664 "block.")
3665 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
3666 R"(
3667 Checks if this operation is before another in the same block.
3668
3669 Args:
3670 other: Another operation in the same parent block.
3671
3672 Returns:
3673 True if this operation is before `other` in the operation list of the parent block.)")
3674 .def(
3675 "clone",
3676 [](PyOperationBase &self,
3677 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3678 return self.getOperation().clone(ip);
3679 },
3680 "ip"_a = nb::none(),
3681 R"(
3682 Creates a deep copy of the operation.
3683
3684 Args:
3685 ip: Optional insertion point where the cloned operation should be inserted.
3686 If None, the current insertion point is used. If False, the operation
3687 remains detached.
3688
3689 Returns:
3690 A new Operation that is a clone of this operation.)")
3691 .def(
3692 "detach_from_parent",
3693 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3694 PyOperation &operation = self.getOperation();
3695 operation.checkValid();
3696 if (!operation.isAttached())
3697 throw nb::value_error("Detached operation has no parent.");
3698
3699 operation.detachFromParent();
3700 return operation.createOpView();
3701 },
3702 "Detaches the operation from its parent block.")
3703 .def_prop_ro(
3704 "attached",
3705 [](PyOperationBase &self) {
3706 PyOperation &operation = self.getOperation();
3707 operation.checkValid();
3708 return operation.isAttached();
3709 },
3710 "Reports if the operation is attached to its parent block.")
3711 .def(
3712 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
3713 R"(
3714 Erases the operation and frees its memory.
3715
3716 Note:
3717 After erasing, any Python references to the operation become invalid.)")
3718 .def("walk", &PyOperationBase::walk, "callback"_a,
3719 "walk_order"_a = PyWalkOrder::PostOrder,
3720 // clang-format off
3721 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
3722 // clang-format on
3723 R"(
3724 Walks the operation tree with a callback function.
3725
3726 Args:
3727 callback: A callable that takes an Operation and returns a WalkResult.
3728 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
3729
3730 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3731 .def_static(
3732 "create",
3733 [](std::string_view name,
3734 std::optional<std::vector<PyType *>> results,
3735 std::optional<std::vector<PyValue *>> operands,
3736 std::optional<nb::dict> attributes,
3737 std::optional<std::vector<PyBlock *>> successors, int regions,
3738 const std::optional<PyLocation> &location,
3739 const nb::object &maybeIp,
3740 bool inferType) -> nb::typed<nb::object, PyOperation> {
3741 // Unpack/validate operands.
3742 std::vector<MlirValue> mlirOperands;
3743 if (operands) {
3744 mlirOperands.reserve(operands->size());
3745 for (PyValue *operand : *operands) {
3746 if (!operand)
3747 throw nb::value_error("operand value cannot be None");
3748 mlirOperands.push_back(operand->get());
3749 }
3750 }
3751
3752 PyLocation pyLoc = maybeGetTracebackLocation(location);
3753 return PyOperation::create(
3754 name, results, mlirOperands.data(), mlirOperands.size(),
3755 attributes, successors, regions, pyLoc, maybeIp, inferType);
3756 },
3757 "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
3758 "attributes"_a = nb::none(), "successors"_a = nb::none(),
3759 "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
3760 "infer_type"_a = false,
3761 R"(
3762 Creates a new operation.
3763
3764 Args:
3765 name: Operation name (e.g. `dialect.operation`).
3766 results: Optional sequence of Type representing op result types.
3767 operands: Optional operands of the operation.
3768 attributes: Optional Dict of {str: Attribute}.
3769 successors: Optional List of Block for the operation's successors.
3770 regions: Number of regions to create (default = 0).
3771 location: Optional Location object (defaults to resolve from context manager).
3772 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
3773 infer_type: Whether to infer result types (default = False).
3774 Returns:
3775 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
3776 .def_static(
3777 "parse",
3778 [](const std::string &sourceStr, const std::string &sourceName,
3780 -> nb::typed<nb::object, PyOpView> {
3781 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3782 ->createOpView();
3783 },
3784 "source"_a, nb::kw_only(), "source_name"_a = "",
3785 "context"_a = nb::none(),
3786 "Parses an operation. Supports both text assembly format and binary "
3787 "bytecode format.")
3789 "Gets a capsule wrapping the MlirOperation.")
3792 "Creates an Operation from a capsule wrapping MlirOperation.")
3793 .def_prop_ro(
3794 "operation",
3795 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3796 return self;
3797 },
3798 "Returns self (the operation).")
3799 .def_prop_ro(
3800 "opview",
3801 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
3802 return self.createOpView();
3803 },
3804 R"(
3805 Returns an OpView of this operation.
3806
3807 Note:
3808 If the operation has a registered and loaded dialect then this OpView will
3809 be concrete wrapper class.)")
3810 .def_prop_ro("block", &PyOperation::getBlock,
3811 "Returns the block containing this operation.")
3812 .def_prop_ro(
3813 "successors",
3814 [](PyOperationBase &self) {
3815 return PyOpSuccessors(self.getOperation().getRef());
3816 },
3817 "Returns the list of Operation successors.")
3818 .def(
3819 "replace_uses_of_with",
3820 [](PyOperation &self, PyValue &of, PyValue &with) {
3821 mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
3822 },
3823 "of"_a, "with_"_a,
3824 "Replaces uses of the 'of' value with the 'with' value inside the "
3825 "operation.")
3826 .def("_set_invalid", &PyOperation::setInvalid,
3827 "Invalidate the operation.");
3828
3829 auto opViewClass =
3830 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3831 .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
3832 .def(
3833 "__init__",
3834 [](PyOpView *self, std::string_view name,
3835 std::tuple<int, bool> opRegionSpec,
3836 nb::object operandSegmentSpecObj,
3837 nb::object resultSegmentSpecObj,
3838 std::optional<nb::list> resultTypeList, nb::list operandList,
3839 std::optional<nb::dict> attributes,
3840 std::optional<std::vector<PyBlock *>> successors,
3841 std::optional<int> regions,
3842 const std::optional<PyLocation> &location,
3843 const nb::object &maybeIp) {
3844 PyLocation pyLoc = maybeGetTracebackLocation(location);
3846 name, opRegionSpec, operandSegmentSpecObj,
3847 resultSegmentSpecObj, resultTypeList, operandList,
3848 attributes, successors, regions, pyLoc, maybeIp));
3849 },
3850 "name"_a, "opRegionSpec"_a,
3851 "operandSegmentSpecObj"_a = nb::none(),
3852 "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
3853 "operands"_a = nb::none(), "attributes"_a = nb::none(),
3854 "successors"_a = nb::none(), "regions"_a = nb::none(),
3855 "loc"_a = nb::none(), "ip"_a = nb::none())
3856 .def_prop_ro(
3857 "operation",
3858 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
3859 return self.getOperationObject();
3860 })
3861 .def_prop_ro("opview",
3862 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
3863 return self;
3864 })
3865 .def(
3866 "__str__",
3867 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3868 .def_prop_ro(
3869 "successors",
3870 [](PyOperationBase &self) {
3871 return PyOpSuccessors(self.getOperation().getRef());
3872 },
3873 "Returns the list of Operation successors.")
3874 .def(
3875 "_set_invalid",
3876 [](PyOpView &self) { self.getOperation().setInvalid(); },
3877 "Invalidate the operation.");
3878 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3879 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3880 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3881 // It is faster to pass the operation_name, ods_regions, and
3882 // ods_operand_segments/ods_result_segments as arguments to the constructor,
3883 // rather than to access them as attributes.
3884 opViewClass.attr("build_generic") = classmethod(
3885 [](nb::handle cls, std::optional<nb::list> resultTypeList,
3886 nb::list operandList, std::optional<nb::dict> attributes,
3887 std::optional<std::vector<PyBlock *>> successors,
3888 std::optional<int> regions, std::optional<PyLocation> location,
3889 const nb::object &maybeIp) {
3890 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3891 std::tuple<int, bool> opRegionSpec =
3892 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3893 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3894 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3895 PyLocation pyLoc = maybeGetTracebackLocation(location);
3896 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3897 resultSegmentSpec, resultTypeList,
3898 operandList, attributes, successors,
3899 regions, pyLoc, maybeIp);
3900 },
3901 "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
3902 "attributes"_a = nb::none(), "successors"_a = nb::none(),
3903 "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
3904 "Builds a specific, generated OpView based on class level attributes.");
3905 opViewClass.attr("parse") = classmethod(
3906 [](const nb::object &cls, const std::string &sourceStr,
3907 const std::string &sourceName,
3908 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
3909 PyOperationRef parsed =
3910 PyOperation::parse(context->getRef(), sourceStr, sourceName);
3911
3912 // Check if the expected operation was parsed, and cast to to the
3913 // appropriate `OpView` subclass if successful.
3914 // NOTE: This accesses attributes that have been automatically added to
3915 // `OpView` subclasses, and is not intended to be used on `OpView`
3916 // directly.
3917 std::string clsOpName =
3918 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3919 MlirStringRef identifier =
3921 std::string_view parsedOpName(identifier.data, identifier.length);
3922 if (clsOpName != parsedOpName)
3923 throw MLIRError(join("Expected a '", clsOpName, "' op, got: '",
3924 parsedOpName, "'"));
3925 return PyOpView::constructDerived(cls, parsed.getObject());
3926 },
3927 "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
3928 "context"_a = nb::none(),
3929 "Parses a specific, generated OpView based on class level attributes.");
3930
3931 //----------------------------------------------------------------------------
3932 // Mapping of PyRegion.
3933 //----------------------------------------------------------------------------
3934 nb::class_<PyRegion>(m, "Region")
3935 .def_prop_ro(
3936 "blocks",
3937 [](PyRegion &self) {
3938 return PyBlockList(self.getParentOperation(), self.get());
3939 },
3940 "Returns a forward-optimized sequence of blocks.")
3941 .def_prop_ro(
3942 "owner",
3943 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
3944 return self.getParentOperation()->createOpView();
3945 },
3946 "Returns the operation owning this region.")
3947 .def(
3948 "__iter__",
3949 [](PyRegion &self) {
3950 self.checkValid();
3951 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3952 return PyBlockIterator(self.getParentOperation(), firstBlock);
3953 },
3954 "Iterates over blocks in the region.")
3955 .def(
3956 "__eq__",
3957 [](PyRegion &self, PyRegion &other) {
3958 return self.get().ptr == other.get().ptr;
3959 },
3960 "Compares two regions for pointer equality.")
3961 .def(
3962 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
3963 "Compares region with non-region object (always returns False).");
3964
3965 //----------------------------------------------------------------------------
3966 // Mapping of PyBlock.
3967 //----------------------------------------------------------------------------
3968 nb::class_<PyBlock>(m, "Block")
3970 "Gets a capsule wrapping the MlirBlock.")
3971 .def_prop_ro(
3972 "owner",
3973 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
3974 return self.getParentOperation()->createOpView();
3975 },
3976 "Returns the owning operation of this block.")
3977 .def_prop_ro(
3978 "region",
3979 [](PyBlock &self) {
3980 MlirRegion region = mlirBlockGetParentRegion(self.get());
3981 return PyRegion(self.getParentOperation(), region);
3982 },
3983 "Returns the owning region of this block.")
3984 .def_prop_ro(
3985 "arguments",
3986 [](PyBlock &self) {
3987 return PyBlockArgumentList(self.getParentOperation(), self.get());
3988 },
3989 "Returns a list of block arguments.")
3990 .def(
3991 "add_argument",
3992 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3993 return PyBlockArgument(self.getParentOperation(),
3994 mlirBlockAddArgument(self.get(), type, loc));
3995 },
3996 "type"_a, "loc"_a,
3997 R"(
3998 Appends an argument of the specified type to the block.
3999
4000 Args:
4001 type: The type of the argument to add.
4002 loc: The source location for the argument.
4003
4004 Returns:
4005 The newly added block argument.)")
4006 .def(
4007 "erase_argument",
4008 [](PyBlock &self, unsigned index) {
4009 return mlirBlockEraseArgument(self.get(), index);
4010 },
4011 "index"_a,
4012 R"(
4013 Erases the argument at the specified index.
4014
4015 Args:
4016 index: The index of the argument to erase.)")
4017 .def_prop_ro(
4018 "operations",
4019 [](PyBlock &self) {
4020 return PyOperationList(self.getParentOperation(), self.get());
4021 },
4022 "Returns a forward-optimized sequence of operations.")
4023 .def_static(
4024 "create_at_start",
4025 [](PyRegion &parent, const nb::sequence &pyArgTypes,
4026 const std::optional<nb::sequence> &pyArgLocs) {
4027 parent.checkValid();
4028 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4029 mlirRegionInsertOwnedBlock(parent, 0, block);
4030 return PyBlock(parent.getParentOperation(), block);
4031 },
4032 "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
4033 "Creates and returns a new Block at the beginning of the given "
4034 "region (with given argument types and locations).")
4035 .def(
4036 "append_to",
4037 [](PyBlock &self, PyRegion &region) {
4038 MlirBlock b = self.get();
4041 mlirRegionAppendOwnedBlock(region.get(), b);
4042 },
4043 "region"_a,
4044 R"(
4045 Appends this block to a region.
4046
4047 Transfers ownership if the block is currently owned by another region.
4048
4049 Args:
4050 region: The region to append the block to.)")
4051 .def(
4052 "create_before",
4053 [](PyBlock &self, const nb::args &pyArgTypes,
4054 const std::optional<nb::sequence> &pyArgLocs) {
4055 self.checkValid();
4056 MlirBlock block =
4057 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4058 MlirRegion region = mlirBlockGetParentRegion(self.get());
4059 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4060 return PyBlock(self.getParentOperation(), block);
4061 },
4062 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4063 "Creates and returns a new Block before this block "
4064 "(with given argument types and locations).")
4065 .def(
4066 "create_after",
4067 [](PyBlock &self, const nb::args &pyArgTypes,
4068 const std::optional<nb::sequence> &pyArgLocs) {
4069 self.checkValid();
4070 MlirBlock block =
4071 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4072 MlirRegion region = mlirBlockGetParentRegion(self.get());
4073 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4074 return PyBlock(self.getParentOperation(), block);
4075 },
4076 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4077 "Creates and returns a new Block after this block "
4078 "(with given argument types and locations).")
4079 .def(
4080 "__iter__",
4081 [](PyBlock &self) {
4082 self.checkValid();
4083 MlirOperation firstOperation =
4084 mlirBlockGetFirstOperation(self.get());
4085 return PyOperationIterator(self.getParentOperation(),
4086 firstOperation);
4087 },
4088 "Iterates over operations in the block.")
4089 .def(
4090 "__eq__",
4091 [](PyBlock &self, PyBlock &other) {
4092 return self.get().ptr == other.get().ptr;
4093 },
4094 "Compares two blocks for pointer equality.")
4095 .def(
4096 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4097 "Compares block with non-block object (always returns False).")
4098 .def(
4099 "__hash__",
4100 [](PyBlock &self) {
4101 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4102 },
4103 "Returns the hash value of the block.")
4104 .def(
4105 "__str__",
4106 [](PyBlock &self) {
4107 self.checkValid();
4108 PyPrintAccumulator printAccum;
4109 mlirBlockPrint(self.get(), printAccum.getCallback(),
4110 printAccum.getUserData());
4111 return printAccum.join();
4112 },
4113 "Returns the assembly form of the block.")
4114 .def(
4115 "append",
4116 [](PyBlock &self, PyOperationBase &operation) {
4117 if (operation.getOperation().isAttached())
4118 operation.getOperation().detachFromParent();
4119
4120 MlirOperation mlirOperation = operation.getOperation().get();
4121 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4122 operation.getOperation().setAttached(
4123 self.getParentOperation().getObject());
4124 },
4125 "operation"_a,
4126 R"(
4127 Appends an operation to this block.
4128
4129 If the operation is currently in another block, it will be moved.
4130
4131 Args:
4132 operation: The operation to append to the block.)")
4133 .def_prop_ro(
4134 "successors",
4135 [](PyBlock &self) {
4136 return PyBlockSuccessors(self, self.getParentOperation());
4137 },
4138 "Returns the list of Block successors.")
4139 .def_prop_ro(
4140 "predecessors",
4141 [](PyBlock &self) {
4142 return PyBlockPredecessors(self, self.getParentOperation());
4143 },
4144 "Returns the list of Block predecessors.");
4145
4146 //----------------------------------------------------------------------------
4147 // Mapping of PyInsertionPoint.
4148 //----------------------------------------------------------------------------
4149
4150 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4151 .def(nb::init<PyBlock &>(), "block"_a,
4152 "Inserts after the last operation but still inside the block.")
4153 .def("__enter__", &PyInsertionPoint::contextEnter,
4154 "Enters the insertion point as a context manager.")
4155 .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
4156 "exc_value"_a.none(), "traceback"_a.none(),
4157 "Exits the insertion point context manager.")
4158 .def_prop_ro_static(
4159 "current",
4160 [](nb::object & /*class*/) {
4162 if (!ip)
4163 throw nb::value_error("No current InsertionPoint");
4164 return ip;
4165 },
4166 nb::sig("def current(/) -> InsertionPoint"),
4167 "Gets the InsertionPoint bound to the current thread or raises "
4168 "ValueError if none has been set.")
4169 .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
4170 "Inserts before a referenced operation.")
4171 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
4172 R"(
4173 Creates an insertion point at the beginning of a block.
4174
4175 Args:
4176 block: The block at whose beginning operations should be inserted.
4177
4178 Returns:
4179 An InsertionPoint at the block's beginning.)")
4180 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4181 "block"_a,
4182 R"(
4183 Creates an insertion point before a block's terminator.
4184
4185 Args:
4186 block: The block whose terminator to insert before.
4187
4188 Returns:
4189 An InsertionPoint before the terminator.
4190
4191 Raises:
4192 ValueError: If the block has no terminator.)")
4193 .def_static("after", &PyInsertionPoint::after, "operation"_a,
4194 R"(
4195 Creates an insertion point immediately after an operation.
4196
4197 Args:
4198 operation: The operation after which to insert.
4199
4200 Returns:
4201 An InsertionPoint after the operation.)")
4202 .def("insert", &PyInsertionPoint::insert, "operation"_a,
4203 R"(
4204 Inserts an operation at this insertion point.
4205
4206 Args:
4207 operation: The operation to insert.)")
4208 .def_prop_ro(
4209 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4210 "Returns the block that this `InsertionPoint` points to.")
4211 .def_prop_ro(
4212 "ref_operation",
4213 [](PyInsertionPoint &self)
4214 -> std::optional<nb::typed<nb::object, PyOperation>> {
4215 auto refOperation = self.getRefOperation();
4216 if (refOperation)
4217 return refOperation->getObject();
4218 return {};
4219 },
4220 "The reference operation before which new operations are "
4221 "inserted, or None if the insertion point is at the end of "
4222 "the block.");
4223
4224 //----------------------------------------------------------------------------
4225 // Mapping of PyAttribute.
4226 //----------------------------------------------------------------------------
4227 nb::class_<PyAttribute>(m, "Attribute")
4228 // Delegate to the PyAttribute copy constructor, which will also lifetime
4229 // extend the backing context which owns the MlirAttribute.
4230 .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
4231 "Casts the passed attribute to the generic `Attribute`.")
4233 "Gets a capsule wrapping the MlirAttribute.")
4234 .def_static(
4236 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4237 .def_static(
4238 "parse",
4239 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4240 -> nb::typed<nb::object, PyAttribute> {
4241 PyMlirContext::ErrorCapture errors(context->getRef());
4242 MlirAttribute attr = mlirAttributeParseGet(
4243 context->get(), toMlirStringRef(attrSpec));
4244 if (mlirAttributeIsNull(attr))
4245 throw MLIRError("Unable to parse attribute", errors.take());
4246 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4247 },
4248 "asm"_a, "context"_a = nb::none(),
4249 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4250 "failure.")
4251 .def_prop_ro(
4252 "context",
4253 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4254 return self.getContext().getObject();
4255 },
4256 "Context that owns the `Attribute`.")
4257 .def_prop_ro(
4258 "type",
4259 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4260 return PyType(self.getContext(), mlirAttributeGetType(self))
4261 .maybeDownCast();
4262 },
4263 "Returns the type of the `Attribute`.")
4264 .def(
4265 "get_named",
4266 [](PyAttribute &self, std::string name) {
4267 return PyNamedAttribute(self, std::move(name));
4268 },
4269 nb::keep_alive<0, 1>(),
4270 R"(
4271 Binds a name to the attribute, creating a `NamedAttribute`.
4272
4273 Args:
4274 name: The name to bind to the `Attribute`.
4275
4276 Returns:
4277 A `NamedAttribute` with the given name and this attribute.)")
4278 .def(
4279 "__eq__",
4280 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4281 "Compares two attributes for equality.")
4282 .def(
4283 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4284 "Compares attribute with non-attribute object (always returns "
4285 "False).")
4286 .def(
4287 "__hash__",
4288 [](PyAttribute &self) {
4289 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4290 },
4291 "Returns the hash value of the attribute.")
4292 .def(
4293 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4295 .def(
4296 "__str__",
4297 [](PyAttribute &self) {
4298 PyPrintAccumulator printAccum;
4299 mlirAttributePrint(self, printAccum.getCallback(),
4300 printAccum.getUserData());
4301 return printAccum.join();
4302 },
4303 "Returns the assembly form of the Attribute.")
4304 .def(
4305 "__repr__",
4306 [](PyAttribute &self) {
4307 // Generally, assembly formats are not printed for __repr__ because
4308 // this can cause exceptionally long debug output and exceptions.
4309 // However, attribute values are generally considered useful and
4310 // are printed. This may need to be re-evaluated if debug dumps end
4311 // up being excessive.
4312 PyPrintAccumulator printAccum;
4313 printAccum.parts.append("Attribute(");
4314 mlirAttributePrint(self, printAccum.getCallback(),
4315 printAccum.getUserData());
4316 printAccum.parts.append(")");
4317 return printAccum.join();
4318 },
4319 "Returns a string representation of the attribute.")
4320 .def_prop_ro(
4321 "typeid",
4322 [](PyAttribute &self) {
4323 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4324 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4325 "mlirTypeID was expected to be non-null.");
4326 return PyTypeID(mlirTypeID);
4327 },
4328 "Returns the `TypeID` of the attribute.")
4329 .def(
4331 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4332 return self.maybeDownCast();
4333 },
4334 "Downcasts the attribute to a more specific attribute if possible.");
4335
4336 //----------------------------------------------------------------------------
4337 // Mapping of PyNamedAttribute
4338 //----------------------------------------------------------------------------
4339 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4340 .def(
4341 "__repr__",
4342 [](PyNamedAttribute &self) {
4343 PyPrintAccumulator printAccum;
4344 printAccum.parts.append("NamedAttribute(");
4345 printAccum.parts.append(
4346 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4347 mlirIdentifierStr(self.namedAttr.name).length));
4348 printAccum.parts.append("=");
4349 mlirAttributePrint(self.namedAttr.attribute,
4350 printAccum.getCallback(),
4351 printAccum.getUserData());
4352 printAccum.parts.append(")");
4353 return printAccum.join();
4354 },
4355 "Returns a string representation of the named attribute.")
4356 .def_prop_ro(
4357 "name",
4358 [](PyNamedAttribute &self) {
4359 return mlirIdentifierStr(self.namedAttr.name);
4360 },
4361 "The name of the `NamedAttribute` binding.")
4362 .def_prop_ro(
4363 "attr",
4364 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4365 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4366 "The underlying generic attribute of the `NamedAttribute` binding.");
4367
4368 //----------------------------------------------------------------------------
4369 // Mapping of PyType.
4370 //----------------------------------------------------------------------------
4371 nb::class_<PyType>(m, "Type")
4372 // Delegate to the PyType copy constructor, which will also lifetime
4373 // extend the backing context which owns the MlirType.
4374 .def(nb::init<PyType &>(), "cast_from_type"_a,
4375 "Casts the passed type to the generic `Type`.")
4377 "Gets a capsule wrapping the `MlirType`.")
4379 "Creates a Type from a capsule wrapping `MlirType`.")
4380 .def_static(
4381 "parse",
4382 [](std::string typeSpec,
4383 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4384 PyMlirContext::ErrorCapture errors(context->getRef());
4385 MlirType type =
4386 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4387 if (mlirTypeIsNull(type))
4388 throw MLIRError("Unable to parse type", errors.take());
4389 return PyType(context.get()->getRef(), type).maybeDownCast();
4390 },
4391 "asm"_a, "context"_a = nb::none(),
4392 R"(
4393 Parses the assembly form of a type.
4394
4395 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4396
4397 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4398 .def_prop_ro(
4399 "context",
4400 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4401 return self.getContext().getObject();
4402 },
4403 "Context that owns the `Type`.")
4404 .def(
4405 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4406 "Compares two types for equality.")
4407 .def(
4408 "__eq__", [](PyType &self, nb::object &other) { return false; },
4409 "other"_a.none(),
4410 "Compares type with non-type object (always returns False).")
4411 .def(
4412 "__hash__",
4413 [](PyType &self) {
4414 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4415 },
4416 "Returns the hash value of the `Type`.")
4417 .def(
4418 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4419 .def(
4420 "__str__",
4421 [](PyType &self) {
4422 PyPrintAccumulator printAccum;
4423 mlirTypePrint(self, printAccum.getCallback(),
4424 printAccum.getUserData());
4425 return printAccum.join();
4426 },
4427 "Returns the assembly form of the `Type`.")
4428 .def(
4429 "__repr__",
4430 [](PyType &self) {
4431 // Generally, assembly formats are not printed for __repr__ because
4432 // this can cause exceptionally long debug output and exceptions.
4433 // However, types are an exception as they typically have compact
4434 // assembly forms and printing them is useful.
4435 PyPrintAccumulator printAccum;
4436 printAccum.parts.append("Type(");
4437 mlirTypePrint(self, printAccum.getCallback(),
4438 printAccum.getUserData());
4439 printAccum.parts.append(")");
4440 return printAccum.join();
4441 },
4442 "Returns a string representation of the `Type`.")
4443 .def(
4445 [](PyType &self) -> nb::typed<nb::object, PyType> {
4446 return self.maybeDownCast();
4447 },
4448 "Downcasts the Type to a more specific `Type` if possible.")
4449 .def_prop_ro(
4450 "typeid",
4451 [](PyType &self) {
4452 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4453 if (!mlirTypeIDIsNull(mlirTypeID))
4454 return PyTypeID(mlirTypeID);
4455 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4456 throw nb::value_error(join(origRepr, " has no typeid.").c_str());
4457 },
4458 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4459 "`Type` has no "
4460 "`TypeID`.");
4461
4462 //----------------------------------------------------------------------------
4463 // Mapping of PyTypeID.
4464 //----------------------------------------------------------------------------
4465 nb::class_<PyTypeID>(m, "TypeID")
4467 "Gets a capsule wrapping the `MlirTypeID`.")
4469 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4470 // Note, this tests whether the underlying TypeIDs are the same,
4471 // not whether the wrapper MlirTypeIDs are the same, nor whether
4472 // the Python objects are the same (i.e., PyTypeID is a value type).
4473 .def(
4474 "__eq__",
4475 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4476 "Compares two `TypeID`s for equality.")
4477 .def(
4478 "__eq__",
4479 [](PyTypeID &self, const nb::object &other) { return false; },
4480 "Compares TypeID with non-TypeID object (always returns False).")
4481 // Note, this gives the hash value of the underlying TypeID, not the
4482 // hash value of the Python object, nor the hash value of the
4483 // MlirTypeID wrapper.
4484 .def(
4485 "__hash__",
4486 [](PyTypeID &self) {
4487 return static_cast<size_t>(mlirTypeIDHashValue(self));
4488 },
4489 "Returns the hash value of the `TypeID`.");
4490
4491 //----------------------------------------------------------------------------
4492 // Mapping of Value.
4493 //----------------------------------------------------------------------------
4494 m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
4495
4496 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4497 nb::sig("class Value(Generic[_T])"))
4498 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
4499 "Creates a Value reference from another `Value`.")
4501 "Gets a capsule wrapping the `MlirValue`.")
4503 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4504 .def_prop_ro(
4505 "context",
4506 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4507 return self.getParentOperation()->getContext().getObject();
4508 },
4509 "Context in which the value lives.")
4510 .def(
4511 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4513 .def_prop_ro(
4514 "owner",
4515 [](PyValue &self)
4516 -> nb::typed<nb::object, std::variant<PyOpView, PyBlock>> {
4517 MlirValue v = self.get();
4518 if (mlirValueIsAOpResult(v)) {
4519 assert(mlirOperationEqual(self.getParentOperation()->get(),
4520 mlirOpResultGetOwner(self.get())) &&
4521 "expected the owner of the value in Python to match "
4522 "that in "
4523 "the IR");
4524 return self.getParentOperation()->createOpView();
4525 }
4526
4528 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4529 return nb::cast(PyBlock(self.getParentOperation(), block));
4530 }
4531
4532 assert(false && "Value must be a block argument or an op result");
4533 return nb::none();
4534 },
4535 "Returns the owner of the value (`Operation` for results, `Block` "
4536 "for "
4537 "arguments).")
4538 .def_prop_ro(
4539 "uses",
4540 [](PyValue &self) {
4541 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
4542 },
4543 "Returns an iterator over uses of this value.")
4544 .def(
4545 "__eq__",
4546 [](PyValue &self, PyValue &other) {
4547 return self.get().ptr == other.get().ptr;
4548 },
4549 "Compares two values for pointer equality.")
4550 .def(
4551 "__eq__", [](PyValue &self, nb::object other) { return false; },
4552 "Compares value with non-value object (always returns False).")
4553 .def(
4554 "__hash__",
4555 [](PyValue &self) {
4556 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4557 },
4558 "Returns the hash value of the value.")
4559 .def(
4560 "__str__",
4561 [](PyValue &self) {
4562 PyPrintAccumulator printAccum;
4563 printAccum.parts.append("Value(");
4564 mlirValuePrint(self.get(), printAccum.getCallback(),
4565 printAccum.getUserData());
4566 printAccum.parts.append(")");
4567 return printAccum.join();
4568 },
4569 R"(
4570 Returns the string form of the value.
4571
4572 If the value is a block argument, this is the assembly form of its type and the
4573 position in the argument list. If the value is an operation result, this is
4574 equivalent to printing the operation that produced it.
4575 )")
4576 .def(
4577 "get_name",
4578 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4579 PyPrintAccumulator printAccum;
4580 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4581 if (useLocalScope)
4583 if (useNameLocAsPrefix)
4585 MlirAsmState valueState =
4586 mlirAsmStateCreateForValue(self.get(), flags);
4587 mlirValuePrintAsOperand(self.get(), valueState,
4588 printAccum.getCallback(),
4589 printAccum.getUserData());
4591 mlirAsmStateDestroy(valueState);
4592 return printAccum.join();
4593 },
4594 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
4595 R"(
4596 Returns the string form of value as an operand.
4597
4598 Args:
4599 use_local_scope: Whether to use local scope for naming.
4600 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4601
4602 Returns:
4603 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4604 .def(
4605 "get_name",
4606 [](PyValue &self, PyAsmState &state) {
4607 PyPrintAccumulator printAccum;
4608 MlirAsmState valueState = state.get();
4609 mlirValuePrintAsOperand(self.get(), valueState,
4610 printAccum.getCallback(),
4611 printAccum.getUserData());
4612 return printAccum.join();
4613 },
4614 "state"_a,
4615 "Returns the string form of value as an operand (i.e., the ValueID).")
4616 .def_prop_ro(
4617 "type",
4618 [](PyValue &self) -> nb::typed<nb::object, PyType> {
4619 return PyType(self.getParentOperation()->getContext(),
4620 mlirValueGetType(self.get()))
4621 .maybeDownCast();
4622 },
4623 "Returns the type of the value.")
4624 .def(
4625 "set_type",
4626 [](PyValue &self, const PyType &type) {
4627 mlirValueSetType(self.get(), type);
4628 },
4629 "type"_a, "Sets the type of the value.",
4630 nb::sig("def set_type(self, type: _T)"))
4631 .def(
4632 "replace_all_uses_with",
4633 [](PyValue &self, PyValue &with) {
4634 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4635 },
4636 "Replace all uses of value with the new value, updating anything in "
4637 "the IR that uses `self` to use the other value instead.")
4638 .def(
4639 "replace_all_uses_except",
4640 [](PyValue &self, PyValue &with, PyOperation &exception) {
4641 MlirOperation exceptedUser = exception.get();
4642 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4643 },
4644 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4645 .def(
4646 "replace_all_uses_except",
4647 [](PyValue &self, PyValue &with, const nb::list &exceptions) {
4648 // Convert Python list to a std::vector of MlirOperations
4649 std::vector<MlirOperation> exceptionOps;
4650 for (nb::handle exception : exceptions) {
4651 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4652 }
4653
4655 self, with, static_cast<intptr_t>(exceptionOps.size()),
4656 exceptionOps.data());
4657 },
4658 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4659 .def(
4660 "replace_all_uses_except",
4661 [](PyValue &self, PyValue &with, PyOperation &exception) {
4662 MlirOperation exceptedUser = exception.get();
4663 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4664 },
4665 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4666 .def(
4667 "replace_all_uses_except",
4668 [](PyValue &self, PyValue &with,
4669 std::vector<PyOperation> &exceptions) {
4670 // Convert Python list to a std::vector of MlirOperations
4671 std::vector<MlirOperation> exceptionOps;
4672 for (PyOperation &exception : exceptions)
4673 exceptionOps.push_back(exception);
4675 self, with, static_cast<intptr_t>(exceptionOps.size()),
4676 exceptionOps.data());
4677 },
4678 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4679 .def(
4681 [](PyValue &self) { return self.maybeDownCast(); },
4682 "Downcasts the `Value` to a more specific kind if possible.")
4683 .def_prop_ro(
4684 "location",
4685 [](PyValue self) {
4686 return PyLocation(
4688 mlirValueGetLocation(self));
4689 },
4690 "Returns the source location of the value.");
4691
4695
4696 nb::class_<PyAsmState>(m, "AsmState")
4697 .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
4698 R"(
4699 Creates an `AsmState` for consistent SSA value naming.
4700
4701 Args:
4702 value: The value to create state for.
4703 use_local_scope: Whether to use local scope for naming.)")
4704 .def(nb::init<PyOperationBase &, bool>(), "op"_a,
4705 "use_local_scope"_a = false,
4706 R"(
4707 Creates an AsmState for consistent SSA value naming.
4708
4709 Args:
4710 op: The operation to create state for.
4711 use_local_scope: Whether to use local scope for naming.)");
4712
4713 //----------------------------------------------------------------------------
4714 // Mapping of SymbolTable.
4715 //----------------------------------------------------------------------------
4716 nb::class_<PySymbolTable>(m, "SymbolTable")
4717 .def(nb::init<PyOperationBase &>(),
4718 R"(
4719 Creates a symbol table for an operation.
4720
4721 Args:
4722 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4723
4724 Raises:
4725 TypeError: If the operation is not a symbol table.)")
4726 .def(
4727 "__getitem__",
4728 [](PySymbolTable &self,
4729 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4730 return self.dunderGetItem(name);
4731 },
4732 R"(
4733 Looks up a symbol by name in the symbol table.
4734
4735 Args:
4736 name: The name of the symbol to look up.
4737
4738 Returns:
4739 The operation defining the symbol.
4740
4741 Raises:
4742 KeyError: If the symbol is not found.)")
4743 .def("insert", &PySymbolTable::insert, "operation"_a,
4744 R"(
4745 Inserts a symbol operation into the symbol table.
4746
4747 Args:
4748 operation: An operation with a symbol name to insert.
4749
4750 Returns:
4751 The symbol name attribute of the inserted operation.
4752
4753 Raises:
4754 ValueError: If the operation does not have a symbol name.)")
4755 .def("erase", &PySymbolTable::erase, "operation"_a,
4756 R"(
4757 Erases a symbol operation from the symbol table.
4758
4759 Args:
4760 operation: The symbol operation to erase.
4761
4762 Note:
4763 The operation is also erased from the IR and invalidated.)")
4764 .def("__delitem__", &PySymbolTable::dunderDel,
4765 "Deletes a symbol by name from the symbol table.")
4766 .def(
4767 "__contains__",
4768 [](PySymbolTable &table, const std::string &name) {
4769 return !mlirOperationIsNull(mlirSymbolTableLookup(
4770 table, mlirStringRefCreate(name.data(), name.length())));
4771 },
4772 "Checks if a symbol with the given name exists in the table.")
4773 // Static helpers.
4774 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
4775 "name"_a, "Sets the symbol name for a symbol operation.")
4776 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
4777 "Gets the symbol name from a symbol operation.")
4778 .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
4779 "Gets the visibility attribute of a symbol operation.")
4780 .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
4781 "visibility"_a,
4782 "Sets the visibility attribute of a symbol operation.")
4783 .def_static("replace_all_symbol_uses",
4784 &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
4785 "new_symbol"_a, "from_op"_a,
4786 "Replaces all uses of a symbol with a new symbol name within "
4787 "the given operation.")
4788 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4789 "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
4790 "Walks symbol tables starting from an operation with a "
4791 "callback function.");
4792
4793 // Container bindings.
4808
4809 // Debug bindings.
4811
4812 // Attribute builder getter.
4814}
4815} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
4816} // namespace python
4817} // namespace mlir
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:39
static const char kModuleParseDocstring[]
Definition IRCore.cpp:30
static nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.cpp:61
#define Py_XNewRef(obj)
Definition IRCore.cpp:2536
#define _Py_CAST(type, expr)
Definition IRCore.cpp:2512
static std::string join(const Ts &...args)
Local helper to concatenate arguments into a std::string.
Definition IRCore.cpp:53
#define Py_NewRef(obj)
Definition IRCore.cpp:2545
#define _Py_NULL
Definition IRCore.cpp:2523
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:67
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:41
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:446
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:472
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1096
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:669
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1116
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:681
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:673
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1133
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:521
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:732
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:437
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:178
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:728
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:639
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1281
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1152
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1082
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:202
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:468
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:210
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:793
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:705
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:220
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:460
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:215
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:233
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1128
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:740
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1300
MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags)
Definition IR.cpp:157
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:643
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:237
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:252
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1092
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:454
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:425
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:225
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:677
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1156
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:736
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1296
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:415
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:812
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:717
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1143
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:429
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:647
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:504
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:635
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1124
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1138
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:713
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:655
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:807
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1292
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:513
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:659
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1285
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:817
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1120
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1101
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:685
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:589
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:256
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1277
void mlirOperationRemoveFromParent(MlirOperation op)
Definition IR.cpp:641
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1088
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:802
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:206
void mlirValueDump(MlirValue value)
Definition IR.cpp:1160
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1266
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:450
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:626
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:196
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:651
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1147
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:517
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:248
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:229
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:241
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:509
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:480
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:709
void mlirTypeDump(MlirType type)
Definition IR.cpp:1271
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:798
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the value caster registration bindin...
Definition Interop.h:142
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Definition Interop.h:130
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
Accumulates into a file, either writing text (default) or binary.
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
nanobind::class_< PyRegionList > ClassTy
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:298
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:279
PyAsmState(MlirValue value, bool useLocalScope)
Definition IRCore.cpp:1776
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1006
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRCore.h:1008
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:1886
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:1882
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:1890
nanobind::typed< nanobind::object, PyAttribute > maybeDownCast()
Definition IRCore.cpp:1898
PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2215
Python wrapper for MlirBlockArgument.
Definition IRCore.h:1636
nanobind::typed< nanobind::object, PyBlock > dunderNext()
Definition IRCore.cpp:249
Blocks are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1433
PyBlock appendBlock(const nanobind::args &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Definition IRCore.cpp:305
PyBlockPredecessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2337
PyBlockSuccessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2314
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:189
Represents a diagnostic handler attached to the context.
Definition IRCore.h:406
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:753
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:747
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.h:418
Python class mirroring the C MlirDiagnostic struct.
Definition IRCore.h:356
Wrapper around an MlirDialectRegistry.
Definition IRCore.h:498
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:837
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRCore.h:474
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:820
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:34
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:164
bool loadDialectModule(llvm::StringRef dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:49
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:44
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:96
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:82
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:128
void setDialectSearchPrefixes(std::vector< std::string > newValues)
Definition Globals.h:48
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:151
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:106
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:193
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:141
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition Globals.cpp:116
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:178
std::vector< std::string > getDialectSearchPrefixes()
Get and set the list of parent modules to search for dialect implementation classes.
Definition Globals.h:44
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:833
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1807
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:1872
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1846
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:1855
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1833
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1798
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:1868
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:861
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:853
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:849
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:865
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRCore.h:307
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:511
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:479
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:526
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:520
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:212
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRCore.cpp:471
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRCore.h:246
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:516
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:475
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:1866
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:930
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:548
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:898
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:923
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRCore.h:1032
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:1916
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRCore.h:67
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRCore.h:93
void dunderSetItem(const std::string &name, const PyAttribute &attr)
Definition IRCore.cpp:2397
nanobind::typed< nanobind::object, PyAttribute > dunderGetItemNamed(const std::string &name)
Definition IRCore.cpp:2364
nanobind::typed< nanobind::object, std::optional< PyAttribute > > get(const std::string &key, nanobind::object defaultValue)
Definition IRCore.cpp:2374
static void forEachAttr(MlirOperation op, std::function< void(MlirStringRef, MlirAttribute)> fn)
Definition IRCore.cpp:2419
PyNamedAttribute dunderGetItemIndexed(intptr_t index)
Definition IRCore.cpp:2382
nanobind::typed< nanobind::object, PyOpOperand > dunderNext()
Definition IRCore.cpp:414
PyOpOperandList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2247
void dunderSetItem(intptr_t index, PyValue value)
Definition IRCore.cpp:2255
nanobind::typed< nanobind::object, PyOpView > getOwner() const
Definition IRCore.cpp:395
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1425
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2281
void dunderSetItem(intptr_t index, PyBlock block)
Definition IRCore.cpp:2289
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRCore.h:735
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1766
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1587
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1758
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:578
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1189
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1143
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition IRCore.cpp:1042
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1090
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1171
void walk(std::function< PyWalkResult(MlirOperation)> callback, PyWalkOrder walkOrder)
Definition IRCore.cpp:1111
nanobind::typed< nanobind::object, PyOpView > dunderNext()
Definition IRCore.cpp:326
Operations are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1474
nanobind::typed< nanobind::object, PyOpView > dunderGetItem(intptr_t index)
Definition IRCore.cpp:365
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, const MlirValue *operands, size_t numOperands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1253
void setInvalid()
Invalidate the operation.
Definition IRCore.h:702
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRCore.h:635
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:999
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1368
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1229
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1205
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1224
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRCore.cpp:1011
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1388
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1215
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:990
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:938
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRCore.cpp:1026
nanobind::typed< nanobind::object, PyRegion > dunderNext()
Definition IRCore.cpp:197
Regions of an op are fixed length and indexed numerically so are represented with a sequence-like con...
Definition IRCore.h:1390
PyRegionList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:216
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2068
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2032
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2108
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2053
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2149
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2040
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2137
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2080
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2063
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2093
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2119
Tracks an entry in the thread context stack.
Definition IRCore.h:126
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:663
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:691
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:703
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:668
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:611
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:714
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:673
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:658
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:606
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRCore.h:182
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRCore.h:901
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:1962
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:1972
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:1966
Wrapper around the generic MlirType.
Definition IRCore.h:875
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:877
bool operator==(const PyType &other) const
Definition IRCore.cpp:1928
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:1932
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:1936
nanobind::typed< nanobind::object, PyType > maybeDownCast()
Definition IRCore.cpp:1944
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:1980
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRCore.h:1177
nanobind::typed< nanobind::object, std::variant< PyBlockArgument, PyOpResult, PyValue > > maybeDownCast()
Definition IRCore.cpp:1999
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:2020
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
struct MlirDiagnostic MlirDiagnostic
Definition Diagnostics.h:29
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:265
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:851
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1203
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:117
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:842
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:273
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1385
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:137
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:311
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1364
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:844
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1312
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:392
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1369
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:84
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1333
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1246
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1229
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:73
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:952
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition IR.cpp:321
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:355
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1185
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1168
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:403
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1217
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:289
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:361
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1304
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:991
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:866
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1189
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:833
MlirWalkResult
Operation walk result.
Definition IR.h:837
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:932
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1156
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:100
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1060
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:293
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:347
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1041
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:95
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition IR.cpp:343
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:938
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:975
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:937
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1016
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1078
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1359
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1250
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:281
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1215
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1349
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:407
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:299
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:375
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1064
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:827
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1175
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:112
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:858
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1262
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1225
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:334
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1006
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:388
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:870
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:269
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1254
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1345
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:995
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:325
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:399
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:888
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:55
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:987
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:411
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1069
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1310
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1374
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:928
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:999
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:876
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1258
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1321
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:108
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:121
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:305
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:379
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:104
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:329
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1207
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1341
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:915
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1055
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:856
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1246
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:77
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
Definition IR.cpp:906
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:862
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1162
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:849
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:71
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:921
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:84
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:140
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:121
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:134
struct MlirStringRef MlirStringRef
Definition Support.h:79
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:129
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:165
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:47
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1553
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m)
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1238
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:198
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1568
PyObjectRef< PyOperation > PyOperationRef
Definition IRCore.h:630
static void populateResultTypes(std::string_view name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1467
MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.h:1339
PyObjectRef< PyModule > PyModuleRef
Definition IRCore.h:537
static PyOperationRef getValueOwnerRef(MlirValue value)
Definition IRCore.cpp:1984
MlirBlock MLIR_PYTHON_API_EXPORTED createBlock(const nanobind::sequence &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1414
PyWalkOrder
Traversal order for operation walk.
Definition IRCore.h:347
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition Support.h:118
Named MLIR attribute.
Definition IR.h:76
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:75
const char * data
Pointer to the first symbol.
Definition Support.h:76
size_t length
Length of the fragment.
Definition Support.h:77
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRCore.h:1326
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:151
static nanobind::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:156
static void dunderSetItemNamed(const std::string &attributeKind, nanobind::callable func, bool replace)
Definition IRCore.cpp:163
static void set(nanobind::object &o, bool enable)
Definition IRCore.cpp:113
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRCore.h:434
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRCore.h:444