MLIR 23.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
13#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14// clang-format on
16#include "mlir-c/Debug.h"
17#include "mlir-c/Diagnostics.h"
19#include "mlir-c/IR.h"
20#include "mlir-c/Support.h"
21
22#include <array>
23#include <functional>
24#include <optional>
25#include <string>
26
27namespace nb = nanobind;
28using namespace nb::literals;
29using namespace mlir;
31
32static const char kModuleParseDocstring[] =
33 R"(Parses a module's assembly format from a string.
34
35Returns a new MlirModule or raises an MLIRError if the parsing fails.
36
37See also: https://mlir.llvm.org/docs/LangRef/
38)";
39
40static const char kDumpDocstring[] =
41 "Dumps a debug representation of the object to stderr.";
42
44 R"(Replace all uses of this value with the `with` value, except for those
45in `exceptions`. `exceptions` can be either a single operation or a list of
46operations.
47)";
48
49//------------------------------------------------------------------------------
50// Utilities.
51//------------------------------------------------------------------------------
52
53/// Local helper to compute std::hash for a value.
54template <typename T>
55static size_t hash(const T &value) {
56 return std::hash<T>{}(value);
57}
58
59static nb::object
60createCustomDialectWrapper(const std::string &dialectNamespace,
61 nb::object dialectDescriptor) {
62 auto dialectClass =
64 dialectNamespace);
65 if (!dialectClass) {
66 // Use the base class.
68 std::move(dialectDescriptor)));
69 }
70
71 // Create the custom implementation.
72 return (*dialectClass)(std::move(dialectDescriptor));
73}
74
75namespace mlir {
76namespace python {
78
79MlirBlock createBlock(
80 const nb::typed<nb::sequence, PyType> &pyArgTypes,
81 const std::optional<nb::typed<nb::sequence, PyLocation>> &pyArgLocs) {
82 std::vector<MlirType> argTypes;
83 argTypes.reserve(nb::len(pyArgTypes));
84 for (nb::handle pyType : pyArgTypes)
85 argTypes.push_back(
86 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
87
88 std::vector<MlirLocation> argLocs;
89 if (pyArgLocs) {
90 argLocs.reserve(nb::len(*pyArgLocs));
91 for (nb::handle pyLoc : *pyArgLocs)
92 argLocs.push_back(
93 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
94 } else if (!argTypes.empty()) {
95 argLocs.assign(
96 argTypes.size(),
98 }
99
100 if (argTypes.size() != argLocs.size())
101 throw nb::value_error(
102 join("Expected ", argTypes.size(), " locations, got: ", argLocs.size())
103 .c_str());
104 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
105}
106
107void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
108 nb::ft_lock_guard lock(mutex);
109 mlirEnableGlobalDebug(enable);
110}
111
112bool PyGlobalDebugFlag::get(const nb::object &) {
113 nb::ft_lock_guard lock(mutex);
115}
116
117void PyGlobalDebugFlag::bind(nb::module_ &m) {
118 // Debug flags.
119 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
120 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
121 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
122 .def_static(
123 "set_types",
124 [](const std::string &type) {
125 nb::ft_lock_guard lock(mutex);
126 mlirSetGlobalDebugType(type.c_str());
127 },
128 "types"_a, "Sets specific debug types to be produced by LLVM.")
129 .def_static(
130 "set_types",
131 [](const std::vector<std::string> &types) {
132 std::vector<const char *> pointers;
133 pointers.reserve(types.size());
134 for (const std::string &str : types)
135 pointers.push_back(str.c_str());
136 nb::ft_lock_guard lock(mutex);
137 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
138 },
139 "types"_a,
140 "Sets multiple specific debug types to be produced by LLVM.");
141}
142
143nb::ft_mutex PyGlobalDebugFlag::mutex;
144
145bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
146 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
147}
148
149nb::callable
150PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
151 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
152 if (!builder)
153 throw nb::key_error(attributeKind.c_str());
154 return *builder;
155}
156
157void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
158 nb::callable func, bool replace,
159 bool allow_existing) {
160 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
161 replace, allow_existing);
162}
163
164void PyAttrBuilderMap::bind(nb::module_ &m) {
165 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
166 .def_static("contains", &PyAttrBuilderMap::dunderContains,
167 "attribute_kind"_a,
168 "Checks whether an attribute builder is registered for the "
169 "given attribute kind.")
170 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
171 "attribute_kind"_a,
172 "Gets the registered attribute builder for the given "
173 "attribute kind.")
174 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
175 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
176 "allow_existing"_a = false,
177 "Register an attribute builder for building MLIR "
178 "attributes from Python values.");
179}
180
181//------------------------------------------------------------------------------
182// PyBlock
183//------------------------------------------------------------------------------
184
186 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
187}
188
189//------------------------------------------------------------------------------
190// Collections.
191//------------------------------------------------------------------------------
192
196 length == -1 ? mlirOperationGetNumRegions(operation->get())
197 : length,
198 step),
199 operation(std::move(operation)) {}
200
201intptr_t PyRegionList::getRawNumElements() {
202 operation->checkValid();
203 return mlirOperationGetNumRegions(operation->get());
204}
205
206PyRegion PyRegionList::getRawElement(intptr_t pos) {
207 operation->checkValid();
208 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
209}
210
211PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
212 intptr_t step) const {
213 return PyRegionList(operation, startIndex, length, step);
214}
215
216nb::typed<nb::object, PyBlock> PyBlockIterator::dunderNext() {
217 operation->checkValid();
218 if (mlirBlockIsNull(next)) {
219 PyErr_SetNone(PyExc_StopIteration);
220 // python functions should return NULL after setting any exception
221 return nb::object();
222 }
223
224 PyBlock returnBlock(operation, next);
225 next = mlirBlockGetNextInRegion(next);
226 return nb::cast(returnBlock);
227}
228
229void PyBlockIterator::bind(nb::module_ &m) {
230 nb::class_<PyBlockIterator>(m, "BlockIterator")
231 .def("__iter__", &PyBlockIterator::dunderIter,
232 "Returns an iterator over the blocks in the operation's region.")
233 .def("__next__", &PyBlockIterator::dunderNext,
234 "Returns the next block in the iteration.");
235}
236
238 operation->checkValid();
239 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
240}
241
243 operation->checkValid();
244 intptr_t count = 0;
245 MlirBlock block = mlirRegionGetFirstBlock(region);
246 while (!mlirBlockIsNull(block)) {
247 count += 1;
248 block = mlirBlockGetNextInRegion(block);
249 }
250 return count;
251}
252
254 operation->checkValid();
255 if (index < 0) {
256 index += dunderLen();
257 }
258 if (index < 0) {
259 throw nb::index_error("attempt to access out of bounds block");
260 }
261 MlirBlock block = mlirRegionGetFirstBlock(region);
262 while (!mlirBlockIsNull(block)) {
263 if (index == 0) {
264 return PyBlock(operation, block);
265 }
266 block = mlirBlockGetNextInRegion(block);
267 index -= 1;
268 }
269 throw nb::index_error("attempt to access out of bounds block");
270}
271
272PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
273 const std::optional<nb::sequence> &pyArgLocs) {
274 operation->checkValid();
275 MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
276 mlirRegionAppendOwnedBlock(region, block);
277 return PyBlock(operation, block);
278}
279
280void PyBlockList::bind(nb::module_ &m) {
281 nb::class_<PyBlockList>(m, "BlockList")
282 .def("__getitem__", &PyBlockList::dunderGetItem,
283 "Returns the block at the specified index.")
284 .def("__iter__", &PyBlockList::dunderIter,
285 "Returns an iterator over blocks in the operation's region.")
286 .def("__len__", &PyBlockList::dunderLen,
287 "Returns the number of blocks in the operation's region.")
288 .def("append", &PyBlockList::appendBlock,
289 R"(
290 Appends a new block, with argument types as positional args.
291
292 Returns:
293 The created block.
294 )",
295 "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
296}
297
298nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
299 parentOperation->checkValid();
300 if (mlirOperationIsNull(next)) {
301 PyErr_SetNone(PyExc_StopIteration);
302 // python functions should return NULL after setting any exception
303 return nb::object();
304 }
305
306 PyOperationRef returnOperation =
307 PyOperation::forOperation(parentOperation->getContext(), next);
308 next = mlirOperationGetNextInBlock(next);
309 return returnOperation->createOpView();
310}
311
312void PyOperationIterator::bind(nb::module_ &m) {
313 nb::class_<PyOperationIterator>(m, "OperationIterator")
314 .def("__iter__", &PyOperationIterator::dunderIter,
315 "Returns an iterator over the operations in an operation's block.")
316 .def("__next__", &PyOperationIterator::dunderNext,
317 "Returns the next operation in the iteration.");
318}
319
321 parentOperation->checkValid();
322 return PyOperationIterator(parentOperation,
324}
325
327 parentOperation->checkValid();
328 intptr_t count = 0;
329 MlirOperation childOp = mlirBlockGetFirstOperation(block);
330 while (!mlirOperationIsNull(childOp)) {
331 count += 1;
332 childOp = mlirOperationGetNextInBlock(childOp);
333 }
334 return count;
335}
336
337nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
338 parentOperation->checkValid();
339 if (index < 0) {
340 index += dunderLen();
341 }
342 if (index < 0) {
343 throw nb::index_error("attempt to access out of bounds operation");
344 }
345 MlirOperation childOp = mlirBlockGetFirstOperation(block);
346 while (!mlirOperationIsNull(childOp)) {
347 if (index == 0) {
348 return PyOperation::forOperation(parentOperation->getContext(), childOp)
349 ->createOpView();
350 }
351 childOp = mlirOperationGetNextInBlock(childOp);
352 index -= 1;
353 }
354 throw nb::index_error("attempt to access out of bounds operation");
355}
356
357void PyOperationList::bind(nb::module_ &m) {
358 nb::class_<PyOperationList>(m, "OperationList")
359 .def("__getitem__", &PyOperationList::dunderGetItem,
360 "Returns the operation at the specified index.")
361 .def("__iter__", &PyOperationList::dunderIter,
362 "Returns an iterator over operations in the list.")
363 .def("__len__", &PyOperationList::dunderLen,
364 "Returns the number of operations in the list.");
365}
366
367nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
368 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
372}
374size_t PyOpOperand::getOperandNumber() const {
375 return mlirOpOperandGetOperandNumber(opOperand);
376}
377
378void PyOpOperand::bind(nb::module_ &m) {
379 nb::class_<PyOpOperand>(m, "OpOperand")
380 .def_prop_ro("owner", &PyOpOperand::getOwner,
381 "Returns the operation that owns this operand.")
382 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
383 "Returns the operand number in the owning operation.");
384}
385
386nb::typed<nb::object, PyOpOperand> PyOpOperandIterator::dunderNext() {
387 if (mlirOpOperandIsNull(opOperand)) {
388 PyErr_SetNone(PyExc_StopIteration);
389 // python functions should return NULL after setting any exception
390 return nb::object();
391 }
392
393 PyOpOperand returnOpOperand(opOperand);
394 opOperand = mlirOpOperandGetNextUse(opOperand);
395 return nb::cast(returnOpOperand);
396}
397
398void PyOpOperandIterator::bind(nb::module_ &m) {
399 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
400 .def("__iter__", &PyOpOperandIterator::dunderIter,
401 "Returns an iterator over operands.")
402 .def("__next__", &PyOpOperandIterator::dunderNext,
403 "Returns the next operand in the iteration.");
404}
406//------------------------------------------------------------------------------
407// PyThreadPool
408//------------------------------------------------------------------------------
409
411
413 if (threadPool.ptr)
414 mlirLlvmThreadPoolDestroy(threadPool);
415}
418 return mlirLlvmThreadPoolGetMaxConcurrency(threadPool);
419}
420
421std::string PyThreadPool::_mlir_thread_pool_ptr() const {
422 std::stringstream ss;
423 ss << threadPool.ptr;
424 return ss.str();
425}
427//------------------------------------------------------------------------------
428// PyMlirContext
429//------------------------------------------------------------------------------
430
431PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
432 nb::gil_scoped_acquire acquire;
433 nb::ft_lock_guard lock(live_contexts_mutex);
434 auto &liveContexts = getLiveContexts();
435 liveContexts[context.ptr] = this;
436}
437
439 // Note that the only public way to construct an instance is via the
440 // forContext method, which always puts the associated handle into
441 // liveContexts.
442 nb::gil_scoped_acquire acquire;
443 {
444 nb::ft_lock_guard lock(live_contexts_mutex);
445 getLiveContexts().erase(context.ptr);
446 }
447 mlirContextDestroy(context);
448}
451 return PyMlirContextRef(this, nb::cast(this));
452}
454nb::object PyMlirContext::getCapsule() {
455 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
456}
457
458nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
459 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
460 if (mlirContextIsNull(rawContext))
461 throw nb::python_error();
462 return forContext(rawContext).releaseObject();
463}
464
465PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
466 nb::gil_scoped_acquire acquire;
467 nb::ft_lock_guard lock(live_contexts_mutex);
468 auto &liveContexts = getLiveContexts();
469 auto it = liveContexts.find(context.ptr);
470 if (it == liveContexts.end()) {
471 // Create.
472 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
473 nb::object pyRef = nb::cast(unownedContextWrapper);
474 assert(pyRef && "cast to nb::object failed");
475 liveContexts[context.ptr] = unownedContextWrapper;
476 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
477 }
478 // Use existing.
479 nb::object pyRef = nb::cast(it->second);
480 return PyMlirContextRef(it->second, std::move(pyRef));
481}
482
483nb::ft_mutex PyMlirContext::live_contexts_mutex;
484
485PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
486 static LiveContextMap liveContexts;
487 return liveContexts;
488}
489
491 nb::ft_lock_guard lock(live_contexts_mutex);
492 return getLiveContexts().size();
493}
495nb::object PyMlirContext::contextEnter(nb::object context) {
496 return PyThreadContextEntry::pushContext(context);
497}
498
499void PyMlirContext::contextExit(const nb::object &excType,
500 const nb::object &excVal,
501 const nb::object &excTb) {
503}
504
505nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
506 // Note that ownership is transferred to the delete callback below by way of
507 // an explicit inc_ref (borrow).
508 PyDiagnosticHandler *pyHandler =
509 new PyDiagnosticHandler(get(), std::move(callback));
510 nb::object pyHandlerObject =
511 nb::cast(pyHandler, nb::rv_policy::take_ownership);
512 (void)pyHandlerObject.inc_ref();
513
514 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
515 // guaranteed to be known to pybind.
516 auto handlerCallback =
517 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
518 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
519 nb::object pyDiagnosticObject =
520 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
521
522 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
523 bool result = false;
524 {
525 // Since this can be called from arbitrary C++ contexts, always get the
526 // gil.
527 nb::gil_scoped_acquire gil;
528 try {
529 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
530 } catch (std::exception &e) {
531 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
532 e.what());
533 pyHandler->hadError = true;
534 }
535 }
536
537 pyDiagnostic->invalidate();
539 };
540 auto deleteCallback = +[](void *userData) {
541 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
542 assert(pyHandler->registeredID && "handler is not registered");
543 pyHandler->registeredID.reset();
544
545 // Decrement reference, balancing the inc_ref() above.
546 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
547 pyHandlerObject.dec_ref();
548 };
549
550 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
551 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
552 return pyHandlerObject;
553}
554
555MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
556 void *userData) {
557 auto *self = static_cast<ErrorCapture *>(userData);
558 // Check if the context requested we emit errors instead of capturing them.
559 if (self->ctx->emitErrorDiagnostics)
561
563 MlirDiagnosticSeverity::MlirDiagnosticError)
566 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
568}
569
572 if (!context) {
573 throw std::runtime_error(
574 "An MLIR function requires a Context but none was provided in the call "
575 "or from the surrounding environment. Either pass to the function with "
576 "a 'context=' argument or establish a default using 'with Context():'");
577 }
578 return *context;
579}
581//------------------------------------------------------------------------------
582// PyThreadContextEntry management
583//------------------------------------------------------------------------------
584
585std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
586 static thread_local std::vector<PyThreadContextEntry> stack;
587 return stack;
588}
589
591 auto &stack = getStack();
592 if (stack.empty())
593 return nullptr;
594 return &stack.back();
595}
596
597void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
598 nb::object insertionPoint,
599 nb::object location) {
600 auto &stack = getStack();
601 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
602 std::move(location));
603 // If the new stack has more than one entry and the context of the new top
604 // entry matches the previous, copy the insertionPoint and location from the
605 // previous entry if missing from the new top entry.
606 if (stack.size() > 1) {
607 auto &prev = *(stack.rbegin() + 1);
608 auto &current = stack.back();
609 if (current.context.is(prev.context)) {
610 // Default non-context objects from the previous entry.
611 if (!current.insertionPoint)
612 current.insertionPoint = prev.insertionPoint;
613 if (!current.location)
614 current.location = prev.location;
615 }
616 }
617}
618
620 if (!context)
621 return nullptr;
622 return nb::cast<PyMlirContext *>(context);
623}
624
626 if (!insertionPoint)
627 return nullptr;
628 return nb::cast<PyInsertionPoint *>(insertionPoint);
629}
630
632 if (!location)
633 return nullptr;
634 return nb::cast<PyLocation *>(location);
635}
636
638 auto *tos = getTopOfStack();
639 return tos ? tos->getContext() : nullptr;
640}
641
643 auto *tos = getTopOfStack();
644 return tos ? tos->getInsertionPoint() : nullptr;
645}
646
648 auto *tos = getTopOfStack();
649 return tos ? tos->getLocation() : nullptr;
650}
651
652nb::object PyThreadContextEntry::pushContext(nb::object context) {
653 push(FrameKind::Context, /*context=*/context,
654 /*insertionPoint=*/nb::object(),
655 /*location=*/nb::object());
656 return context;
657}
658
660 auto &stack = getStack();
661 if (stack.empty())
662 throw std::runtime_error("Unbalanced Context enter/exit");
663 auto &tos = stack.back();
664 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
665 throw std::runtime_error("Unbalanced Context enter/exit");
666 stack.pop_back();
667}
668
669nb::object
670PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
671 PyInsertionPoint &insertionPoint =
672 nb::cast<PyInsertionPoint &>(insertionPointObj);
673 nb::object contextObj =
674 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
675 push(FrameKind::InsertionPoint,
676 /*context=*/contextObj,
677 /*insertionPoint=*/insertionPointObj,
678 /*location=*/nb::object());
679 return insertionPointObj;
680}
681
683 auto &stack = getStack();
684 if (stack.empty())
685 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
686 auto &tos = stack.back();
687 if (tos.frameKind != FrameKind::InsertionPoint &&
688 tos.getInsertionPoint() != &insertionPoint)
689 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
690 stack.pop_back();
691}
692
693nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
694 PyLocation &location = nb::cast<PyLocation &>(locationObj);
695 nb::object contextObj = location.getContext().getObject();
696 push(FrameKind::Location, /*context=*/contextObj,
697 /*insertionPoint=*/nb::object(),
698 /*location=*/locationObj);
699 return locationObj;
700}
701
703 auto &stack = getStack();
704 if (stack.empty())
705 throw std::runtime_error("Unbalanced Location enter/exit");
706 auto &tos = stack.back();
707 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
708 throw std::runtime_error("Unbalanced Location enter/exit");
709 stack.pop_back();
710}
712//------------------------------------------------------------------------------
713// PyDiagnostic*
714//------------------------------------------------------------------------------
715
717 valid = false;
718 if (materializedNotes) {
719 for (nb::handle noteObject : *materializedNotes) {
720 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
721 note->invalidate();
722 }
723 }
724}
725
727 nb::object callback)
728 : context(context), callback(std::move(callback)) {}
729
731
733 if (!registeredID)
734 return;
735 MlirDiagnosticHandlerID localID = *registeredID;
736 mlirContextDetachDiagnosticHandler(context, localID);
737 assert(!registeredID && "should have unregistered");
738 // Not strictly necessary but keeps stale pointers from being around to cause
739 // issues.
740 context = {nullptr};
741}
742
743void PyDiagnostic::checkValid() {
744 if (!valid) {
745 throw std::invalid_argument(
746 "Diagnostic is invalid (used outside of callback)");
747 }
748}
749
751 checkValid();
752 return static_cast<PyDiagnosticSeverity>(
753 mlirDiagnosticGetSeverity(diagnostic));
754}
755
757 checkValid();
758 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
759 MlirContext context = mlirLocationGetContext(loc);
760 return PyLocation(PyMlirContext::forContext(context), loc);
761}
762
763nb::str PyDiagnostic::getMessage() {
764 checkValid();
765 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
766 PyFileAccumulator accum(fileObject, /*binary=*/false);
767 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
768 return nb::cast<nb::str>(fileObject.attr("getvalue")());
769}
770
771nb::typed<nb::tuple, PyDiagnostic> PyDiagnostic::getNotes() {
772 checkValid();
773 if (materializedNotes)
774 return *materializedNotes;
775 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
776 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
777 for (intptr_t i = 0; i < numNotes; ++i) {
778 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
779 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
780 PyTuple_SetItem(notes.ptr(), i, diagnostic.release().ptr());
781 }
782 materializedNotes = std::move(notes);
783
784 return *materializedNotes;
785}
786
788 std::vector<DiagnosticInfo> notes;
789 for (nb::handle n : getNotes())
790 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
791 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
792 std::move(notes)};
793}
795//------------------------------------------------------------------------------
796// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
797//------------------------------------------------------------------------------
798
799MlirDialect PyDialects::getDialectForKey(const std::string &key,
800 bool attrError) {
801 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
802 {key.data(), key.size()});
803 if (mlirDialectIsNull(dialect)) {
804 std::string msg = join("Dialect '", key, "' not found");
805 if (attrError)
806 throw nb::attribute_error(msg.c_str());
807 throw nb::index_error(msg.c_str());
808 }
809 return dialect;
810}
813 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
814}
815
817 MlirDialectRegistry rawRegistry =
819 if (mlirDialectRegistryIsNull(rawRegistry))
820 throw nb::python_error();
821 return PyDialectRegistry(rawRegistry);
822}
824//------------------------------------------------------------------------------
825// PyLocation
826//------------------------------------------------------------------------------
828nb::object PyLocation::getCapsule() {
829 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
830}
831
832PyLocation PyLocation::createFromCapsule(nb::object capsule) {
833 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
834 if (mlirLocationIsNull(rawLoc))
835 throw nb::python_error();
837 rawLoc);
838}
840nb::object PyLocation::contextEnter(nb::object locationObj) {
841 return PyThreadContextEntry::pushLocation(locationObj);
842}
843
844void PyLocation::contextExit(const nb::object &excType,
845 const nb::object &excVal,
846 const nb::object &excTb) {
848}
849
852 if (!location) {
853 throw std::runtime_error(
854 "An MLIR function requires a Location but none was provided in the "
855 "call or from the surrounding environment. Either pass to the function "
856 "with a 'loc=' argument or establish a default using 'with loc:'");
857 }
858 return *location;
859}
860
861//------------------------------------------------------------------------------
862// PyModule
863//------------------------------------------------------------------------------
864
865PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
866 : BaseContextObject(std::move(contextRef)), module(module) {}
867
869 nb::gil_scoped_acquire acquire;
870 auto &liveModules = getContext()->liveModules;
871 assert(liveModules.count(module.ptr) == 1 &&
872 "destroying module not in live map");
873 liveModules.erase(module.ptr);
874 mlirModuleDestroy(module);
875}
876
877PyModuleRef PyModule::forModule(MlirModule module) {
878 MlirContext context = mlirModuleGetContext(module);
879 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
880
881 nb::gil_scoped_acquire acquire;
882 auto &liveModules = contextRef->liveModules;
883 auto it = liveModules.find(module.ptr);
884 if (it == liveModules.end()) {
885 // Create.
886 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
887 // Note that the default return value policy on cast is automatic_reference,
888 // which does not take ownership (delete will not be called).
889 // Just be explicit.
890 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
891 unownedModule->handle = pyRef;
892 liveModules[module.ptr] =
893 std::make_pair(unownedModule->handle, unownedModule);
894 return PyModuleRef(unownedModule, std::move(pyRef));
895 }
896 // Use existing.
897 PyModule *existing = it->second.second;
898 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
899 return PyModuleRef(existing, std::move(pyRef));
900}
901
902nb::object PyModule::createFromCapsule(nb::object capsule) {
903 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
904 if (mlirModuleIsNull(rawModule))
905 throw nb::python_error();
906 return forModule(rawModule).releaseObject();
907}
908
909nb::object PyModule::getCapsule() {
910 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
911}
913//------------------------------------------------------------------------------
914// PyOperation
915//------------------------------------------------------------------------------
916
917PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
918 : BaseContextObject(std::move(contextRef)), operation(operation) {}
919
921 // If the operation has already been invalidated there is nothing to do.
922 if (!valid)
923 return;
924 // Otherwise, invalidate the operation when it is attached.
925 if (isAttached())
926 setInvalid();
927 else {
928 // And destroy it when it is detached, i.e. owned by Python.
929 erase();
930 }
931}
932
933namespace {
934
935// Constructs a new object of type T in-place on the Python heap, returning a
936// PyObjectRef to it, loosely analogous to std::make_shared<T>().
937template <typename T, class... Args>
938PyObjectRef<T> makeObjectRef(Args &&...args) {
939 nb::handle type = nb::type<T>();
940 nb::object instance = nb::inst_alloc(type);
941 T *ptr = nb::inst_ptr<T>(instance);
942 new (ptr) T(std::forward<Args>(args)...);
943 nb::inst_mark_ready(instance);
944 return PyObjectRef<T>(ptr, std::move(instance));
945}
946
947} // namespace
948
949PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
950 MlirOperation operation,
951 nb::object parentKeepAlive) {
952 // Create.
953 PyOperationRef unownedOperation =
954 makeObjectRef<PyOperation>(std::move(contextRef), operation);
955 unownedOperation->handle = unownedOperation.getObject();
956 if (parentKeepAlive) {
957 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
958 }
959 return unownedOperation;
960}
961
963 MlirOperation operation,
964 nb::object parentKeepAlive) {
965 return createInstance(std::move(contextRef), operation,
966 std::move(parentKeepAlive));
967}
968
970 MlirOperation operation,
971 nb::object parentKeepAlive) {
972 PyOperationRef created = createInstance(std::move(contextRef), operation,
973 std::move(parentKeepAlive));
974 created->attached = false;
975 return created;
976}
977
979 const std::string &sourceStr,
980 const std::string &sourceName) {
981 PyMlirContext::ErrorCapture errors(contextRef);
982 MlirOperation op =
983 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
984 toMlirStringRef(sourceName));
985 if (mlirOperationIsNull(op))
986 throw MLIRError("Unable to parse operation assembly", errors.take());
987 return PyOperation::createDetached(std::move(contextRef), op);
988}
989
992 setDetached();
993 parentKeepAlive = nb::object();
994}
995
996MlirOperation PyOperation::get() const {
997 checkValid();
998 return operation;
999}
1002 return PyOperationRef(this, nb::borrow<nb::object>(handle));
1003}
1004
1005void PyOperation::setAttached(const nb::object &parent) {
1006 assert(!attached && "operation already attached");
1007 attached = true;
1008}
1009
1011 assert(attached && "operation already detached");
1012 attached = false;
1013}
1014
1015void PyOperation::checkValid() const {
1016 if (!valid) {
1017 throw std::runtime_error("the operation has been invalidated");
1018 }
1019}
1020
1021void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1022 std::optional<int64_t> largeResourceLimit,
1023 bool enableDebugInfo, bool prettyDebugInfo,
1024 bool printGenericOpForm, bool useLocalScope,
1025 bool useNameLocAsPrefix, bool assumeVerified,
1026 nb::object fileObject, bool binary,
1027 bool skipRegions) {
1028 PyOperation &operation = getOperation();
1029 operation.checkValid();
1030 if (fileObject.is_none())
1031 fileObject = nb::module_::import_("sys").attr("stdout");
1032
1033 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1034 if (largeElementsLimit)
1035 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1036 if (largeResourceLimit)
1037 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1038 if (enableDebugInfo)
1039 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1040 /*prettyForm=*/prettyDebugInfo);
1041 if (printGenericOpForm)
1043 if (useLocalScope)
1045 if (assumeVerified)
1047 if (skipRegions)
1049 if (useNameLocAsPrefix)
1051
1052 PyFileAccumulator accum(fileObject, binary);
1053 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1054 accum.getUserData());
1056}
1057
1058void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1059 bool binary) {
1060 PyOperation &operation = getOperation();
1061 operation.checkValid();
1062 if (fileObject.is_none())
1063 fileObject = nb::module_::import_("sys").attr("stdout");
1064 PyFileAccumulator accum(fileObject, binary);
1065 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1066 accum.getUserData());
1067}
1068
1069void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1070 std::optional<int64_t> bytecodeVersion) {
1071 PyOperation &operation = getOperation();
1072 operation.checkValid();
1073 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1074
1075 if (!bytecodeVersion.has_value())
1076 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1077 accum.getUserData());
1078
1079 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1080 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1082 operation, config, accum.getCallback(), accum.getUserData());
1085 throw nb::value_error(
1086 join("Unable to honor desired bytecode version ", *bytecodeVersion)
1087 .c_str());
1088}
1089
1090void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
1091 PyWalkOrder walkOrder) {
1092 PyOperation &operation = getOperation();
1093 operation.checkValid();
1094 struct UserData {
1095 std::function<PyWalkResult(MlirOperation)> callback;
1096 bool gotException;
1097 std::string exceptionWhat;
1098 nb::object exceptionType;
1099 };
1100 UserData userData{callback, false, {}, {}};
1101 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1102 void *userData) {
1103 UserData *calleeUserData = static_cast<UserData *>(userData);
1104 try {
1105 return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
1106 } catch (nb::python_error &e) {
1107 calleeUserData->gotException = true;
1108 calleeUserData->exceptionWhat = std::string(e.what());
1109 calleeUserData->exceptionType = nb::borrow(e.type());
1110 return MlirWalkResult::MlirWalkResultInterrupt;
1111 }
1112 };
1113 mlirOperationWalk(operation, walkCallback, &userData,
1114 static_cast<MlirWalkOrder>(walkOrder));
1115 if (userData.gotException) {
1116 std::string message("Exception raised in callback: ");
1117 message.append(userData.exceptionWhat);
1118 throw std::runtime_error(message);
1119 }
1120}
1121
1122nb::object PyOperationBase::getAsm(bool binary,
1123 std::optional<int64_t> largeElementsLimit,
1124 std::optional<int64_t> largeResourceLimit,
1125 bool enableDebugInfo, bool prettyDebugInfo,
1126 bool printGenericOpForm, bool useLocalScope,
1127 bool useNameLocAsPrefix, bool assumeVerified,
1128 bool skipRegions) {
1129 nb::object fileObject;
1130 if (binary) {
1131 fileObject = nb::module_::import_("io").attr("BytesIO")();
1132 } else {
1133 fileObject = nb::module_::import_("io").attr("StringIO")();
1134 }
1135 print(/*largeElementsLimit=*/largeElementsLimit,
1136 /*largeResourceLimit=*/largeResourceLimit,
1137 /*enableDebugInfo=*/enableDebugInfo,
1138 /*prettyDebugInfo=*/prettyDebugInfo,
1139 /*printGenericOpForm=*/printGenericOpForm,
1140 /*useLocalScope=*/useLocalScope,
1141 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1142 /*assumeVerified=*/assumeVerified,
1143 /*fileObject=*/fileObject,
1144 /*binary=*/binary,
1145 /*skipRegions=*/skipRegions);
1146
1147 return fileObject.attr("getvalue")();
1148}
1149
1151 PyOperation &operation = getOperation();
1152 PyOperation &otherOp = other.getOperation();
1153 operation.checkValid();
1154 otherOp.checkValid();
1155 mlirOperationMoveAfter(operation, otherOp);
1156 operation.parentKeepAlive = otherOp.parentKeepAlive;
1157}
1158
1160 PyOperation &operation = getOperation();
1161 PyOperation &otherOp = other.getOperation();
1162 operation.checkValid();
1163 otherOp.checkValid();
1164 mlirOperationMoveBefore(operation, otherOp);
1165 operation.parentKeepAlive = otherOp.parentKeepAlive;
1166}
1167
1169 PyOperation &operation = getOperation();
1170 PyOperation &otherOp = other.getOperation();
1171 operation.checkValid();
1172 otherOp.checkValid();
1173 return mlirOperationIsBeforeInBlock(operation, otherOp);
1174}
1175
1177 PyOperation &op = getOperation();
1180 throw MLIRError("Verification failed", errors.take());
1181 return true;
1182}
1183
1184std::optional<PyOperationRef> PyOperation::getParentOperation() {
1185 checkValid();
1186 if (!isAttached())
1187 throw nb::value_error("Detached operations have no parent");
1188 MlirOperation operation = mlirOperationGetParentOperation(get());
1189 if (mlirOperationIsNull(operation))
1190 return {};
1191 return PyOperation::forOperation(getContext(), operation);
1192}
1193
1195 checkValid();
1196 std::optional<PyOperationRef> parentOperation = getParentOperation();
1197 MlirBlock block = mlirOperationGetBlock(get());
1198 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1199 assert(parentOperation && "Operation has no parent");
1200 return PyBlock{std::move(*parentOperation), block};
1201}
1202
1204 checkValid();
1205 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1206}
1207
1208nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1209 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1210 if (mlirOperationIsNull(rawOperation))
1211 throw nb::python_error();
1212 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1213 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1214 .releaseObject();
1215}
1216
1217static void maybeInsertOperation(PyOperationRef &op,
1218 const nb::object &maybeIp) {
1219 // InsertPoint active?
1220 if (!maybeIp.is(nb::cast(false))) {
1221 PyInsertionPoint *ip;
1222 if (maybeIp.is_none()) {
1224 } else {
1225 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1226 }
1227 if (ip)
1228 ip->insert(*op.get());
1229 }
1230}
1231
1232nb::object PyOperation::create(std::string_view name,
1233 std::optional<std::vector<PyType *>> results,
1234 const MlirValue *operands, size_t numOperands,
1235 std::optional<nb::dict> attributes,
1236 std::optional<std::vector<PyBlock *>> successors,
1237 int regions, PyLocation &location,
1238 const nb::object &maybeIp, bool inferType) {
1239 std::vector<MlirType> mlirResults;
1240 std::vector<MlirBlock> mlirSuccessors;
1241 std::vector<std::pair<std::string, MlirAttribute>> mlirAttributes;
1242
1243 // General parameter validation.
1244 if (regions < 0)
1245 throw nb::value_error("number of regions must be >= 0");
1246
1247 // Unpack/validate results.
1248 if (results) {
1249 mlirResults.reserve(results->size());
1250 for (PyType *result : *results) {
1251 // TODO: Verify result type originate from the same context.
1252 if (!result)
1253 throw nb::value_error("result type cannot be None");
1254 mlirResults.push_back(*result);
1255 }
1256 }
1257 // Unpack/validate attributes.
1258 if (attributes) {
1259 mlirAttributes.reserve(attributes->size());
1260 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1261 std::string key;
1262 try {
1263 key = nb::cast<std::string>(it.first);
1264 } catch (nb::cast_error &err) {
1265 std::string msg = join("Invalid attribute key (not a string) when "
1266 "attempting to create the operation \"",
1267 name, "\" (", err.what(), ")");
1268 throw nb::type_error(msg.c_str());
1269 }
1270 try {
1271 auto &attribute = nb::cast<PyAttribute &>(it.second);
1272 // TODO: Verify attribute originates from the same context.
1273 mlirAttributes.emplace_back(std::move(key), attribute);
1274 } catch (nb::cast_error &err) {
1275 std::string msg = join("Invalid attribute value for the key \"", key,
1276 "\" when attempting to create the operation \"",
1277 name, "\" (", err.what(), ")");
1278 throw nb::type_error(msg.c_str());
1279 } catch (std::runtime_error &) {
1280 // This exception seems thrown when the value is "None".
1281 std::string msg = join(
1282 "Found an invalid (`None`?) attribute value for the key \"", key,
1283 "\" when attempting to create the operation \"", name, "\"");
1284 throw std::runtime_error(msg);
1285 }
1286 }
1287 }
1288 // Unpack/validate successors.
1289 if (successors) {
1290 mlirSuccessors.reserve(successors->size());
1291 for (PyBlock *successor : *successors) {
1292 // TODO: Verify successor originate from the same context.
1293 if (!successor)
1294 throw nb::value_error("successor block cannot be None");
1295 mlirSuccessors.push_back(successor->get());
1296 }
1297 }
1298
1299 // Apply unpacked/validated to the operation state. Beyond this
1300 // point, exceptions cannot be thrown or else the state will leak.
1301 MlirOperationState state =
1302 mlirOperationStateGet(toMlirStringRef(name), location);
1303 if (numOperands > 0)
1304 mlirOperationStateAddOperands(&state, numOperands, operands);
1305 state.enableResultTypeInference = inferType;
1306 if (!mlirResults.empty())
1307 mlirOperationStateAddResults(&state, mlirResults.size(),
1308 mlirResults.data());
1309 if (!mlirAttributes.empty()) {
1310 // Note that the attribute names directly reference bytes in
1311 // mlirAttributes, so that vector must not be changed from here
1312 // on.
1313 std::vector<MlirNamedAttribute> mlirNamedAttributes;
1314 mlirNamedAttributes.reserve(mlirAttributes.size());
1315 for (const std::pair<std::string, MlirAttribute> &it : mlirAttributes)
1316 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1318 toMlirStringRef(it.first)),
1319 it.second));
1320 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1321 mlirNamedAttributes.data());
1322 }
1323 if (!mlirSuccessors.empty())
1324 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1325 mlirSuccessors.data());
1326 if (regions) {
1327 std::vector<MlirRegion> mlirRegions;
1328 mlirRegions.resize(regions);
1329 for (int i = 0; i < regions; ++i)
1330 mlirRegions[i] = mlirRegionCreate();
1331 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1332 mlirRegions.data());
1333 }
1334
1335 // Construct the operation.
1336 PyMlirContext::ErrorCapture errors(location.getContext());
1337 MlirOperation operation = mlirOperationCreate(&state);
1338 if (!operation.ptr)
1339 throw MLIRError("Operation creation failed", errors.take());
1340 PyOperationRef created =
1341 PyOperation::createDetached(location.getContext(), operation);
1342 maybeInsertOperation(created, maybeIp);
1343
1344 return created.getObject();
1345}
1346
1347nb::object PyOperation::clone(const nb::object &maybeIp) {
1348 MlirOperation clonedOperation = mlirOperationClone(operation);
1349 PyOperationRef cloned =
1350 PyOperation::createDetached(getContext(), clonedOperation);
1351 maybeInsertOperation(cloned, maybeIp);
1352
1353 return cloned->createOpView();
1354}
1355
1356nb::object PyOperation::createOpView() {
1357 checkValid();
1358 MlirIdentifier ident = mlirOperationGetName(get());
1359 MlirStringRef identStr = mlirIdentifierStr(ident);
1360 auto operationCls = PyGlobals::get().lookupOperationClass(
1361 std::string_view(identStr.data, identStr.length));
1362 if (operationCls)
1363 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1364 return nb::cast(PyOpView(getRef().getObject()));
1365}
1366
1367void PyOperation::erase() {
1369 setInvalid();
1370 mlirOperationDestroy(operation);
1371}
1372
1373void PyOpResult::bindDerived(ClassTy &c) {
1374 c.def_prop_ro(
1375 "owner",
1376 [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
1377 assert(mlirOperationEqual(self.getParentOperation()->get(),
1378 mlirOpResultGetOwner(self.get())) &&
1379 "expected the owner of the value in Python to match that in "
1380 "the IR");
1381 return self.getParentOperation()->createOpView();
1382 },
1383 "Returns the operation that produces this result.");
1384 c.def_prop_ro(
1385 "result_number",
1386 [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
1387 "Returns the position of this result in the operation's result list.");
1389
1390/// Returns the list of types of the values held by container.
1391template <typename Container>
1392static std::vector<nb::typed<nb::object, PyType>>
1393getValueTypes(Container &container, PyMlirContextRef &context) {
1394 std::vector<nb::typed<nb::object, PyType>> result;
1395 result.reserve(container.size());
1396 for (int i = 0, e = container.size(); i < e; ++i) {
1397 result.push_back(PyType(context->getRef(),
1398 mlirValueGetType(container.getElement(i).get()))
1400 }
1401 return result;
1402}
1403
1405 intptr_t length, intptr_t step)
1406 : Sliceable(startIndex,
1408 : length,
1409 step),
1410 operation(std::move(operation)) {}
1411
1412void PyOpResultList::bindDerived(ClassTy &c) {
1413 c.def_prop_ro(
1414 "types",
1415 [](PyOpResultList &self) {
1416 return getValueTypes(self, self.operation->getContext());
1417 },
1418 "Returns a list of types for all results in this result list.");
1419 c.def_prop_ro(
1420 "owner",
1421 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1422 return self.operation->createOpView();
1423 },
1424 "Returns the operation that owns this result list.");
1425}
1426
1427intptr_t PyOpResultList::getRawNumElements() {
1428 operation->checkValid();
1429 return mlirOperationGetNumResults(operation->get());
1430}
1431
1432PyOpResult PyOpResultList::getRawElement(intptr_t index) {
1433 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1434 return PyOpResult(value);
1435}
1436
1437PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
1438 intptr_t step) const {
1439 return PyOpResultList(operation, startIndex, length, step);
1440}
1442//------------------------------------------------------------------------------
1443// PyOpView
1444//------------------------------------------------------------------------------
1445
1446static void populateResultTypes(std::string_view name,
1447 nb::sequence resultTypeList,
1448 const nb::object &resultSegmentSpecObj,
1449 std::vector<int32_t> &resultSegmentLengths,
1450 std::vector<PyType *> &resultTypes) {
1451 resultTypes.reserve(nb::len(resultTypeList));
1452 if (resultSegmentSpecObj.is_none()) {
1453 // Non-variadic result unpacking.
1454 size_t index = 0;
1455 for (nb::handle resultType : resultTypeList) {
1456 try {
1457 resultTypes.push_back(nb::cast<PyType *>(resultType));
1458 if (!resultTypes.back())
1459 throw nb::cast_error();
1460 } catch (nb::cast_error &err) {
1461 throw nb::value_error(join("Result ", index, " of operation \"", name,
1462 "\" must be a Type (", err.what(), ")")
1463 .c_str());
1464 }
1465 ++index;
1466 }
1467 } else {
1468 // Sized result unpacking.
1469 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1470 if (resultSegmentSpec.size() != nb::len(resultTypeList)) {
1471 throw nb::value_error(
1472 join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
1473 " result segments but was provided ", nb::len(resultTypeList))
1474 .c_str());
1475 }
1476 resultSegmentLengths.reserve(nb::len(resultTypeList));
1477 for (size_t i = 0, e = resultSegmentSpec.size(); i < e; ++i) {
1478 int segmentSpec = resultSegmentSpec[i];
1479 if (segmentSpec == 1 || segmentSpec == 0) {
1480 // Unpack unary element.
1481 try {
1482 auto *resultType = nb::cast<PyType *>(resultTypeList[i]);
1483 if (resultType) {
1484 resultTypes.push_back(resultType);
1485 resultSegmentLengths.push_back(1);
1486 } else if (segmentSpec == 0) {
1487 // Allowed to be optional.
1488 resultSegmentLengths.push_back(0);
1489 } else {
1490 throw nb::value_error(
1491 join("Result ", i, " of operation \"", name,
1492 "\" must be a Type (was None and result is not optional)")
1493 .c_str());
1494 }
1495 } catch (nb::cast_error &err) {
1496 throw nb::value_error(join("Result ", i, " of operation \"", name,
1497 "\" must be a Type (", err.what(), ")")
1498 .c_str());
1499 }
1500 } else if (segmentSpec == -1) {
1501 // Unpack sequence by appending.
1502 try {
1503 if (resultTypeList[i].is_none()) {
1504 // Treat it as an empty list.
1505 resultSegmentLengths.push_back(0);
1506 } else {
1507 // Unpack the list.
1508 auto segment = nb::cast<nb::sequence>(resultTypeList[i]);
1509 for (nb::handle segmentItem : segment) {
1510 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1511 if (!resultTypes.back()) {
1512 throw nb::type_error("contained a None item");
1513 }
1514 }
1515 resultSegmentLengths.push_back(nb::len(segment));
1516 }
1517 } catch (std::exception &err) {
1518 // NOTE: Sloppy to be using a catch-all here, but there are at least
1519 // three different unrelated exceptions that can be thrown in the
1520 // above "casts". Just keep the scope above small and catch them all.
1521 throw nb::value_error(join("Result ", i, " of operation \"", name,
1522 "\" must be a Sequence of Types (",
1523 err.what(), ")")
1524 .c_str());
1525 }
1526 } else {
1527 throw nb::value_error("Unexpected segment spec");
1529 }
1530 }
1531}
1532
1533MlirValue getUniqueResult(MlirOperation operation) {
1534 auto numResults = mlirOperationGetNumResults(operation);
1535 if (numResults != 1) {
1536 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1537 throw nb::value_error(
1538 join("Cannot call .result on operation ",
1539 std::string_view(name.data, name.length), " which has ",
1540 numResults,
1541 " results (it is only valid for operations with a "
1542 "single result)")
1543 .c_str());
1544 }
1545 return mlirOperationGetResult(operation, 0);
1546}
1547
1548static MlirValue getOpResultOrValue(nb::handle operand) {
1549 if (operand.is_none()) {
1550 throw nb::value_error("contained a None item");
1551 }
1552 PyOperationBase *op;
1553 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1554 return getUniqueResult(op->getOperation());
1555 }
1556 PyOpResultList *opResultList;
1557 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1558 return getUniqueResult(opResultList->getOperation()->get());
1559 }
1560 PyValue *value;
1561 if (nb::try_cast<PyValue *>(operand, value)) {
1562 return value->get();
1563 }
1564 throw nb::value_error("is not a Value");
1565}
1566
1567nb::typed<nb::object, PyOperation> PyOpView::buildGeneric(
1568 std::string_view name, std::tuple<int, bool> opRegionSpec,
1569 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1570 std::optional<nb::sequence> resultTypeList, nb::sequence operandList,
1571 std::optional<nb::dict> attributes,
1572 std::optional<std::vector<PyBlock *>> successors,
1573 std::optional<int> regions, PyLocation &location,
1574 const nb::object &maybeIp) {
1575 PyMlirContextRef context = location.getContext();
1576
1577 // Class level operation construction metadata.
1578 // Operand and result segment specs are either none, which does no
1579 // variadic unpacking, or a list of ints with segment sizes, where each
1580 // element is either a positive number (typically 1 for a scalar) or -1 to
1581 // indicate that it is derived from the length of the same-indexed operand
1582 // or result (implying that it is a list at that position).
1583 std::vector<int32_t> operandSegmentLengths;
1584 std::vector<int32_t> resultSegmentLengths;
1585
1586 // Validate/determine region count.
1587 int opMinRegionCount = std::get<0>(opRegionSpec);
1588 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1589 if (!regions) {
1590 regions = opMinRegionCount;
1591 }
1592 if (*regions < opMinRegionCount) {
1593 throw nb::value_error(join("Operation \"", name,
1594 "\" requires a minimum of ", opMinRegionCount,
1595 " regions but was built with regions=", *regions)
1596 .c_str());
1597 }
1598 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1599 throw nb::value_error(join("Operation \"", name,
1600 "\" requires a maximum of ", opMinRegionCount,
1601 " regions but was built with regions=", *regions)
1602 .c_str());
1603 }
1604
1605 // Unpack results.
1606 std::vector<PyType *> resultTypes;
1607 if (resultTypeList.has_value()) {
1608 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1609 resultSegmentLengths, resultTypes);
1610 }
1611
1612 // Unpack operands.
1613 std::vector<MlirValue> operands;
1614 operands.reserve(operands.size());
1615 size_t index = 0;
1616 if (operandSegmentSpecObj.is_none()) {
1617 // Non-sized operand unpacking.
1618 for (nb::handle operand : operandList) {
1619 try {
1620 operands.push_back(getOpResultOrValue(operand));
1621 } catch (nb::builtin_exception &err) {
1622 throw nb::value_error(join("Operand ", index, " of operation \"", name,
1623 "\" must be a Value (", err.what(), ")")
1624 .c_str());
1625 }
1626 ++index;
1627 }
1628 } else {
1629 // Sized operand unpacking.
1630 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1631 if (operandSegmentSpec.size() != nb::len(operandList)) {
1632 throw nb::value_error(
1633 join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
1634 "operand segments but was provided ", nb::len(operandList))
1635 .c_str());
1636 }
1637 operandSegmentLengths.reserve(nb::len(operandList));
1638 for (size_t i = 0, e = operandSegmentSpec.size(); i < e; ++i) {
1639 int segmentSpec = operandSegmentSpec[i];
1640 if (segmentSpec == 1 || segmentSpec == 0) {
1641 // Unpack unary element.
1642 const nanobind::handle operand = operandList[i];
1643 if (!operand.is_none()) {
1644 try {
1645 operands.push_back(getOpResultOrValue(operand));
1646 } catch (nb::builtin_exception &err) {
1647 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1648 "\" must be a Value (", err.what(), ")")
1649 .c_str());
1650 }
1651
1652 operandSegmentLengths.push_back(1);
1653 } else if (segmentSpec == 0) {
1654 // Allowed to be optional.
1655 operandSegmentLengths.push_back(0);
1656 } else {
1657 throw nb::value_error(
1658 join("Operand ", i, " of operation \"", name,
1659 "\" must be a Value (was None and operand is not optional)")
1660 .c_str());
1661 }
1662 } else if (segmentSpec == -1) {
1663 // Unpack sequence by appending.
1664 try {
1665 if (operandList[i].is_none()) {
1666 // Treat it as an empty list.
1667 operandSegmentLengths.push_back(0);
1668 } else {
1669 // Unpack the list.
1670 auto segment = nb::cast<nb::sequence>(operandList[i]);
1671 for (nb::handle segmentItem : segment) {
1672 operands.push_back(getOpResultOrValue(segmentItem));
1673 }
1674 operandSegmentLengths.push_back(nb::len(segment));
1675 }
1676 } catch (std::exception &err) {
1677 // NOTE: Sloppy to be using a catch-all here, but there are at least
1678 // three different unrelated exceptions that can be thrown in the
1679 // above "casts". Just keep the scope above small and catch them all.
1680 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1681 "\" must be a Sequence of Values (",
1682 err.what(), ")")
1683 .c_str());
1684 }
1685 } else {
1686 throw nb::value_error("Unexpected segment spec");
1687 }
1688 }
1689 }
1690
1691 // Merge operand/result segment lengths into attributes if needed.
1692 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1693 // Dup.
1694 if (attributes) {
1695 attributes = nb::dict(*attributes);
1696 } else {
1697 attributes = nb::dict();
1698 }
1699 if (attributes->contains("resultSegmentSizes") ||
1700 attributes->contains("operandSegmentSizes")) {
1701 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1702 "'operandSegmentSizes' attribute is unsupported. "
1703 "Use Operation.create for such low-level access.");
1704 }
1705
1706 // Add resultSegmentSizes attribute.
1707 if (!resultSegmentLengths.empty()) {
1708 MlirAttribute segmentLengthAttr =
1709 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1710 resultSegmentLengths.data());
1711 (*attributes)["resultSegmentSizes"] =
1712 PyAttribute(context, segmentLengthAttr);
1713 }
1714
1715 // Add operandSegmentSizes attribute.
1716 if (!operandSegmentLengths.empty()) {
1717 MlirAttribute segmentLengthAttr =
1718 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1719 operandSegmentLengths.data());
1720 (*attributes)["operandSegmentSizes"] =
1721 PyAttribute(context, segmentLengthAttr);
1722 }
1723 }
1724
1725 // Delegate to create.
1726 return PyOperation::create(name,
1727 /*results=*/std::move(resultTypes),
1728 /*operands=*/operands.data(),
1729 /*numOperands=*/operands.size(),
1730 /*attributes=*/std::move(attributes),
1731 /*successors=*/std::move(successors),
1732 /*regions=*/*regions, location, maybeIp,
1733 !resultTypeList);
1734}
1735
1736nb::object PyOpView::constructDerived(const nb::object &cls,
1737 const nb::object &operation) {
1738 nb::handle opViewType = nb::type<PyOpView>();
1739 nb::object instance = cls.attr("__new__")(cls);
1740 opViewType.attr("__init__")(instance, operation);
1741 return instance;
1742}
1743
1744PyOpView::PyOpView(const nb::object &operationObject)
1745 // Casting through the PyOperationBase base-class and then back to the
1746 // Operation lets us accept any PyOperationBase subclass.
1747 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1748 operationObject(operation.getRef().getObject()) {}
1750//------------------------------------------------------------------------------
1751// PyAsmState
1752//------------------------------------------------------------------------------
1753
1754PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
1755 flags = mlirOpPrintingFlagsCreate();
1756 // The OpPrintingFlags are not exposed Python side, create locally and
1757 // associate lifetime with the state.
1758 if (useLocalScope)
1760 state = mlirAsmStateCreateForValue(value, flags);
1761}
1762
1763PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
1764 flags = mlirOpPrintingFlagsCreate();
1765 // The OpPrintingFlags are not exposed Python side, create locally and
1766 // associate lifetime with the state.
1767 if (useLocalScope)
1769 state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
1770}
1772//------------------------------------------------------------------------------
1773// PyInsertionPoint.
1774//------------------------------------------------------------------------------
1775
1776PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1779 : refOperation(beforeOperationBase.getOperation().getRef()),
1780 block((*refOperation)->getBlock()) {}
1781
1783 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1784
1785void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1786 PyOperation &operation = operationBase.getOperation();
1787 if (operation.isAttached())
1788 throw nb::value_error(
1789 "Attempt to insert operation that is already attached");
1790 block.getParentOperation()->checkValid();
1791 MlirOperation beforeOp = {nullptr};
1792 if (refOperation) {
1793 // Insert before operation.
1794 (*refOperation)->checkValid();
1795 beforeOp = (*refOperation)->get();
1796 } else {
1797 // Insert at end (before null) is only valid if the block does not
1798 // already end in a known terminator (violating this will cause assertion
1799 // failures later).
1800 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1801 throw nb::index_error("Cannot insert operation at the end of a block "
1802 "that already has a terminator. Did you mean to "
1803 "use 'InsertionPoint.at_block_terminator(block)' "
1804 "versus 'InsertionPoint(block)'?");
1805 }
1807 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1808 operation.setAttached();
1809}
1810
1812 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1813 if (mlirOperationIsNull(firstOp)) {
1814 // Just insert at end.
1815 return PyInsertionPoint(block);
1816 }
1817
1818 // Insert before first op.
1820 block.getParentOperation()->getContext(), firstOp);
1821 return PyInsertionPoint{block, std::move(firstOpRef)};
1822}
1823
1825 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1826 if (mlirOperationIsNull(terminator))
1827 throw nb::value_error("Block has no terminator");
1829 block.getParentOperation()->getContext(), terminator);
1830 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1831}
1832
1834 PyOperation &operation = op.getOperation();
1835 PyBlock block = operation.getBlock();
1836 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
1837 if (mlirOperationIsNull(nextOperation))
1838 return PyInsertionPoint(block);
1840 block.getParentOperation()->getContext(), nextOperation);
1841 return PyInsertionPoint{block, std::move(nextOpRef)};
1842}
1843
1844size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
1846nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
1847 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
1848}
1849
1850void PyInsertionPoint::contextExit(const nb::object &excType,
1851 const nb::object &excVal,
1852 const nb::object &excTb) {
1854}
1856//------------------------------------------------------------------------------
1857// PyAttribute.
1858//------------------------------------------------------------------------------
1860bool PyAttribute::operator==(const PyAttribute &other) const {
1861 return mlirAttributeEqual(attr, other.attr);
1862}
1864nb::object PyAttribute::getCapsule() {
1865 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
1866}
1867
1868PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
1869 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1870 if (mlirAttributeIsNull(rawAttr))
1871 throw nb::python_error();
1872 return PyAttribute(
1874}
1875
1876nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
1877 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
1878 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1879 "mlirTypeID was expected to be non-null.");
1880 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1881 mlirTypeID, mlirAttributeGetDialect(this->get()));
1882 // nb::rv_policy::move means use std::move to move the return value
1883 // contents into a new instance that will be owned by Python.
1884 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1885 if (!typeCaster)
1886 return thisObj;
1887 return typeCaster.value()(thisObj);
1888}
1890//------------------------------------------------------------------------------
1891// PyNamedAttribute.
1892//------------------------------------------------------------------------------
1893
1894PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1895 : ownedName(new std::string(std::move(ownedName))) {
1898 toMlirStringRef(*this->ownedName)),
1899 attr);
1900}
1902//------------------------------------------------------------------------------
1903// PyType.
1904//------------------------------------------------------------------------------
1906bool PyType::operator==(const PyType &other) const {
1907 return mlirTypeEqual(type, other.type);
1908}
1910nb::object PyType::getCapsule() {
1911 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
1912}
1913
1914PyType PyType::createFromCapsule(nb::object capsule) {
1915 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1916 if (mlirTypeIsNull(rawType))
1917 throw nb::python_error();
1919 rawType);
1920}
1921
1922nb::typed<nb::object, PyType> PyType::maybeDownCast() {
1923 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
1924 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1925 "mlirTypeID was expected to be non-null.");
1926 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1927 mlirTypeID, mlirTypeGetDialect(this->get()));
1928 // nb::rv_policy::move means use std::move to move the return value
1929 // contents into a new instance that will be owned by Python.
1930 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1931 if (!typeCaster)
1932 return thisObj;
1933 return typeCaster.value()(thisObj);
1934}
1936//------------------------------------------------------------------------------
1937// PyTypeID.
1938//------------------------------------------------------------------------------
1940nb::object PyTypeID::getCapsule() {
1941 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
1942}
1943
1944PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
1945 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1946 if (mlirTypeIDIsNull(mlirTypeID))
1947 throw nb::python_error();
1948 return PyTypeID(mlirTypeID);
1949}
1950bool PyTypeID::operator==(const PyTypeID &other) const {
1951 return mlirTypeIDEqual(typeID, other.typeID);
1952}
1954//------------------------------------------------------------------------------
1955// PyValue and subclasses.
1956//------------------------------------------------------------------------------
1958nb::object PyValue::getCapsule() {
1959 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
1960}
1961
1962static PyOperationRef getValueOwnerRef(MlirValue value) {
1963 MlirOperation owner;
1964 if (mlirValueIsAOpResult(value))
1965 owner = mlirOpResultGetOwner(value);
1966 else if (mlirValueIsABlockArgument(value))
1968 else
1969 assert(false && "Value must be an block arg or op result.");
1970 if (mlirOperationIsNull(owner))
1971 throw nb::python_error();
1972 MlirContext ctx = mlirOperationGetContext(owner);
1974}
1975
1976nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
1978 MlirType type = mlirValueGetType(get());
1979 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
1980 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1981 "mlirTypeID was expected to be non-null.");
1982 std::optional<nb::callable> valueCaster =
1984 // nb::rv_policy::move means use std::move to move the return value
1985 // contents into a new instance that will be owned by Python.
1986 nb::object thisObj;
1987 if (mlirValueIsAOpResult(value))
1988 thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
1989 else if (mlirValueIsABlockArgument(value))
1990 thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
1991 else
1992 assert(false && "Value must be an block arg or op result.");
1993 if (valueCaster)
1994 return valueCaster.value()(thisObj);
1995 return thisObj;
1996}
1997
1998PyValue PyValue::createFromCapsule(nb::object capsule) {
1999 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2000 if (mlirValueIsNull(value))
2001 throw nb::python_error();
2002 PyOperationRef ownerRef = getValueOwnerRef(value);
2003 return PyValue(ownerRef, value);
2004}
2006//------------------------------------------------------------------------------
2007// PySymbolTable.
2008//------------------------------------------------------------------------------
2009
2011 : operation(operation.getOperation().getRef()) {
2012 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2013 if (mlirSymbolTableIsNull(symbolTable)) {
2014 throw nb::type_error("Operation is not a Symbol Table.");
2015 }
2016}
2017
2018nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2019 operation->checkValid();
2020 MlirOperation symbol = mlirSymbolTableLookup(
2021 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2022 if (mlirOperationIsNull(symbol))
2023 throw nb::key_error(
2024 join("Symbol '", name, "' not in the symbol table.").c_str());
2025
2026 return PyOperation::forOperation(operation->getContext(), symbol,
2027 operation.getObject())
2028 ->createOpView();
2029}
2030
2032 operation->checkValid();
2033 symbol.getOperation().checkValid();
2034 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2035 // The operation is also erased, so we must invalidate it. There may be Python
2036 // references to this operation so we don't want to delete it from the list of
2037 // live operations here.
2038 symbol.getOperation().valid = false;
2039}
2040
2041void PySymbolTable::dunderDel(const std::string &name) {
2042 nb::object operation = dunderGetItem(name);
2043 erase(nb::cast<PyOperationBase &>(operation));
2044}
2045
2047 operation->checkValid();
2048 symbol.getOperation().checkValid();
2049 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2051 if (mlirAttributeIsNull(symbolAttr))
2052 throw nb::value_error("Expected operation to have a symbol name.");
2054 symbol.getOperation().getContext(),
2055 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2056}
2057
2059 // Op must already be a symbol.
2060 PyOperation &operation = symbol.getOperation();
2061 operation.checkValid();
2063 MlirAttribute existingNameAttr =
2064 mlirOperationGetAttributeByName(operation.get(), attrName);
2065 if (mlirAttributeIsNull(existingNameAttr))
2066 throw nb::value_error("Expected operation to have a symbol name.");
2067 return PyStringAttribute(symbol.getOperation().getContext(),
2068 existingNameAttr);
2069}
2070
2072 const std::string &name) {
2073 // Op must already be a symbol.
2074 PyOperation &operation = symbol.getOperation();
2075 operation.checkValid();
2077 MlirAttribute existingNameAttr =
2078 mlirOperationGetAttributeByName(operation.get(), attrName);
2079 if (mlirAttributeIsNull(existingNameAttr))
2080 throw nb::value_error("Expected operation to have a symbol name.");
2081 MlirAttribute newNameAttr =
2082 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2083 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2084}
2085
2087 PyOperation &operation = symbol.getOperation();
2088 operation.checkValid();
2090 MlirAttribute existingVisAttr =
2091 mlirOperationGetAttributeByName(operation.get(), attrName);
2092 if (mlirAttributeIsNull(existingVisAttr))
2093 throw nb::value_error("Expected operation to have a symbol visibility.");
2094 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2095}
2096
2098 const std::string &visibility) {
2099 if (visibility != "public" && visibility != "private" &&
2100 visibility != "nested")
2101 throw nb::value_error(
2102 "Expected visibility to be 'public', 'private' or 'nested'");
2103 PyOperation &operation = symbol.getOperation();
2104 operation.checkValid();
2106 MlirAttribute existingVisAttr =
2107 mlirOperationGetAttributeByName(operation.get(), attrName);
2108 if (mlirAttributeIsNull(existingVisAttr))
2109 throw nb::value_error("Expected operation to have a symbol visibility.");
2110 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2111 toMlirStringRef(visibility));
2112 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2113}
2114
2115void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2116 const std::string &newSymbol,
2117 PyOperationBase &from) {
2118 PyOperation &fromOperation = from.getOperation();
2119 fromOperation.checkValid();
2121 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2123
2124 throw nb::value_error("Symbol rename failed");
2125}
2126
2128 bool allSymUsesVisible,
2129 nb::object callback) {
2130 PyOperation &fromOperation = from.getOperation();
2131 fromOperation.checkValid();
2132 struct UserData {
2133 PyMlirContextRef context;
2134 nb::object callback;
2135 bool gotException;
2136 std::string exceptionWhat;
2137 nb::object exceptionType;
2138 };
2139 UserData userData{
2140 fromOperation.getContext(), std::move(callback), false, {}, {}};
2142 fromOperation.get(), allSymUsesVisible,
2143 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2144 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2145 auto pyFoundOp =
2146 PyOperation::forOperation(calleeUserData->context, foundOp);
2147 if (calleeUserData->gotException)
2148 return;
2149 try {
2150 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2151 } catch (nb::python_error &e) {
2152 calleeUserData->gotException = true;
2153 calleeUserData->exceptionWhat = e.what();
2154 calleeUserData->exceptionType = nb::borrow(e.type());
2155 }
2156 },
2157 static_cast<void *>(&userData));
2158 if (userData.gotException) {
2159 std::string message("Exception raised in callback: ");
2160 message.append(userData.exceptionWhat);
2161 throw std::runtime_error(message);
2162 }
2163}
2164
2165void PyBlockArgument::bindDerived(ClassTy &c) {
2166 c.def_prop_ro(
2167 "owner",
2168 [](PyBlockArgument &self) {
2169 return PyBlock(self.getParentOperation(),
2171 },
2172 "Returns the block that owns this argument.");
2173 c.def_prop_ro(
2174 "arg_number",
2175 [](PyBlockArgument &self) {
2176 return mlirBlockArgumentGetArgNumber(self.get());
2177 },
2178 "Returns the position of this argument in the block's argument list.");
2179 c.def(
2180 "set_type",
2181 [](PyBlockArgument &self, PyType type) {
2182 return mlirBlockArgumentSetType(self.get(), type);
2183 },
2184 "type"_a, "Sets the type of this block argument.");
2185 c.def(
2186 "set_location",
2187 [](PyBlockArgument &self, PyLocation loc) {
2189 },
2190 "loc"_a, "Sets the location of this block argument.");
2191}
2192
2194 MlirBlock block, intptr_t startIndex,
2197 length == -1 ? mlirBlockGetNumArguments(block) : length, step),
2198 operation(std::move(operation)), block(block) {}
2199
2200void PyBlockArgumentList::bindDerived(ClassTy &c) {
2201 c.def_prop_ro(
2202 "types",
2203 [](PyBlockArgumentList &self) {
2204 return getValueTypes(self, self.operation->getContext());
2205 },
2206 "Returns a list of types for all arguments in this argument list.");
2207}
2208
2209intptr_t PyBlockArgumentList::getRawNumElements() {
2210 operation->checkValid();
2211 return mlirBlockGetNumArguments(block);
2212}
2213
2214PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
2215 MlirValue argument = mlirBlockGetArgument(block, pos);
2216 return PyBlockArgument(operation, argument);
2217}
2218
2219PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
2221 intptr_t step) const {
2222 return PyBlockArgumentList(operation, block, startIndex, length, step);
2223}
2224
2226 intptr_t length, intptr_t step)
2227 : Sliceable(startIndex,
2229 : length,
2230 step),
2231 operation(operation) {}
2232
2235 mlirOperationSetOperand(operation->get(), index, value.get());
2236}
2237
2238void PyOpOperandList::bindDerived(ClassTy &c) {
2239 c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
2240 "Sets the operand at the specified index to a new value.");
2241}
2242
2243intptr_t PyOpOperandList::getRawNumElements() {
2244 operation->checkValid();
2245 return mlirOperationGetNumOperands(operation->get());
2246}
2247
2248PyValue PyOpOperandList::getRawElement(intptr_t pos) {
2249 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2250 PyOperationRef pyOwner = getValueOwnerRef(operand);
2251 return PyValue(pyOwner, operand);
2252}
2253
2254PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
2255 intptr_t step) const {
2256 return PyOpOperandList(operation, startIndex, length, step);
2257}
2259/// A list of OpOperands. Internally, these are stored as consecutive elements,
2260/// random access is cheap. The (returned) OpOperand list is associated with the
2261/// operation whose operands these are, and thus extends the lifetime of this
2262/// operation.
2263class PyOpOperands : public Sliceable<PyOpOperands, PyOpOperand> {
2264public:
2265 static constexpr const char *pyClassName = "OpOperands";
2267
2269 intptr_t length = -1, intptr_t step = 1)
2271 length == -1 ? mlirOperationGetNumOperands(operation->get())
2272 : length,
2273 step),
2274 operation(operation) {}
2275
2276private:
2277 /// Give the parent CRTP class access to hook implementations below.
2278 friend class Sliceable<PyOpOperands, PyOpOperand>;
2279
2280 intptr_t getRawNumElements() {
2281 operation->checkValid();
2282 return mlirOperationGetNumOperands(operation->get());
2283 }
2284
2285 PyOpOperand getRawElement(intptr_t pos) {
2286 MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
2287 return PyOpOperand(opOperand);
2288 }
2289
2291 return PyOpOperands(operation, startIndex, length, step);
2293
2294 PyOperationRef operation;
2295};
2296
2298 intptr_t length, intptr_t step)
2299 : Sliceable(startIndex,
2301 : length,
2302 step),
2303 operation(operation) {}
2304
2307 mlirOperationSetSuccessor(operation->get(), index, block.get());
2308}
2309
2310void PyOpSuccessors::bindDerived(ClassTy &c) {
2311 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
2312 "Sets the successor block at the specified index.");
2313}
2314
2315intptr_t PyOpSuccessors::getRawNumElements() {
2316 operation->checkValid();
2317 return mlirOperationGetNumSuccessors(operation->get());
2318}
2319
2320PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
2321 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2322 return PyBlock(operation, block);
2323}
2324
2326 intptr_t step) const {
2327 return PyOpSuccessors(operation, startIndex, length, step);
2328}
2329
2331 intptr_t startIndex, intptr_t length,
2332 intptr_t step)
2333 : Sliceable(startIndex,
2334 length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
2335 step),
2336 operation(operation), block(block) {}
2337
2338intptr_t PyBlockSuccessors::getRawNumElements() {
2339 block.checkValid();
2340 return mlirBlockGetNumSuccessors(block.get());
2341}
2342
2343PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
2344 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2345 return PyBlock(operation, block);
2346}
2347
2349 intptr_t step) const {
2350 return PyBlockSuccessors(block, operation, startIndex, length, step);
2351}
2352
2354 PyOperationRef operation,
2355 intptr_t startIndex, intptr_t length,
2356 intptr_t step)
2357 : Sliceable(startIndex,
2358 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2359 : length,
2360 step),
2361 operation(operation), block(block) {}
2362
2363intptr_t PyBlockPredecessors::getRawNumElements() {
2364 block.checkValid();
2365 return mlirBlockGetNumPredecessors(block.get());
2366}
2367
2368PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
2369 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2370 return PyBlock(operation, block);
2371}
2372
2373PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
2374 intptr_t length,
2375 intptr_t step) const {
2376 return PyBlockPredecessors(block, operation, startIndex, length, step);
2377}
2378
2379nb::typed<nb::object, PyAttribute>
2380PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
2381 MlirAttribute attr =
2383 if (mlirAttributeIsNull(attr)) {
2384 throw nb::key_error("attempt to access a non-existent attribute");
2386 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2387}
2388
2389nb::typed<nb::object, std::optional<PyAttribute>>
2390PyOpAttributeMap::get(const std::string &key, nb::object defaultValue) {
2391 MlirAttribute attr =
2393 if (mlirAttributeIsNull(attr))
2394 return defaultValue;
2395 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2396}
2397
2399 if (index < 0) {
2400 index += dunderLen();
2401 }
2402 if (index < 0 || index >= dunderLen()) {
2403 throw nb::index_error("attempt to access out of bounds attribute");
2404 }
2405 MlirNamedAttribute namedAttr =
2406 mlirOperationGetAttribute(operation->get(), index);
2407 return PyNamedAttribute(
2408 namedAttr.attribute,
2409 std::string(mlirIdentifierStr(namedAttr.name).data,
2410 mlirIdentifierStr(namedAttr.name).length));
2411}
2412
2413void PyOpAttributeMap::dunderSetItem(const std::string &name,
2414 const PyAttribute &attr) {
2415 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2416 attr);
2417}
2418
2419void PyOpAttributeMap::dunderDelItem(const std::string &name) {
2420 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2422 if (!removed)
2423 throw nb::key_error("attempt to delete a non-existent attribute");
2424}
2427 return mlirOperationGetNumAttributes(operation->get());
2428}
2429
2430bool PyOpAttributeMap::dunderContains(const std::string &name) {
2431 return !mlirAttributeIsNull(
2432 mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
2433}
2434
2436 MlirOperation op, std::function<void(MlirStringRef, MlirAttribute)> fn) {
2438 for (intptr_t i = 0; i < n; ++i) {
2441 fn(name, na.attribute);
2442 }
2443}
2444
2445void PyOpAttributeMap::bind(nb::module_ &m) {
2446 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2447 .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
2448 "Checks if an attribute with the given name exists in the map.")
2449 .def("__len__", &PyOpAttributeMap::dunderLen,
2450 "Returns the number of attributes in the map.")
2451 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
2452 "Gets an attribute by name.")
2453 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
2454 "Gets a named attribute by index.")
2455 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
2456 "Sets an attribute with the given name.")
2457 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
2458 "Deletes an attribute with the given name.")
2459 .def("get", &PyOpAttributeMap::get, nb::arg("key"),
2460 nb::arg("default") = nb::none(),
2461 "Gets an attribute by name or the default value, if it does not "
2462 "exist.")
2463 .def(
2464 "__iter__",
2465 [](PyOpAttributeMap &self) -> nb::typed<nb::iterator, nb::str> {
2466 nb::list keys;
2468 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2469 keys.append(nb::str(name.data, name.length));
2470 });
2471 return nb::iter(keys);
2472 },
2473 "Iterates over attribute names.")
2474 .def(
2475 "keys",
2476 [](PyOpAttributeMap &self) -> nb::typed<nb::list, nb::str> {
2477 nb::list out;
2479 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2480 out.append(nb::str(name.data, name.length));
2481 });
2482 return out;
2483 },
2484 "Returns a list of attribute names.")
2485 .def(
2486 "values",
2487 [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
2488 nb::list out;
2490 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
2491 out.append(PyAttribute(self.operation->getContext(), attr)
2492 .maybeDownCast());
2493 });
2494 return out;
2495 },
2496 "Returns a list of attribute values.")
2497 .def(
2498 "items",
2499 [](PyOpAttributeMap &self)
2500 -> nb::typed<nb::list,
2501 nb::typed<nb::tuple, nb::str, PyAttribute>> {
2502 nb::list out;
2504 self.operation->get(),
2505 [&](MlirStringRef name, MlirAttribute attr) {
2506 out.append(nb::make_tuple(
2507 nb::str(name.data, name.length),
2508 PyAttribute(self.operation->getContext(), attr)
2509 .maybeDownCast()));
2510 });
2511 return out;
2512 },
2513 "Returns a list of `(name, attribute)` tuples.");
2514}
2515
2516void PyOpAdaptor::bind(nb::module_ &m) {
2517 nb::class_<PyOpAdaptor>(m, "OpAdaptor")
2518 .def(nb::init<nb::typed<nb::list, PyValue>, PyOpAttributeMap>(),
2519 "Creates an OpAdaptor with the given operands and attributes.",
2520 "operands"_a, "attributes"_a)
2521 .def(nb::init<nb::typed<nb::list, PyValue>, PyOpView &>(),
2522 "Creates an OpAdaptor with the given operands and operation view.",
2523 "operands"_a, "opview"_a)
2524 .def_prop_ro(
2525 "operands", [](PyOpAdaptor &self) { return self.operands; },
2526 "Returns the operands of the adaptor.")
2527 .def_prop_ro(
2528 "attributes", [](PyOpAdaptor &self) { return self.attributes; },
2529 "Returns the attributes of the adaptor.");
2530}
2531
2532static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData,
2533 const char *methodName) {
2534 nb::handle targetObj(static_cast<PyObject *>(userData));
2535 if (!nb::hasattr(targetObj, methodName))
2536 return mlirLogicalResultSuccess();
2538 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
2539 bool success = nb::cast<bool>(targetObj.attr(methodName)(opView));
2541};
2542
2543static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
2544 PyMlirContext &context) {
2545 std::string opNameStr;
2546 if (opName.is_type()) {
2547 opNameStr = nb::cast<std::string>(opName.attr("OPERATION_NAME"));
2548 } else if (nb::isinstance<nb::str>(opName)) {
2549 opNameStr = nb::cast<std::string>(opName);
2550 } else {
2551 throw nb::type_error("the root argument must be a type or a string");
2552 }
2555 trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get());
2556}
2557
2558bool PyDynamicOpTrait::attach(const nb::object &opName,
2559 const nb::object &target,
2560 PyMlirContext &context) {
2561 if (!nb::hasattr(target, "verify_invariants") &&
2562 !nb::hasattr(target, "verify_region_invariants"))
2563 throw nb::type_error(
2564 "the target object must have at least one of 'verify_invariants' or "
2565 "'verify_region_invariants' methods");
2566
2568 callbacks.construct = [](void *userData) {
2569 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
2570 };
2571 callbacks.destruct = [](void *userData) {
2572 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
2573 };
2574
2575 callbacks.verifyTrait = [](MlirOperation op,
2576 void *userData) -> MlirLogicalResult {
2577 return verifyTraitByMethod(op, userData, "verify_invariants");
2578 };
2579 callbacks.verifyRegionTrait = [](MlirOperation op,
2580 void *userData) -> MlirLogicalResult {
2581 return verifyTraitByMethod(op, userData, "verify_region_invariants");
2582 };
2583
2584 // To ensure that the same dynamic trait gets the same TypeID despite how many
2585 // times `attach` is called, we store it as an attribute on the target class.
2586 if (!nb::hasattr(target, typeIDAttr)) {
2587 nb::setattr(target, typeIDAttr,
2588 nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
2589 }
2590 MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate(
2591 nb::cast<PyTypeID>(target.attr(typeIDAttr)).get(), callbacks,
2592 static_cast<void *>(target.ptr()));
2593 return attachOpTrait(opName, trait, context);
2594}
2595
2596void PyDynamicOpTrait::bind(nb::module_ &m) {
2597 nb::class_<PyDynamicOpTrait> cls(m, "DynamicOpTrait");
2598 cls.attr("attach") = classmethod(
2599 [](const nb::object &cls, const nb::object &opName, nb::object target,
2600 DefaultingPyMlirContext context) {
2601 if (target.is_none())
2602 target = cls;
2603 return PyDynamicOpTrait::attach(opName, target, *context.get());
2604 },
2605 nb::arg("cls"), nb::arg("op_name"), nb::arg("target").none() = nb::none(),
2606 nb::arg("context").none() = nb::none(),
2607 "Attach the dynamic op trait subclass to the given operation name.");
2608}
2609
2610bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
2611 PyMlirContext &context) {
2612 MlirDynamicOpTrait trait = mlirDynamicOpTraitIsTerminatorCreate();
2613 return attachOpTrait(opName, trait, context);
2614}
2615
2616void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
2617 nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
2618 m, "IsTerminatorTrait");
2619 cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
2620 cls.attr("attach") = classmethod(
2621 [](const nb::object &cls, const nb::object &opName,
2622 DefaultingPyMlirContext context) {
2623 return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get());
2625 "Attach IsTerminator trait to the given operation name.", nb::arg("cls"),
2626 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2627}
2628
2629bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
2630 PyMlirContext &context) {
2631 MlirDynamicOpTrait trait = mlirDynamicOpTraitNoTerminatorCreate();
2632 return attachOpTrait(opName, trait, context);
2633}
2634
2635void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
2636 nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
2637 m, "NoTerminatorTrait");
2638 cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
2639 cls.attr("attach") = classmethod(
2640 [](const nb::object &cls, const nb::object &opName,
2641 DefaultingPyMlirContext context) {
2642 return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get());
2643 },
2644 "Attach NoTerminator trait to the given operation name.", nb::arg("cls"),
2645 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2646}
2647
2648} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
2649} // namespace python
2650} // namespace mlir
2651
2652namespace {
2653
2654using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
2655
2656MlirLocation tracebackToLocation(MlirContext ctx) {
2657#if defined(Py_LIMITED_API)
2658 // Frame introspection C APIs are not available under the limited API.
2659 // Traceback-based auto-location is not supported; return unknown.
2660 return mlirLocationUnknownGet(ctx);
2661#else
2662 size_t framesLimit =
2664 // Use a thread_local here to avoid requiring a large amount of space.
2665 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2666 frames;
2667 size_t count = 0;
2668
2669 nb::gil_scoped_acquire acquire;
2670
2671 PyThreadState *tstate = PyThreadState_GET();
2672 PyFrameObject *next;
2673 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2674 // In the increment expression:
2675 // 1. get the next prev frame;
2676 // 2. decrement the ref count on the current frame (in order that it can get
2677 // gc'd, along with any objects in its closure and etc);
2678 // 3. set current = next.
2679 for (; pyFrame != nullptr && count < framesLimit;
2680 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2681 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2682 auto fileNameStr =
2683 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2684 std::string_view fileName(fileNameStr);
2685 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2686 continue;
2687
2688 // co_qualname and PyCode_Addr2Location added in py3.11
2689#if PY_VERSION_HEX < 0x030B00F0
2690 std::string name =
2691 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2692 std::string_view funcName(name);
2693 int startLine = PyFrame_GetLineNumber(pyFrame);
2694 MlirLocation loc = mlirLocationFileLineColGet(
2695 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2696 0);
2697#else
2698 std::string name =
2699 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2700 std::string_view funcName(name);
2701 int startLine, startCol, endLine, endCol;
2702 int lasti = PyFrame_GetLasti(pyFrame);
2703 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2704 &endCol)) {
2705 throw nb::python_error();
2706 }
2707 MlirLocation loc = mlirLocationFileLineColRangeGet(
2708 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2709 startCol, endLine, endCol);
2710#endif
2711
2712 frames[count] = mlirLocationNameGet(
2713 ctx, mlirStringRefCreate(funcName.data(), funcName.size()), loc);
2714 ++count;
2715 }
2716 // When the loop breaks (after the last iter), current frame (if non-null)
2717 // is leaked without this.
2718 Py_XDECREF(pyFrame);
2719
2720 if (count == 0)
2721 return mlirLocationUnknownGet(ctx);
2722
2723 MlirLocation callee = frames[0];
2724 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2725 if (count == 1)
2726 return callee;
2727
2728 MlirLocation caller = frames[count - 1];
2729 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2730 for (int i = count - 2; i >= 1; i--)
2731 caller = mlirLocationCallSiteGet(frames[i], caller);
2732
2733 return mlirLocationCallSiteGet(callee, caller);
2734#endif
2735}
2736
2737PyLocation
2738maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2739 if (location.has_value())
2740 return location.value();
2741 if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2743
2744 PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2745 MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2747 return {ref, mlirLoc};
2748}
2749} // namespace
2750
2751namespace mlir {
2752namespace python {
2754
2755static std::string formatMLIRError(const MLIRError &e) {
2756 auto locStr = [](const PyLocation &loc) {
2757 PyPrintAccumulator accum;
2758 mlirLocationPrint(loc, accum.getCallback(), accum.getUserData());
2759 std::string s = nb::cast<std::string>(nb::str(accum.join()));
2760 std::string_view sv(s);
2761 if (sv.size() > 5) {
2762 sv.remove_prefix(4); // "loc("
2763 sv.remove_suffix(1); // ")"
2764 }
2765 return std::string(sv);
2766 };
2767 auto indent = [](std::string s) {
2768 size_t pos = 0;
2769 while ((pos = s.find('\n', pos)) != std::string::npos) {
2770 s.replace(pos, 1, "\n ");
2771 pos += 3;
2772 }
2773 return s;
2774 };
2775
2776 std::ostringstream os;
2777 os << e.message;
2778 if (!e.errorDiagnostics.empty())
2779 os << ":";
2780 for (const auto &diag : e.errorDiagnostics) {
2781 os << "\nerror: " << locStr(diag.location) << ": " << indent(diag.message);
2782 for (const auto &note : diag.notes) {
2783 os << "\n note: " << locStr(note.location) << ": "
2784 << indent(note.message);
2785 }
2786 }
2787 return os.str();
2788}
2789
2790void MLIRError::bind(nb::module_ &m) {
2791 auto cls = nb::exception<MLIRError>(m, "MLIRError", PyExc_Exception);
2792 nb::register_exception_translator(
2793 [](const std::exception_ptr &p, void *payload) {
2794 try {
2795 if (p)
2796 std::rethrow_exception(p);
2797 } catch (MLIRError &e) {
2798 std::string formatted = formatMLIRError(e);
2799 nb::object ty = nb::borrow(static_cast<PyObject *>(payload));
2800 nb::object obj = ty(formatted);
2801 obj.attr("_message") = nb::cast(std::move(e.message));
2802 obj.attr("_error_diagnostics") =
2803 nb::cast(std::move(e.errorDiagnostics));
2804 PyErr_SetObject(static_cast<PyObject *>(payload), obj.ptr());
2805 }
2806 },
2807 cls.ptr());
2808 auto propertyType = nb::borrow<nb::type_object>(
2809 reinterpret_cast<PyObject *>(&PyProperty_Type));
2810 nb::setattr(
2811 cls, "message",
2812 propertyType(nb::cpp_function(
2813 [](nb::object self) -> nb::str { return self.attr("_message"); },
2814 nb::is_method())));
2815 nb::setattr(cls, "error_diagnostics",
2816 propertyType(nb::cpp_function(
2817 [](nb::object self)
2818 -> nb::typed<nb::list, PyDiagnostic::DiagnosticInfo> {
2819 return self.attr("_error_diagnostics");
2820 },
2821 nb::is_method())));
2822}
2823
2824void populateRoot(nb::module_ &m) {
2825 m.attr("T") = nb::type_var("T");
2826 m.attr("U") = nb::type_var("U");
2827
2828 nb::class_<PyGlobals>(m, "_Globals")
2829 .def_prop_rw("dialect_search_modules",
2832 .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
2833 "module_name"_a)
2834 .def(
2835 "_check_dialect_module_loaded",
2836 [](PyGlobals &self, const std::string &dialectNamespace) {
2837 return self.loadDialectModule(dialectNamespace);
2838 },
2839 "dialect_namespace"_a)
2840 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
2841 "dialect_namespace"_a, "dialect_class"_a, nb::kw_only(),
2842 "replace"_a = false,
2843 "Testing hook for directly registering a dialect")
2844 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
2845 "operation_name"_a, "operation_class"_a, nb::kw_only(),
2846 "replace"_a = false,
2847 "Testing hook for directly registering an operation")
2848 .def("loc_tracebacks_enabled",
2849 [](PyGlobals &self) {
2850 return self.getTracebackLoc().locTracebacksEnabled();
2851 })
2852 .def("set_loc_tracebacks_enabled",
2853 [](PyGlobals &self, bool enabled) {
2855 })
2856 .def("loc_tracebacks_frame_limit",
2857 [](PyGlobals &self) {
2859 })
2860 .def("set_loc_tracebacks_frame_limit",
2861 [](PyGlobals &self, std::optional<int> n) {
2864 })
2865 .def("register_traceback_file_inclusion",
2866 [](PyGlobals &self, const std::string &filename) {
2868 })
2869 .def("register_traceback_file_exclusion",
2870 [](PyGlobals &self, const std::string &filename) {
2872 });
2873
2874 // Aside from making the globals accessible to python, having python manage
2875 // it is necessary to make sure it is destroyed (and releases its python
2876 // resources) properly.
2877 m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
2878
2879 // Registration decorators.
2880 m.def(
2881 "register_dialect",
2882 [](nb::type_object pyClass) {
2883 std::string dialectNamespace =
2884 nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
2885 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
2886 return pyClass;
2887 },
2888 "dialect_class"_a,
2889 "Class decorator for registering a custom Dialect wrapper");
2890 m.def(
2891 "register_operation",
2892 [](const nb::type_object &dialectClass, bool replace) -> nb::object {
2893 return nb::cpp_function(
2894 [dialectClass,
2895 replace](nb::type_object opClass) -> nb::type_object {
2896 std::string operationName =
2897 nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
2898 PyGlobals::get().registerOperationImpl(operationName, opClass,
2899 replace);
2900 // Dict-stuff the new opClass by name onto the dialect class.
2901 nb::object opClassName = opClass.attr("__name__");
2902 dialectClass.attr(opClassName) = opClass;
2903 return opClass;
2904 });
2905 },
2906 // clang-format off
2907 nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
2908 "-> typing.Callable[[type[T]], type[T]]"),
2909 // clang-format on
2910 "dialect_class"_a, nb::kw_only(), "replace"_a = false,
2911 "Produce a class decorator for registering an Operation class as part of "
2912 "a dialect");
2913 m.def(
2914 "register_op_adaptor",
2915 [](const nb::type_object &opClass, bool replace) -> nb::object {
2916 return nb::cpp_function(
2917 [opClass,
2918 replace](nb::type_object adaptorClass) -> nb::type_object {
2919 std::string operationName =
2920 nb::cast<std::string>(adaptorClass.attr("OPERATION_NAME"));
2921 PyGlobals::get().registerOpAdaptorImpl(operationName,
2922 adaptorClass, replace);
2923 // Dict-stuff the new adaptorClass by name onto the opClass.
2924 opClass.attr("Adaptor") = adaptorClass;
2925 return adaptorClass;
2926 });
2927 },
2928 // clang-format off
2929 nb::sig("def register_op_adaptor(op_class: type, *, replace: bool = False) "
2930 "-> typing.Callable[[type[T]], type[T]]"),
2931 // clang-format on
2932 "op_class"_a, nb::kw_only(), "replace"_a = false,
2933 "Produce a class decorator for registering an OpAdaptor class for an "
2934 "operation.");
2935 m.def(
2937 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2938 return nb::cpp_function([mlirTypeID, replace](
2939 nb::callable typeCaster) -> nb::object {
2940 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
2941 return typeCaster;
2942 });
2943 },
2944 // clang-format off
2945 nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2946 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2947 // clang-format on
2948 "typeid"_a, nb::kw_only(), "replace"_a = false,
2949 "Register a type caster for casting MLIR types to custom user types.");
2950 m.def(
2952 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2953 return nb::cpp_function(
2954 [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
2955 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
2956 replace);
2957 return valueCaster;
2958 });
2959 },
2960 // clang-format off
2961 nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2962 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2963 // clang-format on
2964 "typeid"_a, nb::kw_only(), "replace"_a = false,
2965 "Register a value caster for casting MLIR values to custom user values.");
2966}
2967
2968//------------------------------------------------------------------------------
2969// Populates the core exports of the 'ir' submodule.
2970//------------------------------------------------------------------------------
2971void populateIRCore(nb::module_ &m) {
2972 //----------------------------------------------------------------------------
2973 // Enums.
2974 //----------------------------------------------------------------------------
2975 nb::enum_<PyDiagnosticSeverity>(m, "DiagnosticSeverity")
2976 .value("ERROR", PyDiagnosticSeverity::Error)
2977 .value("WARNING", PyDiagnosticSeverity::Warning)
2978 .value("NOTE", PyDiagnosticSeverity::Note)
2979 .value("REMARK", PyDiagnosticSeverity::Remark);
2980
2981 nb::enum_<PyWalkOrder>(m, "WalkOrder")
2982 .value("PRE_ORDER", PyWalkOrder::PreOrder)
2983 .value("POST_ORDER", PyWalkOrder::PostOrder);
2984 nb::enum_<PyWalkResult>(m, "WalkResult")
2985 .value("ADVANCE", PyWalkResult::Advance)
2986 .value("INTERRUPT", PyWalkResult::Interrupt)
2987 .value("SKIP", PyWalkResult::Skip);
2988
2989 //----------------------------------------------------------------------------
2990 // Mapping of Diagnostics.
2991 //----------------------------------------------------------------------------
2992 nb::class_<PyDiagnostic>(m, "Diagnostic")
2993 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
2994 "Returns the severity of the diagnostic.")
2995 .def_prop_ro("location", &PyDiagnostic::getLocation,
2996 "Returns the location associated with the diagnostic.")
2997 .def_prop_ro("message", &PyDiagnostic::getMessage,
2998 "Returns the message text of the diagnostic.")
2999 .def_prop_ro("notes", &PyDiagnostic::getNotes,
3000 "Returns a tuple of attached note diagnostics.")
3001 .def(
3002 "__str__",
3003 [](PyDiagnostic &self) -> nb::str {
3004 if (!self.isValid())
3005 return nb::str("<Invalid Diagnostic>");
3006 return self.getMessage();
3007 },
3008 "Returns the diagnostic message as a string.");
3009
3010 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
3011 .def(
3012 "__init__",
3014 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
3015 },
3016 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
3017 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
3018 "The severity level of the diagnostic.")
3019 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
3020 "The location associated with the diagnostic.")
3021 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
3022 "The message text of the diagnostic.")
3023 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
3024 "List of attached note diagnostics.")
3025 .def(
3026 "__str__",
3027 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
3028 "Returns the diagnostic message as a string.");
3029
3030 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
3031 .def("detach", &PyDiagnosticHandler::detach,
3032 "Detaches the diagnostic handler from the context.")
3033 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
3034 "Returns True if the handler is attached to a context.")
3035 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
3036 "Returns True if an error was encountered during diagnostic "
3037 "handling.")
3038 .def("__enter__", &PyDiagnosticHandler::contextEnter,
3039 "Enters the diagnostic handler as a context manager.",
3040 nb::sig("def __enter__(self, /) -> DiagnosticHandler"))
3041 .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
3042 "exc_value"_a.none(), "traceback"_a.none(),
3043 "Exits the diagnostic handler context manager.");
3044
3045 // Expose DefaultThreadPool to python
3046 nb::class_<PyThreadPool>(m, "ThreadPool")
3047 .def(
3048 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
3049 "Creates a new thread pool with default concurrency.")
3050 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
3051 "Returns the maximum number of threads in the pool.")
3052 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
3053 "Returns the raw pointer to the LLVM thread pool as a string.");
3054
3055 nb::class_<PyMlirContext>(m, "Context")
3056 .def(
3057 "__init__",
3058 [](PyMlirContext &self) {
3059 MlirContext context = mlirContextCreateWithThreading(false);
3060 new (&self) PyMlirContext(context);
3061 },
3062 R"(
3063 Creates a new MLIR context.
3064
3065 The context is the top-level container for all MLIR objects. It owns the storage
3066 for types, attributes, locations, and other core IR objects. A context can be
3067 configured to allow or disallow unregistered dialects and can have dialects
3068 loaded on-demand.)")
3069 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
3070 "Gets the number of live Context objects.")
3071 .def(
3072 "_get_context_again",
3073 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
3075 return ref.releaseObject();
3076 },
3077 "Gets another reference to the same context.")
3078 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
3079 "Gets the number of live modules owned by this context.")
3081 "Gets a capsule wrapping the MlirContext.")
3084 "Creates a Context from a capsule wrapping MlirContext.")
3085 .def("__enter__", &PyMlirContext::contextEnter,
3086 "Enters the context as a context manager.",
3087 nb::sig("def __enter__(self, /) -> Context"))
3088 .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
3089 "exc_value"_a.none(), "traceback"_a.none(),
3090 "Exits the context manager.")
3091 .def_prop_ro_static(
3092 "current",
3093 [](nb::object & /*class*/)
3094 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3096 if (!context)
3097 return {};
3098 return nb::cast(context);
3099 },
3100 nb::sig("def current(/) -> Context | None"),
3101 "Gets the Context bound to the current thread or returns None if no "
3102 "context is set.")
3103 .def_prop_ro(
3104 "dialects",
3105 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3106 "Gets a container for accessing dialects by name.")
3107 .def_prop_ro(
3108 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3109 "Alias for `dialects`.")
3110 .def(
3111 "get_dialect_descriptor",
3112 [=](PyMlirContext &self, std::string &name) {
3113 MlirDialect dialect = mlirContextGetOrLoadDialect(
3114 self.get(), {name.data(), name.size()});
3115 if (mlirDialectIsNull(dialect)) {
3116 throw nb::value_error(
3117 join("Dialect '", name, "' not found").c_str());
3118 }
3119 return PyDialectDescriptor(self.getRef(), dialect);
3120 },
3121 "dialect_name"_a,
3122 "Gets or loads a dialect by name, returning its descriptor object.")
3123 .def_prop_rw(
3124 "allow_unregistered_dialects",
3125 [](PyMlirContext &self) -> bool {
3126 return mlirContextGetAllowUnregisteredDialects(self.get());
3127 },
3128 [](PyMlirContext &self, bool value) {
3129 mlirContextSetAllowUnregisteredDialects(self.get(), value);
3130 },
3131 "Controls whether unregistered dialects are allowed in this context.")
3132 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3133 "callback"_a,
3134 "Attaches a diagnostic handler that will receive callbacks.")
3135 .def(
3136 "enable_multithreading",
3137 [](PyMlirContext &self, bool enable) {
3138 mlirContextEnableMultithreading(self.get(), enable);
3139 },
3140 "enable"_a,
3141 R"(
3142 Enables or disables multi-threading support in the context.
3143
3144 Args:
3145 enable: Whether to enable (True) or disable (False) multi-threading.
3146 )")
3147 .def(
3148 "set_thread_pool",
3149 [](PyMlirContext &self, PyThreadPool &pool) {
3150 // we should disable multi-threading first before setting
3151 // new thread pool otherwise the assert in
3152 // MLIRContext::setThreadPool will be raised.
3153 mlirContextEnableMultithreading(self.get(), false);
3154 mlirContextSetThreadPool(self.get(), pool.get());
3155 },
3156 R"(
3157 Sets a custom thread pool for the context to use.
3158
3159 Args:
3160 pool: A ThreadPool object to use for parallel operations.
3161
3162 Note:
3163 Multi-threading is automatically disabled before setting the thread pool.)")
3164 .def(
3165 "get_num_threads",
3166 [](PyMlirContext &self) {
3167 return mlirContextGetNumThreads(self.get());
3168 },
3169 "Gets the number of threads in the context's thread pool.")
3170 .def(
3171 "_mlir_thread_pool_ptr",
3172 [](PyMlirContext &self) {
3173 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3174 std::stringstream ss;
3175 ss << pool.ptr;
3176 return ss.str();
3177 },
3178 "Gets the raw pointer to the LLVM thread pool as a string.")
3179 .def(
3180 "is_registered_operation",
3181 [](PyMlirContext &self, std::string &name) {
3183 self.get(), MlirStringRef{name.data(), name.size()});
3184 },
3185 "operation_name"_a,
3186 R"(
3187 Checks whether an operation with the given name is registered.
3188
3189 Args:
3190 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3191
3192 Returns:
3193 True if the operation is registered, False otherwise.)")
3194 .def(
3195 "append_dialect_registry",
3196 [](PyMlirContext &self, PyDialectRegistry &registry) {
3197 mlirContextAppendDialectRegistry(self.get(), registry);
3198 },
3199 "registry"_a,
3200 R"(
3201 Appends the contents of a dialect registry to the context.
3202
3203 Args:
3204 registry: A DialectRegistry containing dialects to append.)")
3205 .def_prop_rw("emit_error_diagnostics",
3208 R"(
3209 Controls whether error diagnostics are emitted to diagnostic handlers.
3210
3211 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3212 .def(
3213 "load_all_available_dialects",
3214 [](PyMlirContext &self) {
3216 },
3217 R"(
3218 Loads all dialects available in the registry into the context.
3219
3220 This eagerly loads all dialects that have been registered, making them
3221 immediately available for use.)");
3222
3223 //----------------------------------------------------------------------------
3224 // Mapping of PyDialectDescriptor
3225 //----------------------------------------------------------------------------
3226 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3227 .def_prop_ro(
3228 "namespace",
3229 [](PyDialectDescriptor &self) {
3230 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3231 return nb::str(ns.data, ns.length);
3232 },
3233 "Returns the namespace of the dialect.")
3234 .def(
3235 "__repr__",
3236 [](PyDialectDescriptor &self) {
3237 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3238 std::string repr("<DialectDescriptor ");
3239 repr.append(ns.data, ns.length);
3240 repr.append(">");
3241 return repr;
3242 },
3243 nb::sig("def __repr__(self) -> str"),
3244 "Returns a string representation of the dialect descriptor.");
3245
3246 //----------------------------------------------------------------------------
3247 // Mapping of PyDialects
3248 //----------------------------------------------------------------------------
3249 nb::class_<PyDialects>(m, "Dialects")
3250 .def(
3251 "__getitem__",
3252 [=](PyDialects &self, std::string keyName) {
3253 MlirDialect dialect =
3254 self.getDialectForKey(keyName, /*attrError=*/false);
3255 nb::object descriptor =
3256 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3257 return createCustomDialectWrapper(keyName, std::move(descriptor));
3258 },
3259 "Gets a dialect by name using subscript notation.")
3260 .def(
3261 "__getattr__",
3262 [=](PyDialects &self, std::string attrName) {
3263 MlirDialect dialect =
3264 self.getDialectForKey(attrName, /*attrError=*/true);
3265 nb::object descriptor =
3266 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3267 return createCustomDialectWrapper(attrName, std::move(descriptor));
3268 },
3269 "Gets a dialect by name using attribute notation.");
3270
3271 //----------------------------------------------------------------------------
3272 // Mapping of PyDialect
3273 //----------------------------------------------------------------------------
3274 nb::class_<PyDialect>(m, "Dialect")
3275 .def(nb::init<nb::object>(), "descriptor"_a,
3276 "Creates a Dialect from a DialectDescriptor.")
3277 .def_prop_ro(
3278 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3279 "Returns the DialectDescriptor for this dialect.")
3280 .def(
3281 "__repr__",
3282 [](const nb::object &self) {
3283 auto clazz = self.attr("__class__");
3284 return nb::str("<Dialect ") +
3285 self.attr("descriptor").attr("namespace") +
3286 nb::str(" (class ") + clazz.attr("__module__") +
3287 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3288 },
3289 nb::sig("def __repr__(self) -> str"),
3290 "Returns a string representation of the dialect.");
3291
3292 //----------------------------------------------------------------------------
3293 // Mapping of PyDialectRegistry
3294 //----------------------------------------------------------------------------
3295 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3297 "Gets a capsule wrapping the MlirDialectRegistry.")
3300 "Creates a DialectRegistry from a capsule wrapping "
3301 "`MlirDialectRegistry`.")
3302 .def(nb::init<>(), "Creates a new empty dialect registry.");
3303
3304 //----------------------------------------------------------------------------
3305 // Mapping of Location
3306 //----------------------------------------------------------------------------
3307 nb::class_<PyLocation>(m, "Location")
3309 "Gets a capsule wrapping the MlirLocation.")
3311 "Creates a Location from a capsule wrapping MlirLocation.")
3312 .def("__enter__", &PyLocation::contextEnter,
3313 "Enters the location as a context manager.",
3314 nb::sig("def __enter__(self, /) -> Location"))
3315 .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
3316 "exc_value"_a.none(), "traceback"_a.none(),
3317 "Exits the location context manager.")
3318 .def(
3319 "__eq__",
3320 [](PyLocation &self, PyLocation &other) -> bool {
3321 return mlirLocationEqual(self, other);
3322 },
3323 "Compares two locations for equality.")
3324 .def(
3325 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3326 "Compares location with non-location object (always returns False).")
3327 .def_prop_ro_static(
3328 "current",
3329 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3331 if (!loc)
3332 return std::nullopt;
3333 return loc;
3334 },
3335 // clang-format off
3336 nb::sig("def current(/) -> Location | None"),
3337 // clang-format on
3338 "Gets the Location bound to the current thread or raises ValueError.")
3339 .def_static(
3340 "unknown",
3341 [](DefaultingPyMlirContext context) {
3342 return PyLocation(context->getRef(),
3343 mlirLocationUnknownGet(context->get()));
3344 },
3345 "context"_a = nb::none(),
3346 "Gets a Location representing an unknown location.")
3347 .def_static(
3348 "callsite",
3349 [](PyLocation callee, const std::vector<PyLocation> &frames,
3350 DefaultingPyMlirContext context) {
3351 if (frames.empty())
3352 throw nb::value_error("No caller frames provided.");
3353 MlirLocation caller = frames.back().get();
3354 for (size_t index = frames.size() - 1; index-- > 0;) {
3355 caller = mlirLocationCallSiteGet(frames[index].get(), caller);
3356 }
3357 return PyLocation(context->getRef(),
3358 mlirLocationCallSiteGet(callee.get(), caller));
3359 },
3360 "callee"_a, "frames"_a, "context"_a = nb::none(),
3361 "Gets a Location representing a caller and callsite.")
3362 .def("is_a_callsite", mlirLocationIsACallSite,
3363 "Returns True if this location is a CallSiteLoc.")
3364 .def_prop_ro(
3365 "callee",
3366 [](PyLocation &self) {
3367 return PyLocation(self.getContext(),
3369 },
3370 "Gets the callee location from a CallSiteLoc.")
3371 .def_prop_ro(
3372 "caller",
3373 [](PyLocation &self) {
3374 return PyLocation(self.getContext(),
3376 },
3377 "Gets the caller location from a CallSiteLoc.")
3378 .def_static(
3379 "file",
3380 [](std::string filename, int line, int col,
3381 DefaultingPyMlirContext context) {
3382 return PyLocation(
3383 context->getRef(),
3385 context->get(), toMlirStringRef(filename), line, col));
3386 },
3387 "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
3388 "Gets a Location representing a file, line and column.")
3389 .def_static(
3390 "file",
3391 [](std::string filename, int startLine, int startCol, int endLine,
3392 int endCol, DefaultingPyMlirContext context) {
3393 return PyLocation(context->getRef(),
3395 context->get(), toMlirStringRef(filename),
3396 startLine, startCol, endLine, endCol));
3397 },
3398 "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
3399 "end_col"_a, "context"_a = nb::none(),
3400 "Gets a Location representing a file, line and column range.")
3401 .def("is_a_file", mlirLocationIsAFileLineColRange,
3402 "Returns True if this location is a FileLineColLoc.")
3403 .def_prop_ro(
3404 "filename",
3405 [](PyLocation loc) {
3406 return mlirIdentifierStr(
3408 },
3409 "Gets the filename from a FileLineColLoc.")
3410 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
3411 "Gets the start line number from a `FileLineColLoc`.")
3412 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
3413 "Gets the start column number from a `FileLineColLoc`.")
3414 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
3415 "Gets the end line number from a `FileLineColLoc`.")
3416 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
3417 "Gets the end column number from a `FileLineColLoc`.")
3418 .def_static(
3419 "fused",
3420 [](const std::vector<PyLocation> &pyLocations,
3421 std::optional<PyAttribute> metadata,
3422 DefaultingPyMlirContext context) {
3423 std::vector<MlirLocation> locations;
3424 locations.reserve(pyLocations.size());
3425 for (const PyLocation &pyLocation : pyLocations)
3426 locations.push_back(pyLocation.get());
3427 MlirLocation location = mlirLocationFusedGet(
3428 context->get(), locations.size(), locations.data(),
3429 metadata ? metadata->get() : MlirAttribute{0});
3430 return PyLocation(context->getRef(), location);
3431 },
3432 "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
3433 "Gets a Location representing a fused location with optional "
3434 "metadata.")
3435 .def("is_a_fused", mlirLocationIsAFused,
3436 "Returns True if this location is a `FusedLoc`.")
3437 .def_prop_ro(
3438 "locations",
3439 [](PyLocation &self) {
3440 unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3441 std::vector<MlirLocation> locations(numLocations);
3442 if (numLocations)
3443 mlirLocationFusedGetLocations(self, locations.data());
3444 std::vector<PyLocation> pyLocations{};
3445 pyLocations.reserve(numLocations);
3446 for (unsigned i = 0; i < numLocations; ++i)
3447 pyLocations.emplace_back(self.getContext(), locations[i]);
3448 return pyLocations;
3449 },
3450 "Gets the list of locations from a `FusedLoc`.")
3451 .def_static(
3452 "name",
3453 [](std::string name, std::optional<PyLocation> childLoc,
3454 DefaultingPyMlirContext context) {
3455 return PyLocation(
3456 context->getRef(),
3458 context->get(), toMlirStringRef(name),
3459 childLoc ? childLoc->get()
3460 : mlirLocationUnknownGet(context->get())));
3461 },
3462 "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
3463 "Gets a Location representing a named location with optional child "
3464 "location.")
3465 .def("is_a_name", mlirLocationIsAName,
3466 "Returns True if this location is a `NameLoc`.")
3467 .def_prop_ro(
3468 "name_str",
3469 [](PyLocation loc) {
3471 },
3472 "Gets the name string from a `NameLoc`.")
3473 .def_prop_ro(
3474 "child_loc",
3475 [](PyLocation &self) {
3476 return PyLocation(self.getContext(),
3478 },
3479 "Gets the child location from a `NameLoc`.")
3480 .def_static(
3481 "from_attr",
3482 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3483 return PyLocation(context->getRef(),
3484 mlirLocationFromAttribute(attribute));
3485 },
3486 "attribute"_a, "context"_a = nb::none(),
3487 "Gets a Location from a `LocationAttr`.")
3488 .def_prop_ro(
3489 "context",
3490 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3491 return self.getContext().getObject();
3492 },
3493 "Context that owns the `Location`.")
3494 .def_prop_ro(
3495 "attr",
3496 [](PyLocation &self) {
3497 return PyAttribute(self.getContext(),
3499 },
3500 "Get the underlying `LocationAttr`.")
3501 .def(
3502 "emit_error",
3503 [](PyLocation &self, std::string message) {
3504 mlirEmitError(self, message.c_str());
3505 },
3506 "message"_a,
3507 R"(
3508 Emits an error diagnostic at this location.
3509
3510 Args:
3511 message: The error message to emit.)")
3512 .def(
3513 "__repr__",
3514 [](PyLocation &self) {
3515 PyPrintAccumulator printAccum;
3516 mlirLocationPrint(self, printAccum.getCallback(),
3517 printAccum.getUserData());
3518 return printAccum.join();
3519 },
3520 "Returns the assembly representation of the location.");
3521
3522 //----------------------------------------------------------------------------
3523 // Mapping of Module
3524 //----------------------------------------------------------------------------
3525 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3527 "Gets a capsule wrapping the MlirModule.")
3529 R"(
3530 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3531
3532 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3533 prevent double-frees (of the underlying `mlir::Module`).)")
3534 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3535 R"(
3536 Clears the internal MLIR module reference.
3537
3538 This is used internally to prevent double-free when ownership is transferred
3539 via the C API capsule mechanism. Not intended for normal use.)")
3540 .def_static(
3541 "parse",
3542 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3543 -> nb::typed<nb::object, PyModule> {
3544 PyMlirContext::ErrorCapture errors(context->getRef());
3545 MlirModule module = mlirModuleCreateParse(
3546 context->get(), toMlirStringRef(moduleAsm));
3547 if (mlirModuleIsNull(module))
3548 throw MLIRError("Unable to parse module assembly", errors.take());
3549 return PyModule::forModule(module).releaseObject();
3550 },
3551 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3552 .def_static(
3553 "parse",
3554 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3555 -> nb::typed<nb::object, PyModule> {
3556 PyMlirContext::ErrorCapture errors(context->getRef());
3557 MlirModule module = mlirModuleCreateParse(
3558 context->get(), toMlirStringRef(moduleAsm));
3559 if (mlirModuleIsNull(module))
3560 throw MLIRError("Unable to parse module assembly", errors.take());
3561 return PyModule::forModule(module).releaseObject();
3562 },
3563 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3564 .def_static(
3565 "parseFile",
3566 [](const std::string &path, DefaultingPyMlirContext context)
3567 -> nb::typed<nb::object, PyModule> {
3568 PyMlirContext::ErrorCapture errors(context->getRef());
3569 MlirModule module = mlirModuleCreateParseFromFile(
3570 context->get(), toMlirStringRef(path));
3571 if (mlirModuleIsNull(module))
3572 throw MLIRError("Unable to parse module assembly", errors.take());
3573 return PyModule::forModule(module).releaseObject();
3574 },
3575 "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
3576 .def_static(
3577 "create",
3578 [](const std::optional<PyLocation> &loc)
3579 -> nb::typed<nb::object, PyModule> {
3580 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3581 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3582 return PyModule::forModule(module).releaseObject();
3583 },
3584 "loc"_a = nb::none(), "Creates an empty module.")
3585 .def_prop_ro(
3586 "context",
3587 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3588 return self.getContext().getObject();
3589 },
3590 "Context that created the `Module`.")
3591 .def_prop_ro(
3592 "operation",
3593 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3594 return PyOperation::forOperation(self.getContext(),
3595 mlirModuleGetOperation(self.get()),
3596 self.getRef().releaseObject())
3597 .releaseObject();
3598 },
3599 "Accesses the module as an operation.")
3600 .def_prop_ro(
3601 "body",
3602 [](PyModule &self) {
3604 self.getContext(), mlirModuleGetOperation(self.get()),
3605 self.getRef().releaseObject());
3606 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3607 return returnBlock;
3608 },
3609 "Return the block for this module.")
3610 .def(
3611 "dump",
3612 [](PyModule &self) {
3614 },
3616 .def(
3617 "__str__",
3618 [](const nb::object &self) {
3619 // Defer to the operation's __str__.
3620 return self.attr("operation").attr("__str__")();
3621 },
3622 nb::sig("def __str__(self) -> str"),
3623 R"(
3624 Gets the assembly form of the operation with default options.
3625
3626 If more advanced control over the assembly formatting or I/O options is needed,
3627 use the dedicated print or get_asm method, which supports keyword arguments to
3628 customize behavior.
3629 )")
3630 .def(
3631 "__eq__",
3632 [](PyModule &self, PyModule &other) {
3633 return mlirModuleEqual(self.get(), other.get());
3634 },
3635 "other"_a, "Compares two modules for equality.")
3636 .def(
3637 "__hash__",
3638 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3639 "Returns the hash value of the module.");
3640
3641 //----------------------------------------------------------------------------
3642 // Mapping of Operation.
3643 //----------------------------------------------------------------------------
3644 nb::class_<PyOperationBase>(m, "_OperationBase")
3645 .def_prop_ro(
3647 [](PyOperationBase &self) {
3648 return self.getOperation().getCapsule();
3649 },
3650 "Gets a capsule wrapping the `MlirOperation`.")
3651 .def(
3652 "__eq__",
3653 [](PyOperationBase &self, PyOperationBase &other) {
3654 return mlirOperationEqual(self.getOperation().get(),
3655 other.getOperation().get());
3656 },
3657 "Compares two operations for equality.")
3658 .def(
3659 "__eq__",
3660 [](PyOperationBase &self, nb::object other) { return false; },
3661 "Compares operation with non-operation object (always returns "
3662 "False).")
3663 .def(
3664 "__hash__",
3665 [](PyOperationBase &self) {
3666 return mlirOperationHashValue(self.getOperation().get());
3667 },
3668 "Returns the hash value of the operation.")
3669 .def_prop_ro(
3670 "attributes",
3671 [](PyOperationBase &self) {
3672 return PyOpAttributeMap(self.getOperation().getRef());
3673 },
3674 "Returns a dictionary-like map of operation attributes.")
3675 .def_prop_ro(
3676 "context",
3677 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3678 PyOperation &concreteOperation = self.getOperation();
3679 concreteOperation.checkValid();
3680 return concreteOperation.getContext().getObject();
3681 },
3682 "Context that owns the operation.")
3683 .def_prop_ro(
3684 "name",
3685 [](PyOperationBase &self) {
3686 auto &concreteOperation = self.getOperation();
3687 concreteOperation.checkValid();
3688 MlirOperation operation = concreteOperation.get();
3689 return mlirIdentifierStr(mlirOperationGetName(operation));
3690 },
3691 "Returns the fully qualified name of the operation.")
3692 .def_prop_ro(
3693 "operands",
3694 [](PyOperationBase &self) {
3695 return PyOpOperandList(self.getOperation().getRef());
3696 },
3697 "Returns the list of operation operands.")
3698 .def_prop_ro(
3699 "op_operands",
3700 [](PyOperationBase &self) {
3701 return PyOpOperands(self.getOperation().getRef());
3702 },
3703 "Returns the list of op operands.")
3704 .def_prop_ro(
3705 "regions",
3706 [](PyOperationBase &self) {
3707 return PyRegionList(self.getOperation().getRef());
3708 },
3709 "Returns the list of operation regions.")
3710 .def_prop_ro(
3711 "results",
3712 [](PyOperationBase &self) {
3713 return PyOpResultList(self.getOperation().getRef());
3714 },
3715 "Returns the list of Operation results.")
3716 .def_prop_ro(
3717 "result",
3718 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3719 auto &operation = self.getOperation();
3720 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3721 .maybeDownCast();
3722 },
3723 "Shortcut to get an op result if it has only one (throws an error "
3724 "otherwise).")
3725 .def_prop_rw(
3726 "location",
3727 [](PyOperationBase &self) {
3728 PyOperation &operation = self.getOperation();
3729 return PyLocation(operation.getContext(),
3730 mlirOperationGetLocation(operation.get()));
3731 },
3732 [](PyOperationBase &self, const PyLocation &location) {
3733 PyOperation &operation = self.getOperation();
3734 mlirOperationSetLocation(operation.get(), location.get());
3735 },
3736 nb::for_getter("Returns the source location the operation was "
3737 "defined or derived from."),
3738 nb::for_setter("Sets the source location the operation was defined "
3739 "or derived from."))
3740 .def_prop_ro(
3741 "parent",
3742 [](PyOperationBase &self)
3743 -> std::optional<nb::typed<nb::object, PyOperation>> {
3744 auto parent = self.getOperation().getParentOperation();
3745 if (parent)
3746 return parent->getObject();
3747 return {};
3748 },
3749 "Returns the parent operation, or `None` if at top level.")
3750 .def(
3751 "__str__",
3752 [](PyOperationBase &self) {
3753 return self.getAsm(/*binary=*/false,
3754 /*largeElementsLimit=*/std::nullopt,
3755 /*largeResourceLimit=*/std::nullopt,
3756 /*enableDebugInfo=*/false,
3757 /*prettyDebugInfo=*/false,
3758 /*printGenericOpForm=*/false,
3759 /*useLocalScope=*/false,
3760 /*useNameLocAsPrefix=*/false,
3761 /*assumeVerified=*/false,
3762 /*skipRegions=*/false);
3763 },
3764 nb::sig("def __str__(self) -> str"),
3765 "Returns the assembly form of the operation.")
3766 .def("print",
3767 nb::overload_cast<PyAsmState &, nb::object, bool>(
3769 "state"_a, "file"_a = nb::none(), "binary"_a = false,
3770 R"(
3771 Prints the assembly form of the operation to a file like object.
3772
3773 Args:
3774 state: `AsmState` capturing the operation numbering and flags.
3775 file: Optional file like object to write to. Defaults to sys.stdout.
3776 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3777 .def("print",
3778 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3779 bool, bool, bool, bool, bool, bool, nb::object,
3780 bool, bool>(&PyOperationBase::print),
3781 // Careful: Lots of arguments must match up with print method.
3782 "large_elements_limit"_a = nb::none(),
3783 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3784 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3785 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3786 "assume_verified"_a = false, "file"_a = nb::none(),
3787 "binary"_a = false, "skip_regions"_a = false,
3788 R"(
3789 Prints the assembly form of the operation to a file like object.
3790
3791 Args:
3792 large_elements_limit: Whether to elide elements attributes above this
3793 number of elements. Defaults to None (no limit).
3794 large_resource_limit: Whether to elide resource attributes above this
3795 number of characters. Defaults to None (no limit). If large_elements_limit
3796 is set and this is None, the behavior will be to use large_elements_limit
3797 as large_resource_limit.
3798 enable_debug_info: Whether to print debug/location information. Defaults
3799 to False.
3800 pretty_debug_info: Whether to format debug information for easier reading
3801 by a human (warning: the result is unparseable). Defaults to False.
3802 print_generic_op_form: Whether to print the generic assembly forms of all
3803 ops. Defaults to False.
3804 use_local_scope: Whether to print in a way that is more optimized for
3805 multi-threaded access but may not be consistent with how the overall
3806 module prints.
3807 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3808 prefixes for the SSA identifiers. Defaults to False.
3809 assume_verified: By default, if not printing generic form, the verifier
3810 will be run and if it fails, generic form will be printed with a comment
3811 about failed verification. While a reasonable default for interactive use,
3812 for systematic use, it is often better for the caller to verify explicitly
3813 and report failures in a more robust fashion. Set this to True if doing this
3814 in order to avoid running a redundant verification. If the IR is actually
3815 invalid, behavior is undefined.
3816 file: The file like object to write to. Defaults to sys.stdout.
3817 binary: Whether to write bytes (True) or str (False). Defaults to False.
3818 skip_regions: Whether to skip printing regions. Defaults to False.)")
3819 .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
3820 "desired_version"_a = nb::none(),
3821 R"(
3822 Write the bytecode form of the operation to a file like object.
3823
3824 Args:
3825 file: The file like object to write to.
3826 desired_version: Optional version of bytecode to emit.
3827 Returns:
3828 The bytecode writer status.)")
3829 .def("get_asm", &PyOperationBase::getAsm,
3830 // Careful: Lots of arguments must match up with get_asm method.
3831 "binary"_a = false, "large_elements_limit"_a = nb::none(),
3832 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3833 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3834 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3835 "assume_verified"_a = false, "skip_regions"_a = false,
3836 R"(
3837 Gets the assembly form of the operation with all options available.
3838
3839 Args:
3840 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3841 False.
3842 ... others ...: See the print() method for common keyword arguments for
3843 configuring the printout.
3844 Returns:
3845 Either a bytes or str object, depending on the setting of the `binary`
3846 argument.)")
3847 .def("verify", &PyOperationBase::verify,
3848 "Verify the operation. Raises MLIRError if verification fails, and "
3849 "returns true otherwise.")
3850 .def("move_after", &PyOperationBase::moveAfter, "other"_a,
3851 "Puts self immediately after the other operation in its parent "
3852 "block.")
3853 .def("move_before", &PyOperationBase::moveBefore, "other"_a,
3854 "Puts self immediately before the other operation in its parent "
3855 "block.")
3856 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
3857 R"(
3858 Checks if this operation is before another in the same block.
3859
3860 Args:
3861 other: Another operation in the same parent block.
3862
3863 Returns:
3864 True if this operation is before `other` in the operation list of the parent block.)")
3865 .def(
3866 "clone",
3867 [](PyOperationBase &self,
3868 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3869 return self.getOperation().clone(ip);
3870 },
3871 "ip"_a = nb::none(),
3872 R"(
3873 Creates a deep copy of the operation.
3874
3875 Args:
3876 ip: Optional insertion point where the cloned operation should be inserted.
3877 If None, the current insertion point is used. If False, the operation
3878 remains detached.
3879
3880 Returns:
3881 A new Operation that is a clone of this operation.)")
3882 .def(
3883 "detach_from_parent",
3884 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3885 PyOperation &operation = self.getOperation();
3886 operation.checkValid();
3887 if (!operation.isAttached())
3888 throw nb::value_error("Detached operation has no parent.");
3889
3890 operation.detachFromParent();
3891 return operation.createOpView();
3892 },
3893 "Detaches the operation from its parent block.")
3894 .def_prop_ro(
3895 "attached",
3896 [](PyOperationBase &self) {
3897 PyOperation &operation = self.getOperation();
3898 operation.checkValid();
3899 return operation.isAttached();
3900 },
3901 "Reports if the operation is attached to its parent block.")
3902 .def(
3903 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
3904 R"(
3905 Erases the operation and frees its memory.
3906
3907 Note:
3908 After erasing, any Python references to the operation become invalid.)")
3909 .def(
3910 "walk",
3911 [](PyOperationBase &self,
3912 std::function<PyWalkResult(MlirOperation)> callback,
3913 PyWalkOrder walkOrder, std::optional<nb::object> opClass) {
3914 if (!opClass)
3915 return self.walk(callback, walkOrder);
3916 self.walk(
3917 [&](MlirOperation mlirOp) -> PyWalkResult {
3918 nb::object opview =
3920 self.getOperation().getContext(), mlirOp)
3921 ->createOpView();
3922 if (nb::isinstance(opview, *opClass))
3923 return callback(mlirOp);
3924 return PyWalkResult::Advance;
3925 },
3926 walkOrder);
3927 },
3928 "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
3929 "op_class"_a = nb::none(),
3930 // clang-format off
3931 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ..., op_class: type[OpView] | None = None) -> None"),
3932 // clang-format on
3933 R"(
3934 Walks the operation tree with a callback function.
3935
3936 If op_class is provided, the callback is only invoked on operations
3937 of that type; all other operations are skipped silently.
3938
3939 Args:
3940 callback: A callable that takes an Operation and returns a WalkResult.
3941 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
3942 op_class: If provided, only operations of this type are passed to the callback.)")
3943 .def(
3944 "has_trait",
3945 [](PyOperationBase &self, nb::type_object &traitCls) {
3946 PyTypeID traitTypeID =
3947 nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
3948 MlirIdentifier opName =
3949 mlirOperationGetName(self.getOperation().get());
3951 mlirIdentifierStr(opName), traitTypeID.get(),
3952 self.getOperation().getContext()->get());
3953 },
3954 "trait_cls"_a, "Checks if the operation has a given trait.");
3955
3956 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3957 .def_static(
3958 "create",
3959 [](std::string_view name,
3960 std::optional<std::vector<PyType *>> results,
3961 std::optional<std::vector<PyValue *>> operands,
3962 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
3963 attributes,
3964 std::optional<std::vector<PyBlock *>> successors, int regions,
3965 const std::optional<PyLocation> &location,
3966 const nb::object &maybeIp,
3967 bool inferType) -> nb::typed<nb::object, PyOperation> {
3968 // Unpack/validate operands.
3969 std::vector<MlirValue> mlirOperands;
3970 if (operands) {
3971 mlirOperands.reserve(operands->size());
3972 for (PyValue *operand : *operands) {
3973 if (!operand)
3974 throw nb::value_error("operand value cannot be None");
3975 mlirOperands.push_back(operand->get());
3976 }
3977 }
3978
3979 PyLocation pyLoc = maybeGetTracebackLocation(location);
3980 return PyOperation::create(
3981 name, results, mlirOperands.data(), mlirOperands.size(),
3982 attributes, successors, regions, pyLoc, maybeIp, inferType);
3983 },
3984 "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
3985 "attributes"_a = nb::none(), "successors"_a = nb::none(),
3986 "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
3987 "infer_type"_a = false,
3988 R"(
3989 Creates a new operation.
3990
3991 Args:
3992 name: Operation name (e.g. `dialect.operation`).
3993 results: Optional sequence of Type representing op result types.
3994 operands: Optional operands of the operation.
3995 attributes: Optional Dict of {str: Attribute}.
3996 successors: Optional List of Block for the operation's successors.
3997 regions: Number of regions to create (default = 0).
3998 location: Optional Location object (defaults to resolve from context manager).
3999 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
4000 infer_type: Whether to infer result types (default = False).
4001 Returns:
4002 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
4003 .def_static(
4004 "parse",
4005 [](const std::string &sourceStr, const std::string &sourceName,
4007 -> nb::typed<nb::object, PyOpView> {
4008 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
4009 ->createOpView();
4010 },
4011 "source"_a, nb::kw_only(), "source_name"_a = "",
4012 "context"_a = nb::none(),
4013 "Parses an operation. Supports both text assembly format and binary "
4014 "bytecode format.")
4016 "Gets a capsule wrapping the MlirOperation.")
4019 "Creates an Operation from a capsule wrapping MlirOperation.")
4020 .def_prop_ro(
4021 "operation",
4022 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
4023 return self;
4024 },
4025 "Returns self (the operation).")
4026 .def_prop_ro(
4027 "opview",
4028 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
4029 return self.createOpView();
4030 },
4031 R"(
4032 Returns an OpView of this operation.
4033
4034 Note:
4035 If the operation has a registered and loaded dialect then this OpView will
4036 be concrete wrapper class.)")
4037 .def_prop_ro("block", &PyOperation::getBlock,
4038 "Returns the block containing this operation.")
4039 .def_prop_ro(
4040 "successors",
4041 [](PyOperationBase &self) {
4042 return PyOpSuccessors(self.getOperation().getRef());
4043 },
4044 "Returns the list of Operation successors.")
4045 .def(
4046 "replace_uses_of_with",
4047 [](PyOperation &self, PyValue &of, PyValue &with) {
4048 mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
4049 },
4050 "of"_a, "with_"_a,
4051 "Replaces uses of the 'of' value with the 'with' value inside the "
4052 "operation.")
4053 .def("_set_invalid", &PyOperation::setInvalid,
4054 "Invalidate the operation.");
4055
4056 auto opViewClass =
4057 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
4058 .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
4059 .def(
4060 "__init__",
4061 [](PyOpView *self, std::string_view name,
4062 std::tuple<int, bool> opRegionSpec,
4063 nb::object operandSegmentSpecObj,
4064 nb::object resultSegmentSpecObj,
4065 std::optional<nb::sequence> resultTypeList,
4066 nb::sequence operandList,
4067 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
4068 attributes,
4069 std::optional<std::vector<PyBlock *>> successors,
4070 std::optional<int> regions,
4071 const std::optional<PyLocation> &location,
4072 const nb::object &maybeIp) {
4073 PyLocation pyLoc = maybeGetTracebackLocation(location);
4075 name, opRegionSpec, operandSegmentSpecObj,
4076 resultSegmentSpecObj, resultTypeList, operandList,
4077 attributes, successors, regions, pyLoc, maybeIp));
4078 },
4079 "name"_a, "opRegionSpec"_a,
4080 "operandSegmentSpecObj"_a = nb::none(),
4081 "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
4082 "operands"_a = nb::none(), "attributes"_a = nb::none(),
4083 "successors"_a = nb::none(), "regions"_a = nb::none(),
4084 "loc"_a = nb::none(), "ip"_a = nb::none())
4085 .def_prop_ro(
4086 "operation",
4087 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
4088 return self.getOperationObject();
4089 })
4090 .def_prop_ro("opview",
4091 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
4092 return self;
4093 })
4094 .def(
4095 "__str__",
4096 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
4097 .def_prop_ro(
4098 "successors",
4099 [](PyOperationBase &self) {
4100 return PyOpSuccessors(self.getOperation().getRef());
4101 },
4102 "Returns the list of Operation successors.")
4103 .def(
4104 "_set_invalid",
4105 [](PyOpView &self) { self.getOperation().setInvalid(); },
4106 "Invalidate the operation.");
4107 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
4108 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
4109 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
4110 // It is faster to pass the operation_name, ods_regions, and
4111 // ods_operand_segments/ods_result_segments as arguments to the constructor,
4112 // rather than to access them as attributes.
4113 opViewClass.attr("build_generic") = classmethod(
4114 [](nb::handle cls, std::optional<nb::sequence> resultTypeList,
4115 nb::sequence operandList,
4116 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>> attributes,
4117 std::optional<std::vector<PyBlock *>> successors,
4118 std::optional<int> regions, std::optional<PyLocation> location,
4119 const nb::object &maybeIp) {
4120 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4121 std::tuple<int, bool> opRegionSpec =
4122 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
4123 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
4124 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
4125 PyLocation pyLoc = maybeGetTracebackLocation(location);
4126 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
4127 resultSegmentSpec, resultTypeList,
4128 operandList, attributes, successors,
4129 regions, pyLoc, maybeIp);
4130 },
4131 "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
4132 "attributes"_a = nb::none(), "successors"_a = nb::none(),
4133 "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
4134 // clang-format off
4135 nb::sig("def build_generic(cls, results: Sequence[Type] | None = None, operands: Sequence[Value] | None = None, attributes: dict[str, Attribute] | None = None, successors: Sequence[Block] | None = None, regions: int | None = None, loc: Location | None = None, ip: InsertionPoint | None = None) -> typing.Self"),
4136 // clang-format on
4137 "Builds a specific, generated OpView based on class level attributes.");
4138 opViewClass.attr("parse") = classmethod(
4139 [](const nb::object &cls, const std::string &sourceStr,
4140 const std::string &sourceName,
4141 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
4142 PyOperationRef parsed =
4143 PyOperation::parse(context->getRef(), sourceStr, sourceName);
4144
4145 // Check if the expected operation was parsed, and cast to to the
4146 // appropriate `OpView` subclass if successful.
4147 // NOTE: This accesses attributes that have been automatically added to
4148 // `OpView` subclasses, and is not intended to be used on `OpView`
4149 // directly.
4150 std::string clsOpName =
4151 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4152 MlirStringRef identifier =
4154 std::string_view parsedOpName(identifier.data, identifier.length);
4155 if (clsOpName != parsedOpName)
4156 throw MLIRError(join("Expected a '", clsOpName, "' op, got: '",
4157 parsedOpName, "'"));
4158 return PyOpView::constructDerived(cls, parsed.getObject());
4159 },
4160 "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
4161 "context"_a = nb::none(),
4162 // clang-format off
4163 nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
4164 // clang-format on
4165 "Parses a specific, generated OpView based on class level attributes.");
4166 opViewClass.attr("has_trait") = classmethod(
4167 [](nb::object &self, nb::type_object &traitCls,
4168 DefaultingPyMlirContext &context) {
4169 PyTypeID traitTypeID =
4170 nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
4171 std::string opName = nb::cast<std::string>(self.attr("OPERATION_NAME"));
4173 mlirStringRefCreate(opName.data(), opName.size()),
4174 traitTypeID.get(), context->get());
4175 },
4176 "cls"_a, "trait_cls"_a, "context"_a = nb::none(),
4177 "Checks if the operation has a given trait.");
4178
4180
4181 //----------------------------------------------------------------------------
4182 // Mapping of PyRegion.
4183 //----------------------------------------------------------------------------
4184 nb::class_<PyRegion>(m, "Region")
4185 .def_prop_ro(
4186 "blocks",
4187 [](PyRegion &self) {
4188 return PyBlockList(self.getParentOperation(), self.get());
4189 },
4190 "Returns a forward-optimized sequence of blocks.")
4191 .def_prop_ro(
4192 "owner",
4193 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4194 return self.getParentOperation()->createOpView();
4195 },
4196 "Returns the operation owning this region.")
4197 .def(
4198 "__iter__",
4199 [](PyRegion &self) {
4200 self.checkValid();
4201 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
4202 return PyBlockIterator(self.getParentOperation(), firstBlock);
4203 },
4204 "Iterates over blocks in the region.")
4205 .def(
4206 "__eq__",
4207 [](PyRegion &self, PyRegion &other) {
4208 return self.get().ptr == other.get().ptr;
4209 },
4210 "Compares two regions for pointer equality.")
4211 .def(
4212 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
4213 "Compares region with non-region object (always returns False).");
4214
4215 //----------------------------------------------------------------------------
4216 // Mapping of PyBlock.
4217 //----------------------------------------------------------------------------
4218 nb::class_<PyBlock>(m, "Block")
4220 "Gets a capsule wrapping the MlirBlock.")
4221 .def_prop_ro(
4222 "owner",
4223 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4224 return self.getParentOperation()->createOpView();
4225 },
4226 "Returns the owning operation of this block.")
4227 .def_prop_ro(
4228 "region",
4229 [](PyBlock &self) {
4230 MlirRegion region = mlirBlockGetParentRegion(self.get());
4231 return PyRegion(self.getParentOperation(), region);
4232 },
4233 "Returns the owning region of this block.")
4234 .def_prop_ro(
4235 "arguments",
4236 [](PyBlock &self) {
4237 return PyBlockArgumentList(self.getParentOperation(), self.get());
4238 },
4239 "Returns a list of block arguments.")
4240 .def(
4241 "add_argument",
4242 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
4243 return PyBlockArgument(self.getParentOperation(),
4244 mlirBlockAddArgument(self.get(), type, loc));
4245 },
4246 "type"_a, "loc"_a,
4247 R"(
4248 Appends an argument of the specified type to the block.
4249
4250 Args:
4251 type: The type of the argument to add.
4252 loc: The source location for the argument.
4253
4254 Returns:
4255 The newly added block argument.)")
4256 .def(
4257 "erase_argument",
4258 [](PyBlock &self, unsigned index) {
4259 return mlirBlockEraseArgument(self.get(), index);
4260 },
4261 "index"_a,
4262 R"(
4263 Erases the argument at the specified index.
4264
4265 Args:
4266 index: The index of the argument to erase.)")
4267 .def_prop_ro(
4268 "operations",
4269 [](PyBlock &self) {
4270 return PyOperationList(self.getParentOperation(), self.get());
4271 },
4272 "Returns a forward-optimized sequence of operations.")
4273 .def_static(
4274 "create_at_start",
4275 [](PyRegion &parent, nb::typed<nb::sequence, PyType> pyArgTypes,
4276 const std::optional<nb::typed<nb::sequence, PyLocation>>
4277 &pyArgLocs) {
4278 parent.checkValid();
4279 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4280 mlirRegionInsertOwnedBlock(parent, 0, block);
4281 return PyBlock(parent.getParentOperation(), block);
4282 },
4283 "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
4284 "Creates and returns a new Block at the beginning of the given "
4285 "region (with given argument types and locations).")
4286 .def(
4287 "append_to",
4288 [](PyBlock &self, PyRegion &region) {
4289 MlirBlock b = self.get();
4292 mlirRegionAppendOwnedBlock(region.get(), b);
4293 },
4294 "region"_a,
4295 R"(
4296 Appends this block to a region.
4297
4298 Transfers ownership if the block is currently owned by another region.
4299
4300 Args:
4301 region: The region to append the block to.)")
4302 .def(
4303 "create_before",
4304 [](PyBlock &self, const nb::args &pyArgTypes,
4305 const std::optional<nb::typed<nb::sequence, PyLocation>>
4306 &pyArgLocs) {
4307 self.checkValid();
4308 MlirBlock block =
4309 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4310 MlirRegion region = mlirBlockGetParentRegion(self.get());
4311 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4312 return PyBlock(self.getParentOperation(), block);
4313 },
4314 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4315 "Creates and returns a new Block before this block "
4316 "(with given argument types and locations).")
4317 .def(
4318 "create_after",
4319 [](PyBlock &self, const nb::args &pyArgTypes,
4320 const std::optional<nb::typed<nb::sequence, PyLocation>>
4321 &pyArgLocs) {
4322 self.checkValid();
4323 MlirBlock block =
4324 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4325 MlirRegion region = mlirBlockGetParentRegion(self.get());
4326 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4327 return PyBlock(self.getParentOperation(), block);
4328 },
4329 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4330 "Creates and returns a new Block after this block "
4331 "(with given argument types and locations).")
4332 .def(
4333 "__iter__",
4334 [](PyBlock &self) {
4335 self.checkValid();
4336 MlirOperation firstOperation =
4337 mlirBlockGetFirstOperation(self.get());
4338 return PyOperationIterator(self.getParentOperation(),
4339 firstOperation);
4340 },
4341 "Iterates over operations in the block.")
4342 .def(
4343 "__eq__",
4344 [](PyBlock &self, PyBlock &other) {
4345 return self.get().ptr == other.get().ptr;
4346 },
4347 "Compares two blocks for pointer equality.")
4348 .def(
4349 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4350 "Compares block with non-block object (always returns False).")
4351 .def(
4352 "__hash__", [](PyBlock &self) { return hash(self.get().ptr); },
4353 "Returns the hash value of the block.")
4354 .def(
4355 "__str__",
4356 [](PyBlock &self) {
4357 self.checkValid();
4358 PyPrintAccumulator printAccum;
4359 mlirBlockPrint(self.get(), printAccum.getCallback(),
4360 printAccum.getUserData());
4361 return printAccum.join();
4362 },
4363 "Returns the assembly form of the block.")
4364 .def(
4365 "append",
4366 [](PyBlock &self, PyOperationBase &operation) {
4367 if (operation.getOperation().isAttached())
4368 operation.getOperation().detachFromParent();
4369
4370 MlirOperation mlirOperation = operation.getOperation().get();
4371 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4372 operation.getOperation().setAttached(
4373 self.getParentOperation().getObject());
4374 },
4375 "operation"_a,
4376 R"(
4377 Appends an operation to this block.
4378
4379 If the operation is currently in another block, it will be moved.
4380
4381 Args:
4382 operation: The operation to append to the block.)")
4383 .def_prop_ro(
4384 "successors",
4385 [](PyBlock &self) {
4386 return PyBlockSuccessors(self, self.getParentOperation());
4387 },
4388 "Returns the list of Block successors.")
4389 .def_prop_ro(
4390 "predecessors",
4391 [](PyBlock &self) {
4392 return PyBlockPredecessors(self, self.getParentOperation());
4393 },
4394 "Returns the list of Block predecessors.");
4395
4396 //----------------------------------------------------------------------------
4397 // Mapping of PyInsertionPoint.
4398 //----------------------------------------------------------------------------
4399
4400 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4401 .def(nb::init<PyBlock &>(), "block"_a,
4402 "Inserts after the last operation but still inside the block.")
4403 .def("__enter__", &PyInsertionPoint::contextEnter,
4404 "Enters the insertion point as a context manager.",
4405 nb::sig("def __enter__(self, /) -> InsertionPoint"))
4406 .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
4407 "exc_value"_a.none(), "traceback"_a.none(),
4408 "Exits the insertion point context manager.")
4409 .def_prop_ro_static(
4410 "current",
4411 [](nb::object & /*class*/) {
4413 if (!ip)
4414 throw nb::value_error("No current InsertionPoint");
4415 return ip;
4416 },
4417 nb::sig("def current(/) -> InsertionPoint"),
4418 "Gets the InsertionPoint bound to the current thread or raises "
4419 "ValueError if none has been set.")
4420 .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
4421 "Inserts before a referenced operation.")
4422 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
4423 R"(
4424 Creates an insertion point at the beginning of a block.
4425
4426 Args:
4427 block: The block at whose beginning operations should be inserted.
4428
4429 Returns:
4430 An InsertionPoint at the block's beginning.)")
4431 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4432 "block"_a,
4433 R"(
4434 Creates an insertion point before a block's terminator.
4435
4436 Args:
4437 block: The block whose terminator to insert before.
4438
4439 Returns:
4440 An InsertionPoint before the terminator.
4441
4442 Raises:
4443 ValueError: If the block has no terminator.)")
4444 .def_static("after", &PyInsertionPoint::after, "operation"_a,
4445 R"(
4446 Creates an insertion point immediately after an operation.
4447
4448 Args:
4449 operation: The operation after which to insert.
4450
4451 Returns:
4452 An InsertionPoint after the operation.)")
4453 .def("insert", &PyInsertionPoint::insert, "operation"_a,
4454 R"(
4455 Inserts an operation at this insertion point.
4456
4457 Args:
4458 operation: The operation to insert.)")
4459 .def_prop_ro(
4460 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4461 "Returns the block that this `InsertionPoint` points to.")
4462 .def_prop_ro(
4463 "ref_operation",
4464 [](PyInsertionPoint &self)
4465 -> std::optional<nb::typed<nb::object, PyOperation>> {
4466 auto refOperation = self.getRefOperation();
4467 if (refOperation)
4468 return refOperation->getObject();
4469 return {};
4470 },
4471 "The reference operation before which new operations are "
4472 "inserted, or None if the insertion point is at the end of "
4473 "the block.");
4474
4475 //----------------------------------------------------------------------------
4476 // Mapping of PyAttribute.
4477 //----------------------------------------------------------------------------
4478 nb::class_<PyAttribute>(m, "Attribute")
4479 // Delegate to the PyAttribute copy constructor, which will also lifetime
4480 // extend the backing context which owns the MlirAttribute.
4481 .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
4482 "Casts the passed attribute to the generic `Attribute`.")
4484 "Gets a capsule wrapping the MlirAttribute.")
4485 .def_static(
4487 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4488 .def_static(
4489 "parse",
4490 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4491 -> nb::typed<nb::object, PyAttribute> {
4492 PyMlirContext::ErrorCapture errors(context->getRef());
4493 MlirAttribute attr = mlirAttributeParseGet(
4494 context->get(), toMlirStringRef(attrSpec));
4495 if (mlirAttributeIsNull(attr))
4496 throw MLIRError("Unable to parse attribute", errors.take());
4497 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4498 },
4499 "asm"_a, "context"_a = nb::none(),
4500 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4501 "failure.")
4502 .def_prop_ro(
4503 "context",
4504 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4505 return self.getContext().getObject();
4506 },
4507 "Context that owns the `Attribute`.")
4508 .def_prop_ro(
4509 "type",
4510 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4511 return PyType(self.getContext(), mlirAttributeGetType(self))
4512 .maybeDownCast();
4513 },
4514 "Returns the type of the `Attribute`.")
4515 .def(
4516 "get_named",
4517 [](PyAttribute &self, std::string name) {
4518 return PyNamedAttribute(self, std::move(name));
4519 },
4520 nb::keep_alive<0, 1>(),
4521 R"(
4522 Binds a name to the attribute, creating a `NamedAttribute`.
4523
4524 Args:
4525 name: The name to bind to the `Attribute`.
4526
4527 Returns:
4528 A `NamedAttribute` with the given name and this attribute.)")
4529 .def(
4530 "__eq__",
4531 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4532 "Compares two attributes for equality.")
4533 .def(
4534 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4535 "Compares attribute with non-attribute object (always returns "
4536 "False).")
4537 .def(
4538 "__hash__", [](PyAttribute &self) { return hash(self.get().ptr); },
4539 "Returns the hash value of the attribute.")
4540 .def(
4541 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4543 .def(
4544 "__str__",
4545 [](PyAttribute &self) {
4546 PyPrintAccumulator printAccum;
4547 mlirAttributePrint(self, printAccum.getCallback(),
4548 printAccum.getUserData());
4549 return printAccum.join();
4550 },
4551 "Returns the assembly form of the Attribute.")
4552 .def(
4553 "__repr__",
4554 [](PyAttribute &self) {
4555 // Generally, assembly formats are not printed for __repr__ because
4556 // this can cause exceptionally long debug output and exceptions.
4557 // However, attribute values are generally considered useful and
4558 // are printed. This may need to be re-evaluated if debug dumps end
4559 // up being excessive.
4560 PyPrintAccumulator printAccum;
4561 printAccum.parts.append("Attribute(");
4562 mlirAttributePrint(self, printAccum.getCallback(),
4563 printAccum.getUserData());
4564 printAccum.parts.append(")");
4565 return printAccum.join();
4566 },
4567 "Returns a string representation of the attribute.")
4568 .def_prop_ro(
4569 "typeid",
4570 [](PyAttribute &self) {
4571 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4572 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4573 "mlirTypeID was expected to be non-null.");
4574 return PyTypeID(mlirTypeID);
4575 },
4576 "Returns the `TypeID` of the attribute.")
4577 .def(
4579 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4580 return self.maybeDownCast();
4581 },
4582 "Downcasts the attribute to a more specific attribute if possible.");
4583
4584 //----------------------------------------------------------------------------
4585 // Mapping of PyNamedAttribute
4586 //----------------------------------------------------------------------------
4587 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4588 .def(
4589 "__repr__",
4590 [](PyNamedAttribute &self) {
4591 PyPrintAccumulator printAccum;
4592 printAccum.parts.append("NamedAttribute(");
4593 printAccum.parts.append(
4594 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4595 mlirIdentifierStr(self.namedAttr.name).length));
4596 printAccum.parts.append("=");
4597 mlirAttributePrint(self.namedAttr.attribute,
4598 printAccum.getCallback(),
4599 printAccum.getUserData());
4600 printAccum.parts.append(")");
4601 return printAccum.join();
4602 },
4603 "Returns a string representation of the named attribute.")
4604 .def_prop_ro(
4605 "name",
4606 [](PyNamedAttribute &self) {
4607 return mlirIdentifierStr(self.namedAttr.name);
4608 },
4609 "The name of the `NamedAttribute` binding.")
4610 .def_prop_ro(
4611 "attr",
4612 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4613 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4614 "The underlying generic attribute of the `NamedAttribute` binding.");
4615
4616 //----------------------------------------------------------------------------
4617 // Mapping of PyType.
4618 //----------------------------------------------------------------------------
4619 nb::class_<PyType>(m, "Type")
4620 // Delegate to the PyType copy constructor, which will also lifetime
4621 // extend the backing context which owns the MlirType.
4622 .def(nb::init<PyType &>(), "cast_from_type"_a,
4623 "Casts the passed type to the generic `Type`.")
4625 "Gets a capsule wrapping the `MlirType`.")
4627 "Creates a Type from a capsule wrapping `MlirType`.")
4628 .def_static(
4629 "parse",
4630 [](std::string typeSpec,
4631 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4632 PyMlirContext::ErrorCapture errors(context->getRef());
4633 MlirType type =
4634 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4635 if (mlirTypeIsNull(type))
4636 throw MLIRError("Unable to parse type", errors.take());
4637 return PyType(context.get()->getRef(), type).maybeDownCast();
4638 },
4639 "asm"_a, "context"_a = nb::none(),
4640 R"(
4641 Parses the assembly form of a type.
4642
4643 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4644
4645 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4646 .def_prop_ro(
4647 "context",
4648 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4649 return self.getContext().getObject();
4650 },
4651 "Context that owns the `Type`.")
4652 .def(
4653 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4654 "Compares two types for equality.")
4655 .def(
4656 "__eq__", [](PyType &self, nb::object &other) { return false; },
4657 "other"_a.none(),
4658 "Compares type with non-type object (always returns False).")
4659 .def(
4660 "__hash__", [](PyType &self) { return hash(self.get().ptr); },
4661 "Returns the hash value of the `Type`.")
4662 .def(
4663 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4664 .def(
4665 "__str__",
4666 [](PyType &self) {
4667 PyPrintAccumulator printAccum;
4668 mlirTypePrint(self, printAccum.getCallback(),
4669 printAccum.getUserData());
4670 return printAccum.join();
4671 },
4672 "Returns the assembly form of the `Type`.")
4673 .def(
4674 "__repr__",
4675 [](PyType &self) {
4676 // Generally, assembly formats are not printed for __repr__ because
4677 // this can cause exceptionally long debug output and exceptions.
4678 // However, types are an exception as they typically have compact
4679 // assembly forms and printing them is useful.
4680 PyPrintAccumulator printAccum;
4681 printAccum.parts.append("Type(");
4682 mlirTypePrint(self, printAccum.getCallback(),
4683 printAccum.getUserData());
4684 printAccum.parts.append(")");
4685 return printAccum.join();
4686 },
4687 "Returns a string representation of the `Type`.")
4688 .def(
4690 [](PyType &self) -> nb::typed<nb::object, PyType> {
4691 return self.maybeDownCast();
4692 },
4693 "Downcasts the Type to a more specific `Type` if possible.")
4694 .def_prop_ro(
4695 "typeid",
4696 [](PyType &self) {
4697 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4698 if (!mlirTypeIDIsNull(mlirTypeID))
4699 return PyTypeID(mlirTypeID);
4700 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4701 throw nb::value_error(join(origRepr, " has no typeid.").c_str());
4702 },
4703 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4704 "`Type` has no "
4705 "`TypeID`.");
4706
4707 //----------------------------------------------------------------------------
4708 // Mapping of PyTypeID.
4709 //----------------------------------------------------------------------------
4710 nb::class_<PyTypeID>(m, "TypeID")
4712 "Gets a capsule wrapping the `MlirTypeID`.")
4714 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4715 // Note, this tests whether the underlying TypeIDs are the same,
4716 // not whether the wrapper MlirTypeIDs are the same, nor whether
4717 // the Python objects are the same (i.e., PyTypeID is a value type).
4718 .def(
4719 "__eq__",
4720 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4721 "Compares two `TypeID`s for equality.")
4722 .def(
4723 "__eq__",
4724 [](PyTypeID &self, const nb::object &other) { return false; },
4725 "Compares TypeID with non-TypeID object (always returns False).")
4726 // Note, this gives the hash value of the underlying TypeID, not the
4727 // hash value of the Python object, nor the hash value of the
4728 // MlirTypeID wrapper.
4729 .def(
4730 "__hash__",
4731 [](PyTypeID &self) {
4732 return static_cast<size_t>(mlirTypeIDHashValue(self));
4733 },
4734 "Returns the hash value of the `TypeID`.");
4735
4736 //----------------------------------------------------------------------------
4737 // Mapping of Value.
4738 //----------------------------------------------------------------------------
4739 m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
4740
4741 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4742 nb::sig("class Value(typing.Generic[_T])"))
4743 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
4744 "Creates a Value reference from another `Value`.")
4746 "Gets a capsule wrapping the `MlirValue`.")
4748 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4749 .def_prop_ro(
4750 "context",
4751 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4752 return self.getParentOperation()->getContext().getObject();
4753 },
4754 "Context in which the value lives.")
4755 .def(
4756 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4758 .def_prop_ro(
4759 "owner",
4760 [](PyValue &self)
4761 -> nb::typed<nb::object, std::variant<PyOpView, PyBlock>> {
4762 MlirValue v = self.get();
4763 if (mlirValueIsAOpResult(v)) {
4764 assert(mlirOperationEqual(self.getParentOperation()->get(),
4765 mlirOpResultGetOwner(self.get())) &&
4766 "expected the owner of the value in Python to match "
4767 "that in "
4768 "the IR");
4769 return self.getParentOperation()->createOpView();
4770 }
4771
4773 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4774 return nb::cast(PyBlock(self.getParentOperation(), block));
4775 }
4776
4777 assert(false && "Value must be a block argument or an op result");
4778 return nb::none();
4779 },
4780 "Returns the owner of the value (`Operation` for results, `Block` "
4781 "for "
4782 "arguments).")
4783 .def_prop_ro(
4784 "uses",
4785 [](PyValue &self) {
4786 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
4787 },
4788 "Returns an iterator over uses of this value.")
4789 .def(
4790 "__eq__",
4791 [](PyValue &self, PyValue &other) {
4792 return self.get().ptr == other.get().ptr;
4793 },
4794 "Compares two values for pointer equality.")
4795 .def(
4796 "__eq__", [](PyValue &self, nb::object other) { return false; },
4797 "Compares value with non-value object (always returns False).")
4798 .def(
4799 "__hash__", [](PyValue &self) { return hash(self.get().ptr); },
4800 "Returns the hash value of the value.")
4801 .def(
4802 "__str__",
4803 [](PyValue &self) {
4804 PyPrintAccumulator printAccum;
4805 printAccum.parts.append("Value(");
4806 mlirValuePrint(self.get(), printAccum.getCallback(),
4807 printAccum.getUserData());
4808 printAccum.parts.append(")");
4809 return printAccum.join();
4810 },
4811 R"(
4812 Returns the string form of the value.
4813
4814 If the value is a block argument, this is the assembly form of its type and the
4815 position in the argument list. If the value is an operation result, this is
4816 equivalent to printing the operation that produced it.
4817 )")
4818 .def(
4819 "get_name",
4820 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4821 PyPrintAccumulator printAccum;
4822 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4823 if (useLocalScope)
4825 if (useNameLocAsPrefix)
4827 MlirAsmState valueState =
4828 mlirAsmStateCreateForValue(self.get(), flags);
4829 mlirValuePrintAsOperand(self.get(), valueState,
4830 printAccum.getCallback(),
4831 printAccum.getUserData());
4833 mlirAsmStateDestroy(valueState);
4834 return printAccum.join();
4835 },
4836 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
4837 R"(
4838 Returns the string form of value as an operand.
4839
4840 Args:
4841 use_local_scope: Whether to use local scope for naming.
4842 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4843
4844 Returns:
4845 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4846 .def(
4847 "get_name",
4848 [](PyValue &self, PyAsmState &state) {
4849 PyPrintAccumulator printAccum;
4850 MlirAsmState valueState = state.get();
4851 mlirValuePrintAsOperand(self.get(), valueState,
4852 printAccum.getCallback(),
4853 printAccum.getUserData());
4854 return printAccum.join();
4855 },
4856 "state"_a,
4857 "Returns the string form of value as an operand (i.e., the ValueID).")
4858 .def_prop_ro(
4859 "type",
4860 [](PyValue &self) {
4861 return PyType(self.getParentOperation()->getContext(),
4862 mlirValueGetType(self.get()))
4863 .maybeDownCast();
4864 },
4865 "Returns the type of the value.", nb::sig("def type(self) -> _T"))
4866 .def(
4867 "set_type",
4868 [](PyValue &self, const PyType &type) {
4869 mlirValueSetType(self.get(), type);
4870 },
4871 "type"_a, "Sets the type of the value.",
4872 nb::sig("def set_type(self, type: _T)"))
4873 .def(
4874 "replace_all_uses_with",
4875 [](PyValue &self, PyValue &with) {
4876 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4877 },
4878 "Replace all uses of value with the new value, updating anything in "
4879 "the IR that uses `self` to use the other value instead.")
4880 .def(
4881 "replace_all_uses_except",
4882 [](PyValue &self, PyValue &with, PyOperation &exception) {
4883 MlirOperation exceptedUser = exception.get();
4884 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4885 },
4886 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4887 .def(
4888 "replace_all_uses_except",
4889 [](PyValue &self, PyValue &with,
4890 std::vector<PyOperation> &exceptions) {
4891 // Convert Python list to a std::vector of MlirOperations
4892 std::vector<MlirOperation> exceptionOps;
4893 for (PyOperation &exception : exceptions)
4894 exceptionOps.push_back(exception);
4896 self, with, static_cast<intptr_t>(exceptionOps.size()),
4897 exceptionOps.data());
4898 },
4899 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4900 .def(
4902 [](PyValue &self) { return self.maybeDownCast(); },
4903 "Downcasts the `Value` to a more specific kind if possible.")
4904 .def_prop_ro(
4905 "location",
4906 [](PyValue self) {
4907 return PyLocation(
4909 mlirValueGetLocation(self));
4910 },
4911 "Returns the source location of the value.");
4912
4916
4917 nb::class_<PyAsmState>(m, "AsmState")
4918 .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
4919 R"(
4920 Creates an `AsmState` for consistent SSA value naming.
4921
4922 Args:
4923 value: The value to create state for.
4924 use_local_scope: Whether to use local scope for naming.)")
4925 .def(nb::init<PyOperationBase &, bool>(), "op"_a,
4926 "use_local_scope"_a = false,
4927 R"(
4928 Creates an AsmState for consistent SSA value naming.
4929
4930 Args:
4931 op: The operation to create state for.
4932 use_local_scope: Whether to use local scope for naming.)");
4933
4934 //----------------------------------------------------------------------------
4935 // Mapping of SymbolTable.
4936 //----------------------------------------------------------------------------
4937 nb::class_<PySymbolTable>(m, "SymbolTable")
4938 .def(nb::init<PyOperationBase &>(),
4939 R"(
4940 Creates a symbol table for an operation.
4941
4942 Args:
4943 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4944
4945 Raises:
4946 TypeError: If the operation is not a symbol table.)")
4947 .def(
4948 "__getitem__",
4949 [](PySymbolTable &self,
4950 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4951 return self.dunderGetItem(name);
4952 },
4953 R"(
4954 Looks up a symbol by name in the symbol table.
4955
4956 Args:
4957 name: The name of the symbol to look up.
4958
4959 Returns:
4960 The operation defining the symbol.
4961
4962 Raises:
4963 KeyError: If the symbol is not found.)")
4964 .def("insert", &PySymbolTable::insert, "operation"_a,
4965 R"(
4966 Inserts a symbol operation into the symbol table.
4967
4968 Args:
4969 operation: An operation with a symbol name to insert.
4970
4971 Returns:
4972 The symbol name attribute of the inserted operation.
4973
4974 Raises:
4975 ValueError: If the operation does not have a symbol name.)")
4976 .def("erase", &PySymbolTable::erase, "operation"_a,
4977 R"(
4978 Erases a symbol operation from the symbol table.
4979
4980 Args:
4981 operation: The symbol operation to erase.
4982
4983 Note:
4984 The operation is also erased from the IR and invalidated.)")
4985 .def("__delitem__", &PySymbolTable::dunderDel,
4986 "Deletes a symbol by name from the symbol table.")
4987 .def(
4988 "__contains__",
4989 [](PySymbolTable &table, const std::string &name) {
4990 return !mlirOperationIsNull(mlirSymbolTableLookup(
4991 table, mlirStringRefCreate(name.data(), name.length())));
4992 },
4993 "Checks if a symbol with the given name exists in the table.")
4994 // Static helpers.
4995 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
4996 "name"_a, "Sets the symbol name for a symbol operation.")
4997 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
4998 "Gets the symbol name from a symbol operation.")
4999 .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
5000 "Gets the visibility attribute of a symbol operation.")
5001 .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
5002 "visibility"_a,
5003 "Sets the visibility attribute of a symbol operation.")
5004 .def_static("replace_all_symbol_uses",
5005 &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
5006 "new_symbol"_a, "from_op"_a,
5007 "Replaces all uses of a symbol with a new symbol name within "
5008 "the given operation.")
5009 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
5010 "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
5011 "Walks symbol tables starting from an operation with a "
5012 "callback function.");
5013
5014 // Container bindings.
5029
5030 // Debug bindings.
5032
5033 // Attribute builder getter.
5035
5036 // Extensible Dialect
5040
5041 // MLIRError exception.
5042 MLIRError::bind(m);
5043}
5044} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
5045} // namespace python
5046} // namespace mlir
return success()
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:32
static const char kModuleParseDocstring[]
Definition IRCore.cpp:32
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:55
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:60
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:43
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:445
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:471
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1105
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:674
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1125
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:686
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:678
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1142
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:520
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:741
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:436
bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID, MlirContext context)
Definition IR.cpp:654
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:177
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:737
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:638
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1290
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1161
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1091
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:201
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:467
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:209
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:802
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:710
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:219
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:459
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:214
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:232
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1137
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:749
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1309
MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags)
Definition IR.cpp:156
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:642
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:236
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:251
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1101
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:453
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:424
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:224
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:682
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1165
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:745
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1305
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:414
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:821
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:726
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1152
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:428
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:646
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:503
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:634
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1133
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1147
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:718
MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:722
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:660
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:816
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1301
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:512
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:664
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1294
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:826
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1129
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1110
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:690
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:588
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:255
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1286
void mlirOperationRemoveFromParent(MlirOperation op)
Definition IR.cpp:640
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1097
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:811
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:205
void mlirValueDump(MlirValue value)
Definition IR.cpp:1169
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1275
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:449
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:625
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:195
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1156
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:516
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:247
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:228
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:240
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:508
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:479
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:714
void mlirTypeDump(MlirType type)
Definition IR.cpp:1280
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:807
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the value caster registration bindin...
Definition Interop.h:142
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Definition Interop.h:130
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
Accumulates into a file, either writing text (default) or binary.
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:310
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:291
PyAsmState(MlirValue value, bool useLocalScope)
Definition IRCore.cpp:1749
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1014
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRCore.h:1016
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:1859
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:1855
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:1863
nanobind::typed< nanobind::object, PyAttribute > maybeDownCast()
Definition IRCore.cpp:1871
PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2188
Python wrapper for MlirBlockArgument.
Definition IRCore.h:1629
nanobind::typed< nanobind::object, PyBlock > dunderNext()
Definition IRCore.cpp:216
Blocks are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1424
PyBlock appendBlock(const nanobind::args &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Definition IRCore.cpp:272
PyBlockPredecessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2348
PyBlockSuccessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2325
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:185
Represents a diagnostic handler attached to the context.
Definition IRCore.h:418
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:727
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:721
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.h:430
Python class mirroring the C MlirDiagnostic struct.
Definition IRCore.h:368
nanobind::typed< nanobind::tuple, PyDiagnostic > getNotes()
Definition IRCore.cpp:766
Wrapper around an MlirDialectRegistry.
Definition IRCore.h:510
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:811
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRCore.h:486
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:794
static bool attach(const nanobind::object &opName, const nanobind::object &target, PyMlirContext &context)
Definition IRCore.cpp:2553
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2605
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2624
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:29
void registerOpAdaptorImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds an operation adaptor class.
Definition Globals.cpp:167
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:203
bool loadDialectModule(std::string_view dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:64
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation dialect class.
Definition Globals.cpp:145
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
std::optional< nanobind::object > lookupOperationClass(std::string_view operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:233
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:125
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:156
void setDialectSearchPrefixes(std::vector< std::string > newValues)
Definition Globals.h:43
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:189
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:135
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:179
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:218
std::vector< std::string > getDialectSearchPrefixes()
Get and set the list of parent modules to search for dialect implementation classes.
Definition Globals.h:39
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false, bool allow_existing=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:99
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:845
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1780
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:1845
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1819
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:1828
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1806
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1771
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:1841
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:835
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:827
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:823
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:839
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRCore.h:319
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:460
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:485
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:453
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:500
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:494
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:224
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRCore.cpp:445
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRCore.h:258
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:490
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:449
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:1839
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:904
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:560
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:872
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:897
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRCore.h:1040
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:1889
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRCore.h:66
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRCore.h:104
void dunderSetItem(const std::string &name, const PyAttribute &attr)
Definition IRCore.cpp:2408
nanobind::typed< nanobind::object, PyAttribute > dunderGetItemNamed(const std::string &name)
Definition IRCore.cpp:2375
nanobind::typed< nanobind::object, std::optional< PyAttribute > > get(const std::string &key, nanobind::object defaultValue)
Definition IRCore.cpp:2385
static void forEachAttr(MlirOperation op, std::function< void(MlirStringRef, MlirAttribute)> fn)
Definition IRCore.cpp:2430
PyNamedAttribute dunderGetItemIndexed(intptr_t index)
Definition IRCore.cpp:2393
nanobind::typed< nanobind::object, PyOpOperand > dunderNext()
Definition IRCore.cpp:381
PyOpOperandList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2220
void dunderSetItem(intptr_t index, PyValue value)
Definition IRCore.cpp:2228
nanobind::typed< nanobind::object, PyOpView > getOwner() const
Definition IRCore.cpp:362
PyOpOperands(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2263
Sliceable< PyOpOperandList, PyOpOperand > SliceableT
Definition IRCore.cpp:2261
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1399
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2292
void dunderSetItem(intptr_t index, PyBlock block)
Definition IRCore.cpp:2300
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRCore.h:747
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1739
static nanobind::typed< nanobind::object, PyOperation > buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::sequence > resultTypeList, nanobind::sequence operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1562
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1731
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:590
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1163
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1117
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition IRCore.cpp:1016
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1064
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1145
void walk(std::function< PyWalkResult(MlirOperation)> callback, PyWalkOrder walkOrder)
Definition IRCore.cpp:1085
nanobind::typed< nanobind::object, PyOpView > dunderNext()
Definition IRCore.cpp:293
Operations are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1465
nanobind::typed< nanobind::object, PyOpView > dunderGetItem(intptr_t index)
Definition IRCore.cpp:332
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, const MlirValue *operands, size_t numOperands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1227
void setInvalid()
Invalidate the operation.
Definition IRCore.h:714
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRCore.h:647
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:973
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1342
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1203
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1179
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1351
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1198
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:957
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRCore.cpp:985
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1362
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1189
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:964
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:912
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRCore.cpp:1000
Regions of an op are fixed length and indexed numerically so are represented with a sequence-like con...
Definition IRCore.h:1385
PyRegionList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:193
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2041
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2005
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2081
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2026
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2122
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2013
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2110
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2053
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2036
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2066
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2092
Tracks an entry in the thread context stack.
Definition IRCore.h:137
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:637
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:665
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:677
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:642
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:585
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:688
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:647
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:632
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:580
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRCore.h:193
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRCore.h:913
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:1935
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:1945
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:1939
Wrapper around the generic MlirType.
Definition IRCore.h:887
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:889
bool operator==(const PyType &other) const
Definition IRCore.cpp:1901
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:1905
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:1909
nanobind::typed< nanobind::object, PyType > maybeDownCast()
Definition IRCore.cpp:1917
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:1953
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRCore.h:1181
nanobind::typed< nanobind::object, std::variant< PyBlockArgument, PyOpResult, PyValue > > maybeDownCast()
Definition IRCore.cpp:1972
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:1993
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
struct MlirDiagnostic MlirDiagnostic
Definition Diagnostics.h:29
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate(void)
Get the dynamic op trait that indicates the operation is a terminator.
MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void)
Get the type ID of the dynamic op trait that indicates the operation is a terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate(MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData)
Create a custom dynamic op trait with the given type ID and callbacks.
MLIR_CAPI_EXPORTED bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, MlirStringRef opName, MlirContext context)
Attach a dynamic op trait to the given operation name.
MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void)
Get the type ID of the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate(void)
Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:264
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:860
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1212
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:116
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:851
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:272
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1394
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:136
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:310
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1373
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:853
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1321
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:391
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1378
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:83
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1342
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1255
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1238
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:72
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:961
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition IR.cpp:320
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:354
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1194
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1177
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:402
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1226
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:288
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:360
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1313
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:1000
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:875
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1198
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:842
MlirWalkResult
Operation walk result.
Definition IR.h:846
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:941
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1165
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:99
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1069
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:292
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:346
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1050
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:94
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition IR.cpp:342
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:947
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:984
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:946
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1025
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1087
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1368
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1259
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:280
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1224
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1358
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:406
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:298
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:374
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1073
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:836
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1184
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:111
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:867
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1271
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1234
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:333
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1015
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:387
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:879
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:268
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1263
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1354
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:1004
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:324
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:398
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:897
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:54
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:996
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:410
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1078
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1319
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1383
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:937
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:1008
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:885
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1267
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1330
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:107
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:120
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:304
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:378
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:103
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:328
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1216
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1350
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:924
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1064
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:865
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1255
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:76
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
Definition IR.cpp:915
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:871
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1171
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:858
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:70
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:930
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:124
MLIR_CAPI_EXPORTED int mlirLlvmThreadPoolGetMaxConcurrency(MlirLlvmThreadPool pool)
Returns the maximum number of threads in the thread pool.
Definition Support.cpp:38
MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool)
Destroy an LLVM thread pool.
Definition Support.cpp:34
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void)
Create an LLVM thread pool.
Definition Support.cpp:30
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
struct MlirStringRef MlirStringRef
Definition Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:201
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1528
static std::string formatMLIRError(const MLIRError &e)
Definition IRCore.cpp:2755
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m)
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1212
static void populateResultTypes(std::string_view name, nb::sequence resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1441
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:210
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1543
PyObjectRef< PyOperation > PyOperationRef
Definition IRCore.h:642
MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.h:1348
static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait, PyMlirContext &context)
Definition IRCore.cpp:2538
PyObjectRef< PyModule > PyModuleRef
Definition IRCore.h:549
static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData, const char *methodName)
Definition IRCore.cpp:2527
static PyOperationRef getValueOwnerRef(MlirValue value)
Definition IRCore.cpp:1957
MlirBlock MLIR_PYTHON_API_EXPORTED createBlock(const nanobind::typed< nanobind::sequence, PyType > &pyArgTypes, const std::optional< nanobind::typed< nanobind::sequence, PyLocation > > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1388
PyWalkOrder
Traversal order for operation walk.
Definition IRCore.h:359
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m)
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:1873
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
MlirLogicalResult(* verifyTrait)(MlirOperation op, void *userData)
The callback function to verify the operation.
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
MlirLogicalResult(* verifyRegionTrait)(MlirOperation op, void *userData)
The callback function to verify the operation with access to regions.
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Named MLIR attribute.
Definition IR.h:76
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRCore.h:1330
MLIRError(std::string message, std::vector< PyDiagnostic::DiagnosticInfo > &&errorDiagnostics={})
Definition IRCore.h:1331
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition IRCore.h:1341
static void bind(nanobind::module_ &m)
Bind the MLIRError exception class to the given module.
Definition IRCore.cpp:2790
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:145
static nanobind::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:150
static void dunderSetItemNamed(const std::string &attributeKind, nanobind::callable func, bool replace, bool allow_existing)
Definition IRCore.cpp:157
static void set(nanobind::object &o, bool enable)
Definition IRCore.cpp:107
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRCore.h:446
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRCore.h:456