MLIR 23.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
13#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14// clang-format on
16#include "mlir-c/Debug.h"
17#include "mlir-c/Diagnostics.h"
19#include "mlir-c/IR.h"
20#include "mlir-c/Support.h"
21
22#include <array>
23#include <functional>
24#include <optional>
25#include <string>
26
27namespace nb = nanobind;
28using namespace nb::literals;
29using namespace mlir;
31
32static const char kModuleParseDocstring[] =
33 R"(Parses a module's assembly format from a string.
34
35Returns a new MlirModule or raises an MLIRError if the parsing fails.
36
37See also: https://mlir.llvm.org/docs/LangRef/
38)";
39
40static const char kDumpDocstring[] =
41 "Dumps a debug representation of the object to stderr.";
42
44 R"(Replace all uses of this value with the `with` value, except for those
45in `exceptions`. `exceptions` can be either a single operation or a list of
46operations.
47)";
48
49//------------------------------------------------------------------------------
50// Utilities.
51//------------------------------------------------------------------------------
52
53/// Local helper to compute std::hash for a value.
54template <typename T>
55static size_t hash(const T &value) {
56 return std::hash<T>{}(value);
57}
58
59static nb::object
60createCustomDialectWrapper(const std::string &dialectNamespace,
61 nb::object dialectDescriptor) {
62 auto dialectClass =
64 dialectNamespace);
65 if (!dialectClass) {
66 // Use the base class.
68 std::move(dialectDescriptor)));
69 }
70
71 // Create the custom implementation.
72 return (*dialectClass)(std::move(dialectDescriptor));
73}
74
75namespace mlir {
76namespace python {
78
79MlirBlock createBlock(const nb::sequence &pyArgTypes,
80 const std::optional<nb::sequence> &pyArgLocs) {
81 std::vector<MlirType> argTypes;
82 argTypes.reserve(nb::len(pyArgTypes));
83 for (nb::handle pyType : pyArgTypes)
84 argTypes.push_back(
85 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
86
87 std::vector<MlirLocation> argLocs;
88 if (pyArgLocs) {
89 argLocs.reserve(nb::len(*pyArgLocs));
90 for (nb::handle pyLoc : *pyArgLocs)
91 argLocs.push_back(
92 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
93 } else if (!argTypes.empty()) {
94 argLocs.assign(
95 argTypes.size(),
97 }
98
99 if (argTypes.size() != argLocs.size())
100 throw nb::value_error(
101 join("Expected ", argTypes.size(), " locations, got: ", argLocs.size())
102 .c_str());
103 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
104}
105
106void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
107 nb::ft_lock_guard lock(mutex);
108 mlirEnableGlobalDebug(enable);
109}
110
111bool PyGlobalDebugFlag::get(const nb::object &) {
112 nb::ft_lock_guard lock(mutex);
114}
115
116void PyGlobalDebugFlag::bind(nb::module_ &m) {
117 // Debug flags.
118 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
119 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
120 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
121 .def_static(
122 "set_types",
123 [](const std::string &type) {
124 nb::ft_lock_guard lock(mutex);
125 mlirSetGlobalDebugType(type.c_str());
126 },
127 "types"_a, "Sets specific debug types to be produced by LLVM.")
128 .def_static(
129 "set_types",
130 [](const std::vector<std::string> &types) {
131 std::vector<const char *> pointers;
132 pointers.reserve(types.size());
133 for (const std::string &str : types)
134 pointers.push_back(str.c_str());
135 nb::ft_lock_guard lock(mutex);
136 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
137 },
138 "types"_a,
139 "Sets multiple specific debug types to be produced by LLVM.");
140}
141
142nb::ft_mutex PyGlobalDebugFlag::mutex;
143
144bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
145 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
146}
147
148nb::callable
149PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
150 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
151 if (!builder)
152 throw nb::key_error(attributeKind.c_str());
153 return *builder;
154}
155
156void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
157 nb::callable func, bool replace) {
158 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
159 replace);
160}
161
162void PyAttrBuilderMap::bind(nb::module_ &m) {
163 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
164 .def_static("contains", &PyAttrBuilderMap::dunderContains,
165 "attribute_kind"_a,
166 "Checks whether an attribute builder is registered for the "
167 "given attribute kind.")
168 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
169 "attribute_kind"_a,
170 "Gets the registered attribute builder for the given "
171 "attribute kind.")
172 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
173 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
174 "Register an attribute builder for building MLIR "
175 "attributes from Python values.");
176}
177
178//------------------------------------------------------------------------------
179// PyBlock
180//------------------------------------------------------------------------------
181
183 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
184}
185
186//------------------------------------------------------------------------------
187// Collections.
188//------------------------------------------------------------------------------
189
190nb::typed<nb::object, PyRegion> PyRegionIterator::dunderNext() {
191 operation->checkValid();
192 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
193 PyErr_SetNone(PyExc_StopIteration);
194 // python functions should return NULL after setting any exception
195 return nb::object();
196 }
197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198 return nb::cast(PyRegion(operation, region));
199}
200
201void PyRegionIterator::bind(nb::module_ &m) {
202 nb::class_<PyRegionIterator>(m, "RegionIterator")
203 .def("__iter__", &PyRegionIterator::dunderIter,
204 "Returns an iterator over the regions in the operation.")
205 .def("__next__", &PyRegionIterator::dunderNext,
206 "Returns the next region in the iteration.");
207}
208
212 length == -1 ? mlirOperationGetNumRegions(operation->get())
213 : length,
214 step),
215 operation(std::move(operation)) {}
216
218 operation->checkValid();
219 return PyRegionIterator(operation, startIndex);
220}
221
223 c.def("__iter__", &PyRegionList::dunderIter,
224 "Returns an iterator over the regions in the sequence.");
225}
226
227intptr_t PyRegionList::getRawNumElements() {
228 operation->checkValid();
229 return mlirOperationGetNumRegions(operation->get());
230}
231
232PyRegion PyRegionList::getRawElement(intptr_t pos) {
233 operation->checkValid();
234 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
235}
236
237PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
238 intptr_t step) const {
239 return PyRegionList(operation, startIndex, length, step);
240}
241
242nb::typed<nb::object, PyBlock> PyBlockIterator::dunderNext() {
243 operation->checkValid();
244 if (mlirBlockIsNull(next)) {
245 PyErr_SetNone(PyExc_StopIteration);
246 // python functions should return NULL after setting any exception
247 return nb::object();
248 }
249
250 PyBlock returnBlock(operation, next);
251 next = mlirBlockGetNextInRegion(next);
252 return nb::cast(returnBlock);
253}
254
255void PyBlockIterator::bind(nb::module_ &m) {
256 nb::class_<PyBlockIterator>(m, "BlockIterator")
257 .def("__iter__", &PyBlockIterator::dunderIter,
258 "Returns an iterator over the blocks in the operation's region.")
259 .def("__next__", &PyBlockIterator::dunderNext,
260 "Returns the next block in the iteration.");
261}
262
264 operation->checkValid();
265 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
266}
267
269 operation->checkValid();
270 intptr_t count = 0;
271 MlirBlock block = mlirRegionGetFirstBlock(region);
272 while (!mlirBlockIsNull(block)) {
273 count += 1;
274 block = mlirBlockGetNextInRegion(block);
275 }
276 return count;
277}
278
280 operation->checkValid();
281 if (index < 0) {
282 index += dunderLen();
283 }
284 if (index < 0) {
285 throw nb::index_error("attempt to access out of bounds block");
286 }
287 MlirBlock block = mlirRegionGetFirstBlock(region);
288 while (!mlirBlockIsNull(block)) {
289 if (index == 0) {
290 return PyBlock(operation, block);
291 }
292 block = mlirBlockGetNextInRegion(block);
293 index -= 1;
294 }
295 throw nb::index_error("attempt to access out of bounds block");
296}
297
298PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
299 const std::optional<nb::sequence> &pyArgLocs) {
300 operation->checkValid();
301 MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
302 mlirRegionAppendOwnedBlock(region, block);
303 return PyBlock(operation, block);
304}
305
306void PyBlockList::bind(nb::module_ &m) {
307 nb::class_<PyBlockList>(m, "BlockList")
308 .def("__getitem__", &PyBlockList::dunderGetItem,
309 "Returns the block at the specified index.")
310 .def("__iter__", &PyBlockList::dunderIter,
311 "Returns an iterator over blocks in the operation's region.")
312 .def("__len__", &PyBlockList::dunderLen,
313 "Returns the number of blocks in the operation's region.")
314 .def("append", &PyBlockList::appendBlock,
315 R"(
316 Appends a new block, with argument types as positional args.
317
318 Returns:
319 The created block.
320 )",
321 "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
322}
323
324nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
325 parentOperation->checkValid();
326 if (mlirOperationIsNull(next)) {
327 PyErr_SetNone(PyExc_StopIteration);
328 // python functions should return NULL after setting any exception
329 return nb::object();
330 }
331
332 PyOperationRef returnOperation =
333 PyOperation::forOperation(parentOperation->getContext(), next);
334 next = mlirOperationGetNextInBlock(next);
335 return returnOperation->createOpView();
336}
337
338void PyOperationIterator::bind(nb::module_ &m) {
339 nb::class_<PyOperationIterator>(m, "OperationIterator")
340 .def("__iter__", &PyOperationIterator::dunderIter,
341 "Returns an iterator over the operations in an operation's block.")
342 .def("__next__", &PyOperationIterator::dunderNext,
343 "Returns the next operation in the iteration.");
344}
345
347 parentOperation->checkValid();
348 return PyOperationIterator(parentOperation,
350}
351
353 parentOperation->checkValid();
354 intptr_t count = 0;
355 MlirOperation childOp = mlirBlockGetFirstOperation(block);
356 while (!mlirOperationIsNull(childOp)) {
357 count += 1;
358 childOp = mlirOperationGetNextInBlock(childOp);
359 }
360 return count;
361}
362
363nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
364 parentOperation->checkValid();
365 if (index < 0) {
366 index += dunderLen();
367 }
368 if (index < 0) {
369 throw nb::index_error("attempt to access out of bounds operation");
370 }
371 MlirOperation childOp = mlirBlockGetFirstOperation(block);
372 while (!mlirOperationIsNull(childOp)) {
373 if (index == 0) {
374 return PyOperation::forOperation(parentOperation->getContext(), childOp)
375 ->createOpView();
376 }
377 childOp = mlirOperationGetNextInBlock(childOp);
378 index -= 1;
379 }
380 throw nb::index_error("attempt to access out of bounds operation");
381}
382
383void PyOperationList::bind(nb::module_ &m) {
384 nb::class_<PyOperationList>(m, "OperationList")
385 .def("__getitem__", &PyOperationList::dunderGetItem,
386 "Returns the operation at the specified index.")
387 .def("__iter__", &PyOperationList::dunderIter,
388 "Returns an iterator over operations in the list.")
389 .def("__len__", &PyOperationList::dunderLen,
390 "Returns the number of operations in the list.");
391}
392
393nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
394 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
398}
400size_t PyOpOperand::getOperandNumber() const {
401 return mlirOpOperandGetOperandNumber(opOperand);
402}
403
404void PyOpOperand::bind(nb::module_ &m) {
405 nb::class_<PyOpOperand>(m, "OpOperand")
406 .def_prop_ro("owner", &PyOpOperand::getOwner,
407 "Returns the operation that owns this operand.")
408 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
409 "Returns the operand number in the owning operation.");
410}
411
412nb::typed<nb::object, PyOpOperand> PyOpOperandIterator::dunderNext() {
413 if (mlirOpOperandIsNull(opOperand)) {
414 PyErr_SetNone(PyExc_StopIteration);
415 // python functions should return NULL after setting any exception
416 return nb::object();
417 }
418
419 PyOpOperand returnOpOperand(opOperand);
420 opOperand = mlirOpOperandGetNextUse(opOperand);
421 return nb::cast(returnOpOperand);
422}
423
424void PyOpOperandIterator::bind(nb::module_ &m) {
425 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
426 .def("__iter__", &PyOpOperandIterator::dunderIter,
427 "Returns an iterator over operands.")
428 .def("__next__", &PyOpOperandIterator::dunderNext,
429 "Returns the next operand in the iteration.");
430}
432//------------------------------------------------------------------------------
433// PyThreadPool
434//------------------------------------------------------------------------------
435
437
439 if (threadPool.ptr)
440 mlirLlvmThreadPoolDestroy(threadPool);
441}
444 return mlirLlvmThreadPoolGetMaxConcurrency(threadPool);
445}
446
447std::string PyThreadPool::_mlir_thread_pool_ptr() const {
448 std::stringstream ss;
449 ss << threadPool.ptr;
450 return ss.str();
451}
453//------------------------------------------------------------------------------
454// PyMlirContext
455//------------------------------------------------------------------------------
456
457PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
458 nb::gil_scoped_acquire acquire;
459 nb::ft_lock_guard lock(live_contexts_mutex);
460 auto &liveContexts = getLiveContexts();
461 liveContexts[context.ptr] = this;
462}
463
465 // Note that the only public way to construct an instance is via the
466 // forContext method, which always puts the associated handle into
467 // liveContexts.
468 nb::gil_scoped_acquire acquire;
469 {
470 nb::ft_lock_guard lock(live_contexts_mutex);
471 getLiveContexts().erase(context.ptr);
472 }
473 mlirContextDestroy(context);
474}
477 return PyMlirContextRef(this, nb::cast(this));
478}
480nb::object PyMlirContext::getCapsule() {
481 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
482}
483
484nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
485 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
486 if (mlirContextIsNull(rawContext))
487 throw nb::python_error();
488 return forContext(rawContext).releaseObject();
489}
490
491PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
492 nb::gil_scoped_acquire acquire;
493 nb::ft_lock_guard lock(live_contexts_mutex);
494 auto &liveContexts = getLiveContexts();
495 auto it = liveContexts.find(context.ptr);
496 if (it == liveContexts.end()) {
497 // Create.
498 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
499 nb::object pyRef = nb::cast(unownedContextWrapper);
500 assert(pyRef && "cast to nb::object failed");
501 liveContexts[context.ptr] = unownedContextWrapper;
502 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
503 }
504 // Use existing.
505 nb::object pyRef = nb::cast(it->second);
506 return PyMlirContextRef(it->second, std::move(pyRef));
507}
508
509nb::ft_mutex PyMlirContext::live_contexts_mutex;
510
511PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
512 static LiveContextMap liveContexts;
513 return liveContexts;
514}
515
517 nb::ft_lock_guard lock(live_contexts_mutex);
518 return getLiveContexts().size();
519}
521nb::object PyMlirContext::contextEnter(nb::object context) {
522 return PyThreadContextEntry::pushContext(context);
523}
524
525void PyMlirContext::contextExit(const nb::object &excType,
526 const nb::object &excVal,
527 const nb::object &excTb) {
529}
530
531nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
532 // Note that ownership is transferred to the delete callback below by way of
533 // an explicit inc_ref (borrow).
534 PyDiagnosticHandler *pyHandler =
535 new PyDiagnosticHandler(get(), std::move(callback));
536 nb::object pyHandlerObject =
537 nb::cast(pyHandler, nb::rv_policy::take_ownership);
538 (void)pyHandlerObject.inc_ref();
539
540 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
541 // guaranteed to be known to pybind.
542 auto handlerCallback =
543 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
544 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
545 nb::object pyDiagnosticObject =
546 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
547
548 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
549 bool result = false;
550 {
551 // Since this can be called from arbitrary C++ contexts, always get the
552 // gil.
553 nb::gil_scoped_acquire gil;
554 try {
555 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
556 } catch (std::exception &e) {
557 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
558 e.what());
559 pyHandler->hadError = true;
560 }
561 }
562
563 pyDiagnostic->invalidate();
565 };
566 auto deleteCallback = +[](void *userData) {
567 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
568 assert(pyHandler->registeredID && "handler is not registered");
569 pyHandler->registeredID.reset();
570
571 // Decrement reference, balancing the inc_ref() above.
572 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
573 pyHandlerObject.dec_ref();
574 };
575
576 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
577 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
578 return pyHandlerObject;
579}
580
581MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
582 void *userData) {
583 auto *self = static_cast<ErrorCapture *>(userData);
584 // Check if the context requested we emit errors instead of capturing them.
585 if (self->ctx->emitErrorDiagnostics)
587
589 MlirDiagnosticSeverity::MlirDiagnosticError)
592 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
594}
595
598 if (!context) {
599 throw std::runtime_error(
600 "An MLIR function requires a Context but none was provided in the call "
601 "or from the surrounding environment. Either pass to the function with "
602 "a 'context=' argument or establish a default using 'with Context():'");
603 }
604 return *context;
605}
607//------------------------------------------------------------------------------
608// PyThreadContextEntry management
609//------------------------------------------------------------------------------
610
611std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
612 static thread_local std::vector<PyThreadContextEntry> stack;
613 return stack;
614}
615
617 auto &stack = getStack();
618 if (stack.empty())
619 return nullptr;
620 return &stack.back();
621}
622
623void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
624 nb::object insertionPoint,
625 nb::object location) {
626 auto &stack = getStack();
627 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
628 std::move(location));
629 // If the new stack has more than one entry and the context of the new top
630 // entry matches the previous, copy the insertionPoint and location from the
631 // previous entry if missing from the new top entry.
632 if (stack.size() > 1) {
633 auto &prev = *(stack.rbegin() + 1);
634 auto &current = stack.back();
635 if (current.context.is(prev.context)) {
636 // Default non-context objects from the previous entry.
637 if (!current.insertionPoint)
638 current.insertionPoint = prev.insertionPoint;
639 if (!current.location)
640 current.location = prev.location;
641 }
642 }
643}
644
646 if (!context)
647 return nullptr;
648 return nb::cast<PyMlirContext *>(context);
649}
650
652 if (!insertionPoint)
653 return nullptr;
654 return nb::cast<PyInsertionPoint *>(insertionPoint);
655}
656
658 if (!location)
659 return nullptr;
660 return nb::cast<PyLocation *>(location);
661}
662
664 auto *tos = getTopOfStack();
665 return tos ? tos->getContext() : nullptr;
666}
667
669 auto *tos = getTopOfStack();
670 return tos ? tos->getInsertionPoint() : nullptr;
671}
672
674 auto *tos = getTopOfStack();
675 return tos ? tos->getLocation() : nullptr;
676}
677
678nb::object PyThreadContextEntry::pushContext(nb::object context) {
679 push(FrameKind::Context, /*context=*/context,
680 /*insertionPoint=*/nb::object(),
681 /*location=*/nb::object());
682 return context;
683}
684
686 auto &stack = getStack();
687 if (stack.empty())
688 throw std::runtime_error("Unbalanced Context enter/exit");
689 auto &tos = stack.back();
690 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
691 throw std::runtime_error("Unbalanced Context enter/exit");
692 stack.pop_back();
693}
694
695nb::object
696PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
697 PyInsertionPoint &insertionPoint =
698 nb::cast<PyInsertionPoint &>(insertionPointObj);
699 nb::object contextObj =
700 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
701 push(FrameKind::InsertionPoint,
702 /*context=*/contextObj,
703 /*insertionPoint=*/insertionPointObj,
704 /*location=*/nb::object());
705 return insertionPointObj;
706}
707
709 auto &stack = getStack();
710 if (stack.empty())
711 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
712 auto &tos = stack.back();
713 if (tos.frameKind != FrameKind::InsertionPoint &&
714 tos.getInsertionPoint() != &insertionPoint)
715 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
716 stack.pop_back();
717}
718
719nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
720 PyLocation &location = nb::cast<PyLocation &>(locationObj);
721 nb::object contextObj = location.getContext().getObject();
722 push(FrameKind::Location, /*context=*/contextObj,
723 /*insertionPoint=*/nb::object(),
724 /*location=*/locationObj);
725 return locationObj;
726}
727
729 auto &stack = getStack();
730 if (stack.empty())
731 throw std::runtime_error("Unbalanced Location enter/exit");
732 auto &tos = stack.back();
733 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
734 throw std::runtime_error("Unbalanced Location enter/exit");
735 stack.pop_back();
736}
738//------------------------------------------------------------------------------
739// PyDiagnostic*
740//------------------------------------------------------------------------------
741
743 valid = false;
744 if (materializedNotes) {
745 for (nb::handle noteObject : *materializedNotes) {
746 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
747 note->invalidate();
748 }
749 }
750}
751
753 nb::object callback)
754 : context(context), callback(std::move(callback)) {}
755
757
759 if (!registeredID)
760 return;
761 MlirDiagnosticHandlerID localID = *registeredID;
762 mlirContextDetachDiagnosticHandler(context, localID);
763 assert(!registeredID && "should have unregistered");
764 // Not strictly necessary but keeps stale pointers from being around to cause
765 // issues.
766 context = {nullptr};
767}
768
769void PyDiagnostic::checkValid() {
770 if (!valid) {
771 throw std::invalid_argument(
772 "Diagnostic is invalid (used outside of callback)");
773 }
774}
775
777 checkValid();
778 return static_cast<PyDiagnosticSeverity>(
779 mlirDiagnosticGetSeverity(diagnostic));
780}
781
783 checkValid();
784 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
785 MlirContext context = mlirLocationGetContext(loc);
786 return PyLocation(PyMlirContext::forContext(context), loc);
787}
788
789nb::str PyDiagnostic::getMessage() {
790 checkValid();
791 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
792 PyFileAccumulator accum(fileObject, /*binary=*/false);
793 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
794 return nb::cast<nb::str>(fileObject.attr("getvalue")());
795}
796
797nb::tuple PyDiagnostic::getNotes() {
798 checkValid();
799 if (materializedNotes)
800 return *materializedNotes;
801 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
802 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
803 for (intptr_t i = 0; i < numNotes; ++i) {
804 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
805 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
806 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
807 }
808 materializedNotes = std::move(notes);
809
810 return *materializedNotes;
811}
812
814 std::vector<DiagnosticInfo> notes;
815 for (nb::handle n : getNotes())
816 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
817 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
818 std::move(notes)};
819}
821//------------------------------------------------------------------------------
822// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
823//------------------------------------------------------------------------------
824
825MlirDialect PyDialects::getDialectForKey(const std::string &key,
826 bool attrError) {
827 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
828 {key.data(), key.size()});
829 if (mlirDialectIsNull(dialect)) {
830 std::string msg = join("Dialect '", key, "' not found");
831 if (attrError)
832 throw nb::attribute_error(msg.c_str());
833 throw nb::index_error(msg.c_str());
834 }
835 return dialect;
836}
839 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
840}
841
843 MlirDialectRegistry rawRegistry =
845 if (mlirDialectRegistryIsNull(rawRegistry))
846 throw nb::python_error();
847 return PyDialectRegistry(rawRegistry);
848}
850//------------------------------------------------------------------------------
851// PyLocation
852//------------------------------------------------------------------------------
854nb::object PyLocation::getCapsule() {
855 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
856}
857
858PyLocation PyLocation::createFromCapsule(nb::object capsule) {
859 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
860 if (mlirLocationIsNull(rawLoc))
861 throw nb::python_error();
863 rawLoc);
864}
866nb::object PyLocation::contextEnter(nb::object locationObj) {
867 return PyThreadContextEntry::pushLocation(locationObj);
868}
869
870void PyLocation::contextExit(const nb::object &excType,
871 const nb::object &excVal,
872 const nb::object &excTb) {
874}
875
878 if (!location) {
879 throw std::runtime_error(
880 "An MLIR function requires a Location but none was provided in the "
881 "call or from the surrounding environment. Either pass to the function "
882 "with a 'loc=' argument or establish a default using 'with loc:'");
883 }
884 return *location;
885}
886
887//------------------------------------------------------------------------------
888// PyModule
889//------------------------------------------------------------------------------
890
891PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
892 : BaseContextObject(std::move(contextRef)), module(module) {}
893
895 nb::gil_scoped_acquire acquire;
896 auto &liveModules = getContext()->liveModules;
897 assert(liveModules.count(module.ptr) == 1 &&
898 "destroying module not in live map");
899 liveModules.erase(module.ptr);
900 mlirModuleDestroy(module);
901}
902
903PyModuleRef PyModule::forModule(MlirModule module) {
904 MlirContext context = mlirModuleGetContext(module);
905 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
906
907 nb::gil_scoped_acquire acquire;
908 auto &liveModules = contextRef->liveModules;
909 auto it = liveModules.find(module.ptr);
910 if (it == liveModules.end()) {
911 // Create.
912 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
913 // Note that the default return value policy on cast is automatic_reference,
914 // which does not take ownership (delete will not be called).
915 // Just be explicit.
916 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
917 unownedModule->handle = pyRef;
918 liveModules[module.ptr] =
919 std::make_pair(unownedModule->handle, unownedModule);
920 return PyModuleRef(unownedModule, std::move(pyRef));
921 }
922 // Use existing.
923 PyModule *existing = it->second.second;
924 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
925 return PyModuleRef(existing, std::move(pyRef));
926}
927
928nb::object PyModule::createFromCapsule(nb::object capsule) {
929 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
930 if (mlirModuleIsNull(rawModule))
931 throw nb::python_error();
932 return forModule(rawModule).releaseObject();
933}
934
935nb::object PyModule::getCapsule() {
936 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
937}
939//------------------------------------------------------------------------------
940// PyOperation
941//------------------------------------------------------------------------------
942
943PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
944 : BaseContextObject(std::move(contextRef)), operation(operation) {}
945
947 // If the operation has already been invalidated there is nothing to do.
948 if (!valid)
949 return;
950 // Otherwise, invalidate the operation when it is attached.
951 if (isAttached())
952 setInvalid();
953 else {
954 // And destroy it when it is detached, i.e. owned by Python.
955 erase();
956 }
957}
958
959namespace {
960
961// Constructs a new object of type T in-place on the Python heap, returning a
962// PyObjectRef to it, loosely analogous to std::make_shared<T>().
963template <typename T, class... Args>
964PyObjectRef<T> makeObjectRef(Args &&...args) {
965 nb::handle type = nb::type<T>();
966 nb::object instance = nb::inst_alloc(type);
967 T *ptr = nb::inst_ptr<T>(instance);
968 new (ptr) T(std::forward<Args>(args)...);
969 nb::inst_mark_ready(instance);
970 return PyObjectRef<T>(ptr, std::move(instance));
971}
972
973} // namespace
974
975PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
976 MlirOperation operation,
977 nb::object parentKeepAlive) {
978 // Create.
979 PyOperationRef unownedOperation =
980 makeObjectRef<PyOperation>(std::move(contextRef), operation);
981 unownedOperation->handle = unownedOperation.getObject();
982 if (parentKeepAlive) {
983 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
984 }
985 return unownedOperation;
986}
987
989 MlirOperation operation,
990 nb::object parentKeepAlive) {
991 return createInstance(std::move(contextRef), operation,
992 std::move(parentKeepAlive));
993}
994
996 MlirOperation operation,
997 nb::object parentKeepAlive) {
998 PyOperationRef created = createInstance(std::move(contextRef), operation,
999 std::move(parentKeepAlive));
1000 created->attached = false;
1001 return created;
1002}
1003
1005 const std::string &sourceStr,
1006 const std::string &sourceName) {
1007 PyMlirContext::ErrorCapture errors(contextRef);
1008 MlirOperation op =
1009 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1010 toMlirStringRef(sourceName));
1011 if (mlirOperationIsNull(op))
1012 throw MLIRError("Unable to parse operation assembly", errors.take());
1013 return PyOperation::createDetached(std::move(contextRef), op);
1014}
1015
1018 setDetached();
1019 parentKeepAlive = nb::object();
1020}
1021
1022MlirOperation PyOperation::get() const {
1023 checkValid();
1024 return operation;
1025}
1028 return PyOperationRef(this, nb::borrow<nb::object>(handle));
1029}
1030
1031void PyOperation::setAttached(const nb::object &parent) {
1032 assert(!attached && "operation already attached");
1033 attached = true;
1034}
1035
1037 assert(attached && "operation already detached");
1038 attached = false;
1039}
1040
1041void PyOperation::checkValid() const {
1042 if (!valid) {
1043 throw std::runtime_error("the operation has been invalidated");
1044 }
1045}
1046
1047void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1048 std::optional<int64_t> largeResourceLimit,
1049 bool enableDebugInfo, bool prettyDebugInfo,
1050 bool printGenericOpForm, bool useLocalScope,
1051 bool useNameLocAsPrefix, bool assumeVerified,
1052 nb::object fileObject, bool binary,
1053 bool skipRegions) {
1054 PyOperation &operation = getOperation();
1055 operation.checkValid();
1056 if (fileObject.is_none())
1057 fileObject = nb::module_::import_("sys").attr("stdout");
1058
1059 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1060 if (largeElementsLimit)
1061 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1062 if (largeResourceLimit)
1063 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1064 if (enableDebugInfo)
1065 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1066 /*prettyForm=*/prettyDebugInfo);
1067 if (printGenericOpForm)
1069 if (useLocalScope)
1071 if (assumeVerified)
1073 if (skipRegions)
1075 if (useNameLocAsPrefix)
1077
1078 PyFileAccumulator accum(fileObject, binary);
1079 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1080 accum.getUserData());
1082}
1083
1084void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1085 bool binary) {
1086 PyOperation &operation = getOperation();
1087 operation.checkValid();
1088 if (fileObject.is_none())
1089 fileObject = nb::module_::import_("sys").attr("stdout");
1090 PyFileAccumulator accum(fileObject, binary);
1091 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1092 accum.getUserData());
1093}
1094
1095void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1096 std::optional<int64_t> bytecodeVersion) {
1097 PyOperation &operation = getOperation();
1098 operation.checkValid();
1099 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1100
1101 if (!bytecodeVersion.has_value())
1102 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1103 accum.getUserData());
1104
1105 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1108 operation, config, accum.getCallback(), accum.getUserData());
1111 throw nb::value_error(
1112 join("Unable to honor desired bytecode version ", *bytecodeVersion)
1113 .c_str());
1114}
1115
1116void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
1117 PyWalkOrder walkOrder) {
1118 PyOperation &operation = getOperation();
1119 operation.checkValid();
1120 struct UserData {
1121 std::function<PyWalkResult(MlirOperation)> callback;
1122 bool gotException;
1123 std::string exceptionWhat;
1124 nb::object exceptionType;
1125 };
1126 UserData userData{callback, false, {}, {}};
1127 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1128 void *userData) {
1129 UserData *calleeUserData = static_cast<UserData *>(userData);
1130 try {
1131 return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
1132 } catch (nb::python_error &e) {
1133 calleeUserData->gotException = true;
1134 calleeUserData->exceptionWhat = std::string(e.what());
1135 calleeUserData->exceptionType = nb::borrow(e.type());
1136 return MlirWalkResult::MlirWalkResultInterrupt;
1137 }
1138 };
1139 mlirOperationWalk(operation, walkCallback, &userData,
1140 static_cast<MlirWalkOrder>(walkOrder));
1141 if (userData.gotException) {
1142 std::string message("Exception raised in callback: ");
1143 message.append(userData.exceptionWhat);
1144 throw std::runtime_error(message);
1145 }
1146}
1147
1148nb::object PyOperationBase::getAsm(bool binary,
1149 std::optional<int64_t> largeElementsLimit,
1150 std::optional<int64_t> largeResourceLimit,
1151 bool enableDebugInfo, bool prettyDebugInfo,
1152 bool printGenericOpForm, bool useLocalScope,
1153 bool useNameLocAsPrefix, bool assumeVerified,
1154 bool skipRegions) {
1155 nb::object fileObject;
1156 if (binary) {
1157 fileObject = nb::module_::import_("io").attr("BytesIO")();
1158 } else {
1159 fileObject = nb::module_::import_("io").attr("StringIO")();
1160 }
1161 print(/*largeElementsLimit=*/largeElementsLimit,
1162 /*largeResourceLimit=*/largeResourceLimit,
1163 /*enableDebugInfo=*/enableDebugInfo,
1164 /*prettyDebugInfo=*/prettyDebugInfo,
1165 /*printGenericOpForm=*/printGenericOpForm,
1166 /*useLocalScope=*/useLocalScope,
1167 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1168 /*assumeVerified=*/assumeVerified,
1169 /*fileObject=*/fileObject,
1170 /*binary=*/binary,
1171 /*skipRegions=*/skipRegions);
1172
1173 return fileObject.attr("getvalue")();
1174}
1175
1177 PyOperation &operation = getOperation();
1178 PyOperation &otherOp = other.getOperation();
1179 operation.checkValid();
1180 otherOp.checkValid();
1181 mlirOperationMoveAfter(operation, otherOp);
1182 operation.parentKeepAlive = otherOp.parentKeepAlive;
1183}
1184
1186 PyOperation &operation = getOperation();
1187 PyOperation &otherOp = other.getOperation();
1188 operation.checkValid();
1189 otherOp.checkValid();
1190 mlirOperationMoveBefore(operation, otherOp);
1191 operation.parentKeepAlive = otherOp.parentKeepAlive;
1192}
1193
1195 PyOperation &operation = getOperation();
1196 PyOperation &otherOp = other.getOperation();
1197 operation.checkValid();
1198 otherOp.checkValid();
1199 return mlirOperationIsBeforeInBlock(operation, otherOp);
1200}
1201
1203 PyOperation &op = getOperation();
1206 throw MLIRError("Verification failed", errors.take());
1207 return true;
1208}
1209
1210std::optional<PyOperationRef> PyOperation::getParentOperation() {
1211 checkValid();
1212 if (!isAttached())
1213 throw nb::value_error("Detached operations have no parent");
1214 MlirOperation operation = mlirOperationGetParentOperation(get());
1215 if (mlirOperationIsNull(operation))
1216 return {};
1217 return PyOperation::forOperation(getContext(), operation);
1218}
1219
1221 checkValid();
1222 std::optional<PyOperationRef> parentOperation = getParentOperation();
1223 MlirBlock block = mlirOperationGetBlock(get());
1224 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1225 assert(parentOperation && "Operation has no parent");
1226 return PyBlock{std::move(*parentOperation), block};
1227}
1228
1230 checkValid();
1231 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1232}
1233
1234nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1235 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1236 if (mlirOperationIsNull(rawOperation))
1237 throw nb::python_error();
1238 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1239 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1240 .releaseObject();
1241}
1242
1243static void maybeInsertOperation(PyOperationRef &op,
1244 const nb::object &maybeIp) {
1245 // InsertPoint active?
1246 if (!maybeIp.is(nb::cast(false))) {
1247 PyInsertionPoint *ip;
1248 if (maybeIp.is_none()) {
1250 } else {
1251 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1252 }
1253 if (ip)
1254 ip->insert(*op.get());
1255 }
1256}
1257
1258nb::object PyOperation::create(std::string_view name,
1259 std::optional<std::vector<PyType *>> results,
1260 const MlirValue *operands, size_t numOperands,
1261 std::optional<nb::dict> attributes,
1262 std::optional<std::vector<PyBlock *>> successors,
1263 int regions, PyLocation &location,
1264 const nb::object &maybeIp, bool inferType) {
1265 std::vector<MlirType> mlirResults;
1266 std::vector<MlirBlock> mlirSuccessors;
1267 std::vector<std::pair<std::string, MlirAttribute>> mlirAttributes;
1268
1269 // General parameter validation.
1270 if (regions < 0)
1271 throw nb::value_error("number of regions must be >= 0");
1272
1273 // Unpack/validate results.
1274 if (results) {
1275 mlirResults.reserve(results->size());
1276 for (PyType *result : *results) {
1277 // TODO: Verify result type originate from the same context.
1278 if (!result)
1279 throw nb::value_error("result type cannot be None");
1280 mlirResults.push_back(*result);
1281 }
1282 }
1283 // Unpack/validate attributes.
1284 if (attributes) {
1285 mlirAttributes.reserve(attributes->size());
1286 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1287 std::string key;
1288 try {
1289 key = nb::cast<std::string>(it.first);
1290 } catch (nb::cast_error &err) {
1291 std::string msg = join("Invalid attribute key (not a string) when "
1292 "attempting to create the operation \"",
1293 name, "\" (", err.what(), ")");
1294 throw nb::type_error(msg.c_str());
1295 }
1296 try {
1297 auto &attribute = nb::cast<PyAttribute &>(it.second);
1298 // TODO: Verify attribute originates from the same context.
1299 mlirAttributes.emplace_back(std::move(key), attribute);
1300 } catch (nb::cast_error &err) {
1301 std::string msg = join("Invalid attribute value for the key \"", key,
1302 "\" when attempting to create the operation \"",
1303 name, "\" (", err.what(), ")");
1304 throw nb::type_error(msg.c_str());
1305 } catch (std::runtime_error &) {
1306 // This exception seems thrown when the value is "None".
1307 std::string msg = join(
1308 "Found an invalid (`None`?) attribute value for the key \"", key,
1309 "\" when attempting to create the operation \"", name, "\"");
1310 throw std::runtime_error(msg);
1311 }
1312 }
1313 }
1314 // Unpack/validate successors.
1315 if (successors) {
1316 mlirSuccessors.reserve(successors->size());
1317 for (PyBlock *successor : *successors) {
1318 // TODO: Verify successor originate from the same context.
1319 if (!successor)
1320 throw nb::value_error("successor block cannot be None");
1321 mlirSuccessors.push_back(successor->get());
1322 }
1323 }
1324
1325 // Apply unpacked/validated to the operation state. Beyond this
1326 // point, exceptions cannot be thrown or else the state will leak.
1327 MlirOperationState state =
1328 mlirOperationStateGet(toMlirStringRef(name), location);
1329 if (numOperands > 0)
1330 mlirOperationStateAddOperands(&state, numOperands, operands);
1331 state.enableResultTypeInference = inferType;
1332 if (!mlirResults.empty())
1333 mlirOperationStateAddResults(&state, mlirResults.size(),
1334 mlirResults.data());
1335 if (!mlirAttributes.empty()) {
1336 // Note that the attribute names directly reference bytes in
1337 // mlirAttributes, so that vector must not be changed from here
1338 // on.
1339 std::vector<MlirNamedAttribute> mlirNamedAttributes;
1340 mlirNamedAttributes.reserve(mlirAttributes.size());
1341 for (const std::pair<std::string, MlirAttribute> &it : mlirAttributes)
1342 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1344 toMlirStringRef(it.first)),
1345 it.second));
1346 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1347 mlirNamedAttributes.data());
1348 }
1349 if (!mlirSuccessors.empty())
1350 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1351 mlirSuccessors.data());
1352 if (regions) {
1353 std::vector<MlirRegion> mlirRegions;
1354 mlirRegions.resize(regions);
1355 for (int i = 0; i < regions; ++i)
1356 mlirRegions[i] = mlirRegionCreate();
1357 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1358 mlirRegions.data());
1359 }
1360
1361 // Construct the operation.
1362 PyMlirContext::ErrorCapture errors(location.getContext());
1363 MlirOperation operation = mlirOperationCreate(&state);
1364 if (!operation.ptr)
1365 throw MLIRError("Operation creation failed", errors.take());
1366 PyOperationRef created =
1367 PyOperation::createDetached(location.getContext(), operation);
1368 maybeInsertOperation(created, maybeIp);
1369
1370 return created.getObject();
1371}
1372
1373nb::object PyOperation::clone(const nb::object &maybeIp) {
1374 MlirOperation clonedOperation = mlirOperationClone(operation);
1375 PyOperationRef cloned =
1376 PyOperation::createDetached(getContext(), clonedOperation);
1377 maybeInsertOperation(cloned, maybeIp);
1378
1379 return cloned->createOpView();
1380}
1381
1382nb::object PyOperation::createOpView() {
1383 checkValid();
1384 MlirIdentifier ident = mlirOperationGetName(get());
1385 MlirStringRef identStr = mlirIdentifierStr(ident);
1386 auto operationCls = PyGlobals::get().lookupOperationClass(
1387 std::string_view(identStr.data, identStr.length));
1388 if (operationCls)
1389 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1390 return nb::cast(PyOpView(getRef().getObject()));
1391}
1392
1393void PyOperation::erase() {
1395 setInvalid();
1396 mlirOperationDestroy(operation);
1397}
1398
1399void PyOpResult::bindDerived(ClassTy &c) {
1400 c.def_prop_ro(
1401 "owner",
1402 [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
1403 assert(mlirOperationEqual(self.getParentOperation()->get(),
1404 mlirOpResultGetOwner(self.get())) &&
1405 "expected the owner of the value in Python to match that in "
1406 "the IR");
1407 return self.getParentOperation()->createOpView();
1408 },
1409 "Returns the operation that produces this result.");
1410 c.def_prop_ro(
1411 "result_number",
1412 [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
1413 "Returns the position of this result in the operation's result list.");
1415
1416/// Returns the list of types of the values held by container.
1417template <typename Container>
1418static std::vector<nb::typed<nb::object, PyType>>
1419getValueTypes(Container &container, PyMlirContextRef &context) {
1420 std::vector<nb::typed<nb::object, PyType>> result;
1421 result.reserve(container.size());
1422 for (int i = 0, e = container.size(); i < e; ++i) {
1423 result.push_back(PyType(context->getRef(),
1424 mlirValueGetType(container.getElement(i).get()))
1426 }
1427 return result;
1428}
1429
1431 intptr_t length, intptr_t step)
1432 : Sliceable(startIndex,
1434 : length,
1435 step),
1436 operation(std::move(operation)) {}
1437
1438void PyOpResultList::bindDerived(ClassTy &c) {
1439 c.def_prop_ro(
1440 "types",
1441 [](PyOpResultList &self) {
1442 return getValueTypes(self, self.operation->getContext());
1443 },
1444 "Returns a list of types for all results in this result list.");
1445 c.def_prop_ro(
1446 "owner",
1447 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1448 return self.operation->createOpView();
1449 },
1450 "Returns the operation that owns this result list.");
1451}
1452
1453intptr_t PyOpResultList::getRawNumElements() {
1454 operation->checkValid();
1455 return mlirOperationGetNumResults(operation->get());
1456}
1457
1458PyOpResult PyOpResultList::getRawElement(intptr_t index) {
1459 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1460 return PyOpResult(value);
1461}
1462
1463PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
1464 intptr_t step) const {
1465 return PyOpResultList(operation, startIndex, length, step);
1466}
1468//------------------------------------------------------------------------------
1469// PyOpView
1470//------------------------------------------------------------------------------
1471
1472static void populateResultTypes(std::string_view name, nb::list resultTypeList,
1473 const nb::object &resultSegmentSpecObj,
1474 std::vector<int32_t> &resultSegmentLengths,
1475 std::vector<PyType *> &resultTypes) {
1476 resultTypes.reserve(resultTypeList.size());
1477 if (resultSegmentSpecObj.is_none()) {
1478 // Non-variadic result unpacking.
1479 size_t index = 0;
1480 for (nb::handle resultType : resultTypeList) {
1481 try {
1482 resultTypes.push_back(nb::cast<PyType *>(resultType));
1483 if (!resultTypes.back())
1484 throw nb::cast_error();
1485 } catch (nb::cast_error &err) {
1486 throw nb::value_error(join("Result ", index, " of operation \"", name,
1487 "\" must be a Type (", err.what(), ")")
1488 .c_str());
1489 }
1490 ++index;
1491 }
1492 } else {
1493 // Sized result unpacking.
1494 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1495 if (resultSegmentSpec.size() != resultTypeList.size()) {
1496 throw nb::value_error(
1497 join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
1498 " result segments but was provided ", resultTypeList.size())
1499 .c_str());
1500 }
1501 resultSegmentLengths.reserve(resultTypeList.size());
1502 for (size_t i = 0, e = resultSegmentSpec.size(); i < e; ++i) {
1503 int segmentSpec = resultSegmentSpec[i];
1504 if (segmentSpec == 1 || segmentSpec == 0) {
1505 // Unpack unary element.
1506 try {
1507 auto *resultType = nb::cast<PyType *>(resultTypeList[i]);
1508 if (resultType) {
1509 resultTypes.push_back(resultType);
1510 resultSegmentLengths.push_back(1);
1511 } else if (segmentSpec == 0) {
1512 // Allowed to be optional.
1513 resultSegmentLengths.push_back(0);
1514 } else {
1515 throw nb::value_error(
1516 join("Result ", i, " of operation \"", name,
1517 "\" must be a Type (was None and result is not optional)")
1518 .c_str());
1519 }
1520 } catch (nb::cast_error &err) {
1521 throw nb::value_error(join("Result ", i, " of operation \"", name,
1522 "\" must be a Type (", err.what(), ")")
1523 .c_str());
1524 }
1525 } else if (segmentSpec == -1) {
1526 // Unpack sequence by appending.
1527 try {
1528 if (resultTypeList[i].is_none()) {
1529 // Treat it as an empty list.
1530 resultSegmentLengths.push_back(0);
1531 } else {
1532 // Unpack the list.
1533 auto segment = nb::cast<nb::sequence>(resultTypeList[i]);
1534 for (nb::handle segmentItem : segment) {
1535 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1536 if (!resultTypes.back()) {
1537 throw nb::type_error("contained a None item");
1538 }
1539 }
1540 resultSegmentLengths.push_back(nb::len(segment));
1541 }
1542 } catch (std::exception &err) {
1543 // NOTE: Sloppy to be using a catch-all here, but there are at least
1544 // three different unrelated exceptions that can be thrown in the
1545 // above "casts". Just keep the scope above small and catch them all.
1546 throw nb::value_error(join("Result ", i, " of operation \"", name,
1547 "\" must be a Sequence of Types (",
1548 err.what(), ")")
1549 .c_str());
1550 }
1551 } else {
1552 throw nb::value_error("Unexpected segment spec");
1554 }
1555 }
1556}
1557
1558MlirValue getUniqueResult(MlirOperation operation) {
1559 auto numResults = mlirOperationGetNumResults(operation);
1560 if (numResults != 1) {
1561 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1562 throw nb::value_error(
1563 join("Cannot call .result on operation ",
1564 std::string_view(name.data, name.length), " which has ",
1565 numResults,
1566 " results (it is only valid for operations with a "
1567 "single result)")
1568 .c_str());
1569 }
1570 return mlirOperationGetResult(operation, 0);
1571}
1572
1573static MlirValue getOpResultOrValue(nb::handle operand) {
1574 if (operand.is_none()) {
1575 throw nb::value_error("contained a None item");
1576 }
1577 PyOperationBase *op;
1578 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1579 return getUniqueResult(op->getOperation());
1580 }
1581 PyOpResultList *opResultList;
1582 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1583 return getUniqueResult(opResultList->getOperation()->get());
1584 }
1585 PyValue *value;
1586 if (nb::try_cast<PyValue *>(operand, value)) {
1587 return value->get();
1588 }
1589 throw nb::value_error("is not a Value");
1590}
1591
1592nb::object PyOpView::buildGeneric(
1593 std::string_view name, std::tuple<int, bool> opRegionSpec,
1594 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1595 std::optional<nb::list> resultTypeList, nb::list operandList,
1596 std::optional<nb::dict> attributes,
1597 std::optional<std::vector<PyBlock *>> successors,
1598 std::optional<int> regions, PyLocation &location,
1599 const nb::object &maybeIp) {
1600 PyMlirContextRef context = location.getContext();
1601
1602 // Class level operation construction metadata.
1603 // Operand and result segment specs are either none, which does no
1604 // variadic unpacking, or a list of ints with segment sizes, where each
1605 // element is either a positive number (typically 1 for a scalar) or -1 to
1606 // indicate that it is derived from the length of the same-indexed operand
1607 // or result (implying that it is a list at that position).
1608 std::vector<int32_t> operandSegmentLengths;
1609 std::vector<int32_t> resultSegmentLengths;
1610
1611 // Validate/determine region count.
1612 int opMinRegionCount = std::get<0>(opRegionSpec);
1613 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1614 if (!regions) {
1615 regions = opMinRegionCount;
1616 }
1617 if (*regions < opMinRegionCount) {
1618 throw nb::value_error(join("Operation \"", name,
1619 "\" requires a minimum of ", opMinRegionCount,
1620 " regions but was built with regions=", *regions)
1621 .c_str());
1622 }
1623 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1624 throw nb::value_error(join("Operation \"", name,
1625 "\" requires a maximum of ", opMinRegionCount,
1626 " regions but was built with regions=", *regions)
1627 .c_str());
1628 }
1629
1630 // Unpack results.
1631 std::vector<PyType *> resultTypes;
1632 if (resultTypeList.has_value()) {
1633 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1634 resultSegmentLengths, resultTypes);
1635 }
1636
1637 // Unpack operands.
1638 std::vector<MlirValue> operands;
1639 operands.reserve(operands.size());
1640 size_t index = 0;
1641 if (operandSegmentSpecObj.is_none()) {
1642 // Non-sized operand unpacking.
1643 for (nb::handle operand : operandList) {
1644 try {
1645 operands.push_back(getOpResultOrValue(operand));
1646 } catch (nb::builtin_exception &err) {
1647 throw nb::value_error(join("Operand ", index, " of operation \"", name,
1648 "\" must be a Value (", err.what(), ")")
1649 .c_str());
1650 }
1651 ++index;
1652 }
1653 } else {
1654 // Sized operand unpacking.
1655 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1656 if (operandSegmentSpec.size() != operandList.size()) {
1657 throw nb::value_error(
1658 join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
1659 "operand segments but was provided ", operandList.size())
1660 .c_str());
1661 }
1662 operandSegmentLengths.reserve(operandList.size());
1663 for (size_t i = 0, e = operandSegmentSpec.size(); i < e; ++i) {
1664 int segmentSpec = operandSegmentSpec[i];
1665 if (segmentSpec == 1 || segmentSpec == 0) {
1666 // Unpack unary element.
1667 const nanobind::handle operand = operandList[i];
1668 if (!operand.is_none()) {
1669 try {
1670 operands.push_back(getOpResultOrValue(operand));
1671 } catch (nb::builtin_exception &err) {
1672 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1673 "\" must be a Value (", err.what(), ")")
1674 .c_str());
1675 }
1676
1677 operandSegmentLengths.push_back(1);
1678 } else if (segmentSpec == 0) {
1679 // Allowed to be optional.
1680 operandSegmentLengths.push_back(0);
1681 } else {
1682 throw nb::value_error(
1683 join("Operand ", i, " of operation \"", name,
1684 "\" must be a Value (was None and operand is not optional)")
1685 .c_str());
1686 }
1687 } else if (segmentSpec == -1) {
1688 // Unpack sequence by appending.
1689 try {
1690 if (operandList[i].is_none()) {
1691 // Treat it as an empty list.
1692 operandSegmentLengths.push_back(0);
1693 } else {
1694 // Unpack the list.
1695 auto segment = nb::cast<nb::sequence>(operandList[i]);
1696 for (nb::handle segmentItem : segment) {
1697 operands.push_back(getOpResultOrValue(segmentItem));
1698 }
1699 operandSegmentLengths.push_back(nb::len(segment));
1700 }
1701 } catch (std::exception &err) {
1702 // NOTE: Sloppy to be using a catch-all here, but there are at least
1703 // three different unrelated exceptions that can be thrown in the
1704 // above "casts". Just keep the scope above small and catch them all.
1705 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1706 "\" must be a Sequence of Values (",
1707 err.what(), ")")
1708 .c_str());
1709 }
1710 } else {
1711 throw nb::value_error("Unexpected segment spec");
1712 }
1713 }
1714 }
1715
1716 // Merge operand/result segment lengths into attributes if needed.
1717 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1718 // Dup.
1719 if (attributes) {
1720 attributes = nb::dict(*attributes);
1721 } else {
1722 attributes = nb::dict();
1723 }
1724 if (attributes->contains("resultSegmentSizes") ||
1725 attributes->contains("operandSegmentSizes")) {
1726 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1727 "'operandSegmentSizes' attribute is unsupported. "
1728 "Use Operation.create for such low-level access.");
1729 }
1730
1731 // Add resultSegmentSizes attribute.
1732 if (!resultSegmentLengths.empty()) {
1733 MlirAttribute segmentLengthAttr =
1734 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1735 resultSegmentLengths.data());
1736 (*attributes)["resultSegmentSizes"] =
1737 PyAttribute(context, segmentLengthAttr);
1738 }
1739
1740 // Add operandSegmentSizes attribute.
1741 if (!operandSegmentLengths.empty()) {
1742 MlirAttribute segmentLengthAttr =
1743 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1744 operandSegmentLengths.data());
1745 (*attributes)["operandSegmentSizes"] =
1746 PyAttribute(context, segmentLengthAttr);
1747 }
1748 }
1749
1750 // Delegate to create.
1751 return PyOperation::create(name,
1752 /*results=*/std::move(resultTypes),
1753 /*operands=*/operands.data(),
1754 /*numOperands=*/operands.size(),
1755 /*attributes=*/std::move(attributes),
1756 /*successors=*/std::move(successors),
1757 /*regions=*/*regions, location, maybeIp,
1758 !resultTypeList);
1759}
1760
1761nb::object PyOpView::constructDerived(const nb::object &cls,
1762 const nb::object &operation) {
1763 nb::handle opViewType = nb::type<PyOpView>();
1764 nb::object instance = cls.attr("__new__")(cls);
1765 opViewType.attr("__init__")(instance, operation);
1766 return instance;
1767}
1768
1769PyOpView::PyOpView(const nb::object &operationObject)
1770 // Casting through the PyOperationBase base-class and then back to the
1771 // Operation lets us accept any PyOperationBase subclass.
1772 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1773 operationObject(operation.getRef().getObject()) {}
1775//------------------------------------------------------------------------------
1776// PyAsmState
1777//------------------------------------------------------------------------------
1778
1779PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
1780 flags = mlirOpPrintingFlagsCreate();
1781 // The OpPrintingFlags are not exposed Python side, create locally and
1782 // associate lifetime with the state.
1783 if (useLocalScope)
1785 state = mlirAsmStateCreateForValue(value, flags);
1786}
1787
1788PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
1789 flags = mlirOpPrintingFlagsCreate();
1790 // The OpPrintingFlags are not exposed Python side, create locally and
1791 // associate lifetime with the state.
1792 if (useLocalScope)
1794 state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
1795}
1797//------------------------------------------------------------------------------
1798// PyInsertionPoint.
1799//------------------------------------------------------------------------------
1800
1801PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1804 : refOperation(beforeOperationBase.getOperation().getRef()),
1805 block((*refOperation)->getBlock()) {}
1806
1808 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1809
1810void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1811 PyOperation &operation = operationBase.getOperation();
1812 if (operation.isAttached())
1813 throw nb::value_error(
1814 "Attempt to insert operation that is already attached");
1815 block.getParentOperation()->checkValid();
1816 MlirOperation beforeOp = {nullptr};
1817 if (refOperation) {
1818 // Insert before operation.
1819 (*refOperation)->checkValid();
1820 beforeOp = (*refOperation)->get();
1821 } else {
1822 // Insert at end (before null) is only valid if the block does not
1823 // already end in a known terminator (violating this will cause assertion
1824 // failures later).
1825 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1826 throw nb::index_error("Cannot insert operation at the end of a block "
1827 "that already has a terminator. Did you mean to "
1828 "use 'InsertionPoint.at_block_terminator(block)' "
1829 "versus 'InsertionPoint(block)'?");
1830 }
1832 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1833 operation.setAttached();
1834}
1835
1837 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1838 if (mlirOperationIsNull(firstOp)) {
1839 // Just insert at end.
1840 return PyInsertionPoint(block);
1841 }
1842
1843 // Insert before first op.
1845 block.getParentOperation()->getContext(), firstOp);
1846 return PyInsertionPoint{block, std::move(firstOpRef)};
1847}
1848
1850 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1851 if (mlirOperationIsNull(terminator))
1852 throw nb::value_error("Block has no terminator");
1854 block.getParentOperation()->getContext(), terminator);
1855 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1856}
1857
1859 PyOperation &operation = op.getOperation();
1860 PyBlock block = operation.getBlock();
1861 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
1862 if (mlirOperationIsNull(nextOperation))
1863 return PyInsertionPoint(block);
1865 block.getParentOperation()->getContext(), nextOperation);
1866 return PyInsertionPoint{block, std::move(nextOpRef)};
1867}
1868
1869size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
1871nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
1872 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
1873}
1874
1875void PyInsertionPoint::contextExit(const nb::object &excType,
1876 const nb::object &excVal,
1877 const nb::object &excTb) {
1879}
1881//------------------------------------------------------------------------------
1882// PyAttribute.
1883//------------------------------------------------------------------------------
1885bool PyAttribute::operator==(const PyAttribute &other) const {
1886 return mlirAttributeEqual(attr, other.attr);
1887}
1889nb::object PyAttribute::getCapsule() {
1890 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
1891}
1892
1893PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
1894 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1895 if (mlirAttributeIsNull(rawAttr))
1896 throw nb::python_error();
1897 return PyAttribute(
1899}
1900
1901nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
1902 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
1903 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1904 "mlirTypeID was expected to be non-null.");
1905 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1906 mlirTypeID, mlirAttributeGetDialect(this->get()));
1907 // nb::rv_policy::move means use std::move to move the return value
1908 // contents into a new instance that will be owned by Python.
1909 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1910 if (!typeCaster)
1911 return thisObj;
1912 return typeCaster.value()(thisObj);
1913}
1915//------------------------------------------------------------------------------
1916// PyNamedAttribute.
1917//------------------------------------------------------------------------------
1918
1919PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1920 : ownedName(new std::string(std::move(ownedName))) {
1923 toMlirStringRef(*this->ownedName)),
1924 attr);
1925}
1927//------------------------------------------------------------------------------
1928// PyType.
1929//------------------------------------------------------------------------------
1931bool PyType::operator==(const PyType &other) const {
1932 return mlirTypeEqual(type, other.type);
1933}
1935nb::object PyType::getCapsule() {
1936 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
1937}
1938
1939PyType PyType::createFromCapsule(nb::object capsule) {
1940 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1941 if (mlirTypeIsNull(rawType))
1942 throw nb::python_error();
1944 rawType);
1945}
1946
1947nb::typed<nb::object, PyType> PyType::maybeDownCast() {
1948 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
1949 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1950 "mlirTypeID was expected to be non-null.");
1951 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1952 mlirTypeID, mlirTypeGetDialect(this->get()));
1953 // nb::rv_policy::move means use std::move to move the return value
1954 // contents into a new instance that will be owned by Python.
1955 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1956 if (!typeCaster)
1957 return thisObj;
1958 return typeCaster.value()(thisObj);
1959}
1961//------------------------------------------------------------------------------
1962// PyTypeID.
1963//------------------------------------------------------------------------------
1965nb::object PyTypeID::getCapsule() {
1966 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
1967}
1968
1969PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
1970 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1971 if (mlirTypeIDIsNull(mlirTypeID))
1972 throw nb::python_error();
1973 return PyTypeID(mlirTypeID);
1974}
1975bool PyTypeID::operator==(const PyTypeID &other) const {
1976 return mlirTypeIDEqual(typeID, other.typeID);
1977}
1979//------------------------------------------------------------------------------
1980// PyValue and subclasses.
1981//------------------------------------------------------------------------------
1983nb::object PyValue::getCapsule() {
1984 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
1985}
1986
1987static PyOperationRef getValueOwnerRef(MlirValue value) {
1988 MlirOperation owner;
1989 if (mlirValueIsAOpResult(value))
1990 owner = mlirOpResultGetOwner(value);
1991 else if (mlirValueIsABlockArgument(value))
1993 else
1994 assert(false && "Value must be an block arg or op result.");
1995 if (mlirOperationIsNull(owner))
1996 throw nb::python_error();
1997 MlirContext ctx = mlirOperationGetContext(owner);
1999}
2000
2001nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
2003 MlirType type = mlirValueGetType(get());
2004 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2005 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2006 "mlirTypeID was expected to be non-null.");
2007 std::optional<nb::callable> valueCaster =
2009 // nb::rv_policy::move means use std::move to move the return value
2010 // contents into a new instance that will be owned by Python.
2011 nb::object thisObj;
2012 if (mlirValueIsAOpResult(value))
2013 thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
2014 else if (mlirValueIsABlockArgument(value))
2015 thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
2016 else
2017 assert(false && "Value must be an block arg or op result.");
2018 if (valueCaster)
2019 return valueCaster.value()(thisObj);
2020 return thisObj;
2021}
2022
2023PyValue PyValue::createFromCapsule(nb::object capsule) {
2024 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2025 if (mlirValueIsNull(value))
2026 throw nb::python_error();
2027 PyOperationRef ownerRef = getValueOwnerRef(value);
2028 return PyValue(ownerRef, value);
2029}
2031//------------------------------------------------------------------------------
2032// PySymbolTable.
2033//------------------------------------------------------------------------------
2034
2036 : operation(operation.getOperation().getRef()) {
2037 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2038 if (mlirSymbolTableIsNull(symbolTable)) {
2039 throw nb::type_error("Operation is not a Symbol Table.");
2040 }
2041}
2042
2043nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2044 operation->checkValid();
2045 MlirOperation symbol = mlirSymbolTableLookup(
2046 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2047 if (mlirOperationIsNull(symbol))
2048 throw nb::key_error(
2049 join("Symbol '", name, "' not in the symbol table.").c_str());
2050
2051 return PyOperation::forOperation(operation->getContext(), symbol,
2052 operation.getObject())
2053 ->createOpView();
2054}
2055
2057 operation->checkValid();
2058 symbol.getOperation().checkValid();
2059 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2060 // The operation is also erased, so we must invalidate it. There may be Python
2061 // references to this operation so we don't want to delete it from the list of
2062 // live operations here.
2063 symbol.getOperation().valid = false;
2064}
2065
2066void PySymbolTable::dunderDel(const std::string &name) {
2067 nb::object operation = dunderGetItem(name);
2068 erase(nb::cast<PyOperationBase &>(operation));
2069}
2070
2072 operation->checkValid();
2073 symbol.getOperation().checkValid();
2074 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2076 if (mlirAttributeIsNull(symbolAttr))
2077 throw nb::value_error("Expected operation to have a symbol name.");
2079 symbol.getOperation().getContext(),
2080 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2081}
2082
2084 // Op must already be a symbol.
2085 PyOperation &operation = symbol.getOperation();
2086 operation.checkValid();
2088 MlirAttribute existingNameAttr =
2089 mlirOperationGetAttributeByName(operation.get(), attrName);
2090 if (mlirAttributeIsNull(existingNameAttr))
2091 throw nb::value_error("Expected operation to have a symbol name.");
2092 return PyStringAttribute(symbol.getOperation().getContext(),
2093 existingNameAttr);
2094}
2095
2097 const std::string &name) {
2098 // Op must already be a symbol.
2099 PyOperation &operation = symbol.getOperation();
2100 operation.checkValid();
2102 MlirAttribute existingNameAttr =
2103 mlirOperationGetAttributeByName(operation.get(), attrName);
2104 if (mlirAttributeIsNull(existingNameAttr))
2105 throw nb::value_error("Expected operation to have a symbol name.");
2106 MlirAttribute newNameAttr =
2107 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2108 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2109}
2110
2112 PyOperation &operation = symbol.getOperation();
2113 operation.checkValid();
2115 MlirAttribute existingVisAttr =
2116 mlirOperationGetAttributeByName(operation.get(), attrName);
2117 if (mlirAttributeIsNull(existingVisAttr))
2118 throw nb::value_error("Expected operation to have a symbol visibility.");
2119 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2120}
2121
2123 const std::string &visibility) {
2124 if (visibility != "public" && visibility != "private" &&
2125 visibility != "nested")
2126 throw nb::value_error(
2127 "Expected visibility to be 'public', 'private' or 'nested'");
2128 PyOperation &operation = symbol.getOperation();
2129 operation.checkValid();
2131 MlirAttribute existingVisAttr =
2132 mlirOperationGetAttributeByName(operation.get(), attrName);
2133 if (mlirAttributeIsNull(existingVisAttr))
2134 throw nb::value_error("Expected operation to have a symbol visibility.");
2135 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2136 toMlirStringRef(visibility));
2137 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2138}
2139
2140void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2141 const std::string &newSymbol,
2142 PyOperationBase &from) {
2143 PyOperation &fromOperation = from.getOperation();
2144 fromOperation.checkValid();
2146 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2148
2149 throw nb::value_error("Symbol rename failed");
2150}
2151
2153 bool allSymUsesVisible,
2154 nb::object callback) {
2155 PyOperation &fromOperation = from.getOperation();
2156 fromOperation.checkValid();
2157 struct UserData {
2158 PyMlirContextRef context;
2159 nb::object callback;
2160 bool gotException;
2161 std::string exceptionWhat;
2162 nb::object exceptionType;
2163 };
2164 UserData userData{
2165 fromOperation.getContext(), std::move(callback), false, {}, {}};
2167 fromOperation.get(), allSymUsesVisible,
2168 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2169 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2170 auto pyFoundOp =
2171 PyOperation::forOperation(calleeUserData->context, foundOp);
2172 if (calleeUserData->gotException)
2173 return;
2174 try {
2175 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2176 } catch (nb::python_error &e) {
2177 calleeUserData->gotException = true;
2178 calleeUserData->exceptionWhat = e.what();
2179 calleeUserData->exceptionType = nb::borrow(e.type());
2180 }
2181 },
2182 static_cast<void *>(&userData));
2183 if (userData.gotException) {
2184 std::string message("Exception raised in callback: ");
2185 message.append(userData.exceptionWhat);
2186 throw std::runtime_error(message);
2187 }
2188}
2189
2190void PyBlockArgument::bindDerived(ClassTy &c) {
2191 c.def_prop_ro(
2192 "owner",
2193 [](PyBlockArgument &self) {
2194 return PyBlock(self.getParentOperation(),
2196 },
2197 "Returns the block that owns this argument.");
2198 c.def_prop_ro(
2199 "arg_number",
2200 [](PyBlockArgument &self) {
2201 return mlirBlockArgumentGetArgNumber(self.get());
2202 },
2203 "Returns the position of this argument in the block's argument list.");
2204 c.def(
2205 "set_type",
2206 [](PyBlockArgument &self, PyType type) {
2207 return mlirBlockArgumentSetType(self.get(), type);
2208 },
2209 "type"_a, "Sets the type of this block argument.");
2210 c.def(
2211 "set_location",
2212 [](PyBlockArgument &self, PyLocation loc) {
2214 },
2215 "loc"_a, "Sets the location of this block argument.");
2216}
2217
2219 MlirBlock block, intptr_t startIndex,
2222 length == -1 ? mlirBlockGetNumArguments(block) : length, step),
2223 operation(std::move(operation)), block(block) {}
2224
2225void PyBlockArgumentList::bindDerived(ClassTy &c) {
2226 c.def_prop_ro(
2227 "types",
2228 [](PyBlockArgumentList &self) {
2229 return getValueTypes(self, self.operation->getContext());
2230 },
2231 "Returns a list of types for all arguments in this argument list.");
2232}
2233
2234intptr_t PyBlockArgumentList::getRawNumElements() {
2235 operation->checkValid();
2236 return mlirBlockGetNumArguments(block);
2237}
2238
2239PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
2240 MlirValue argument = mlirBlockGetArgument(block, pos);
2241 return PyBlockArgument(operation, argument);
2242}
2243
2244PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
2246 intptr_t step) const {
2247 return PyBlockArgumentList(operation, block, startIndex, length, step);
2248}
2249
2251 intptr_t length, intptr_t step)
2252 : Sliceable(startIndex,
2254 : length,
2255 step),
2256 operation(operation) {}
2257
2260 mlirOperationSetOperand(operation->get(), index, value.get());
2261}
2262
2263void PyOpOperandList::bindDerived(ClassTy &c) {
2264 c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
2265 "Sets the operand at the specified index to a new value.");
2266}
2267
2268intptr_t PyOpOperandList::getRawNumElements() {
2269 operation->checkValid();
2270 return mlirOperationGetNumOperands(operation->get());
2271}
2272
2273PyValue PyOpOperandList::getRawElement(intptr_t pos) {
2274 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2275 PyOperationRef pyOwner = getValueOwnerRef(operand);
2276 return PyValue(pyOwner, operand);
2277}
2278
2279PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
2280 intptr_t step) const {
2281 return PyOpOperandList(operation, startIndex, length, step);
2282}
2284/// A list of OpOperands. Internally, these are stored as consecutive elements,
2285/// random access is cheap. The (returned) OpOperand list is associated with the
2286/// operation whose operands these are, and thus extends the lifetime of this
2287/// operation.
2288class PyOpOperands : public Sliceable<PyOpOperands, PyOpOperand> {
2289public:
2290 static constexpr const char *pyClassName = "OpOperands";
2292
2294 intptr_t length = -1, intptr_t step = 1)
2296 length == -1 ? mlirOperationGetNumOperands(operation->get())
2297 : length,
2298 step),
2299 operation(operation) {}
2300
2301private:
2302 /// Give the parent CRTP class access to hook implementations below.
2303 friend class Sliceable<PyOpOperands, PyOpOperand>;
2304
2305 intptr_t getRawNumElements() {
2306 operation->checkValid();
2307 return mlirOperationGetNumOperands(operation->get());
2308 }
2309
2310 PyOpOperand getRawElement(intptr_t pos) {
2311 MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
2312 return PyOpOperand(opOperand);
2313 }
2314
2316 return PyOpOperands(operation, startIndex, length, step);
2318
2319 PyOperationRef operation;
2320};
2321
2323 intptr_t length, intptr_t step)
2324 : Sliceable(startIndex,
2326 : length,
2327 step),
2328 operation(operation) {}
2329
2332 mlirOperationSetSuccessor(operation->get(), index, block.get());
2333}
2334
2335void PyOpSuccessors::bindDerived(ClassTy &c) {
2336 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
2337 "Sets the successor block at the specified index.");
2338}
2339
2340intptr_t PyOpSuccessors::getRawNumElements() {
2341 operation->checkValid();
2342 return mlirOperationGetNumSuccessors(operation->get());
2343}
2344
2345PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
2346 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2347 return PyBlock(operation, block);
2348}
2349
2351 intptr_t step) const {
2352 return PyOpSuccessors(operation, startIndex, length, step);
2353}
2354
2356 intptr_t startIndex, intptr_t length,
2357 intptr_t step)
2358 : Sliceable(startIndex,
2359 length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
2360 step),
2361 operation(operation), block(block) {}
2362
2363intptr_t PyBlockSuccessors::getRawNumElements() {
2364 block.checkValid();
2365 return mlirBlockGetNumSuccessors(block.get());
2366}
2367
2368PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
2369 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2370 return PyBlock(operation, block);
2371}
2372
2374 intptr_t step) const {
2375 return PyBlockSuccessors(block, operation, startIndex, length, step);
2376}
2377
2379 PyOperationRef operation,
2380 intptr_t startIndex, intptr_t length,
2381 intptr_t step)
2382 : Sliceable(startIndex,
2383 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2384 : length,
2385 step),
2386 operation(operation), block(block) {}
2387
2388intptr_t PyBlockPredecessors::getRawNumElements() {
2389 block.checkValid();
2390 return mlirBlockGetNumPredecessors(block.get());
2391}
2392
2393PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
2394 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2395 return PyBlock(operation, block);
2396}
2397
2398PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
2399 intptr_t length,
2400 intptr_t step) const {
2401 return PyBlockPredecessors(block, operation, startIndex, length, step);
2402}
2403
2404nb::typed<nb::object, PyAttribute>
2405PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
2406 MlirAttribute attr =
2408 if (mlirAttributeIsNull(attr)) {
2409 throw nb::key_error("attempt to access a non-existent attribute");
2411 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2412}
2413
2414nb::typed<nb::object, std::optional<PyAttribute>>
2415PyOpAttributeMap::get(const std::string &key, nb::object defaultValue) {
2416 MlirAttribute attr =
2418 if (mlirAttributeIsNull(attr))
2419 return defaultValue;
2420 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2421}
2422
2424 if (index < 0) {
2425 index += dunderLen();
2426 }
2427 if (index < 0 || index >= dunderLen()) {
2428 throw nb::index_error("attempt to access out of bounds attribute");
2429 }
2430 MlirNamedAttribute namedAttr =
2431 mlirOperationGetAttribute(operation->get(), index);
2432 return PyNamedAttribute(
2433 namedAttr.attribute,
2434 std::string(mlirIdentifierStr(namedAttr.name).data,
2435 mlirIdentifierStr(namedAttr.name).length));
2436}
2437
2438void PyOpAttributeMap::dunderSetItem(const std::string &name,
2439 const PyAttribute &attr) {
2440 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2441 attr);
2442}
2443
2444void PyOpAttributeMap::dunderDelItem(const std::string &name) {
2445 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2447 if (!removed)
2448 throw nb::key_error("attempt to delete a non-existent attribute");
2449}
2452 return mlirOperationGetNumAttributes(operation->get());
2453}
2454
2455bool PyOpAttributeMap::dunderContains(const std::string &name) {
2456 return !mlirAttributeIsNull(
2457 mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
2458}
2459
2461 MlirOperation op, std::function<void(MlirStringRef, MlirAttribute)> fn) {
2463 for (intptr_t i = 0; i < n; ++i) {
2466 fn(name, na.attribute);
2467 }
2468}
2469
2470void PyOpAttributeMap::bind(nb::module_ &m) {
2471 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2472 .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
2473 "Checks if an attribute with the given name exists in the map.")
2474 .def("__len__", &PyOpAttributeMap::dunderLen,
2475 "Returns the number of attributes in the map.")
2476 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
2477 "Gets an attribute by name.")
2478 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
2479 "Gets a named attribute by index.")
2480 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
2481 "Sets an attribute with the given name.")
2482 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
2483 "Deletes an attribute with the given name.")
2484 .def("get", &PyOpAttributeMap::get, nb::arg("key"),
2485 nb::arg("default") = nb::none(),
2486 "Gets an attribute by name or the default value, if it does not "
2487 "exist.")
2488 .def(
2489 "__iter__",
2490 [](PyOpAttributeMap &self) {
2491 nb::list keys;
2493 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2494 keys.append(nb::str(name.data, name.length));
2495 });
2496 return nb::iter(keys);
2497 },
2498 "Iterates over attribute names.")
2499 .def(
2500 "keys",
2501 [](PyOpAttributeMap &self) {
2502 nb::list out;
2504 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2505 out.append(nb::str(name.data, name.length));
2506 });
2507 return out;
2508 },
2509 "Returns a list of attribute names.")
2510 .def(
2511 "values",
2512 [](PyOpAttributeMap &self) {
2513 nb::list out;
2515 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
2516 out.append(PyAttribute(self.operation->getContext(), attr)
2517 .maybeDownCast());
2518 });
2519 return out;
2520 },
2521 "Returns a list of attribute values.")
2522 .def(
2523 "items",
2524 [](PyOpAttributeMap &self) {
2525 nb::list out;
2527 self.operation->get(),
2528 [&](MlirStringRef name, MlirAttribute attr) {
2529 out.append(nb::make_tuple(
2530 nb::str(name.data, name.length),
2531 PyAttribute(self.operation->getContext(), attr)
2532 .maybeDownCast()));
2533 });
2534 return out;
2535 },
2536 "Returns a list of `(name, attribute)` tuples.");
2537}
2538
2539void PyOpAdaptor::bind(nb::module_ &m) {
2540 nb::class_<PyOpAdaptor>(m, "OpAdaptor")
2541 .def(nb::init<nb::list, PyOpAttributeMap>(),
2542 "Creates an OpAdaptor with the given operands and attributes.",
2543 "operands"_a, "attributes"_a)
2544 .def(nb::init<nb::list, PyOpView &>(),
2545 "Creates an OpAdaptor with the given operands and operation view.",
2546 "operands"_a, "opview"_a)
2547 .def_prop_ro(
2548 "operands", [](PyOpAdaptor &self) { return self.operands; },
2549 "Returns the operands of the adaptor.")
2550 .def_prop_ro(
2551 "attributes", [](PyOpAdaptor &self) { return self.attributes; },
2552 "Returns the attributes of the adaptor.");
2553}
2554
2555static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData,
2556 const char *methodName) {
2557 nb::handle targetObj(static_cast<PyObject *>(userData));
2558 if (!nb::hasattr(targetObj, methodName))
2559 return mlirLogicalResultSuccess();
2561 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
2562 bool success = nb::cast<bool>(targetObj.attr(methodName)(opView));
2564};
2565
2566static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
2567 PyMlirContext &context) {
2568 std::string opNameStr;
2569 if (opName.is_type()) {
2570 opNameStr = nb::cast<std::string>(opName.attr("OPERATION_NAME"));
2571 } else if (nb::isinstance<nb::str>(opName)) {
2572 opNameStr = nb::cast<std::string>(opName);
2573 } else {
2574 throw nb::type_error("the root argument must be a type or a string");
2575 }
2578 trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get());
2579}
2580
2581bool PyDynamicOpTrait::attach(const nb::object &opName,
2582 const nb::object &target,
2583 PyMlirContext &context) {
2584 if (!nb::hasattr(target, "verify_invariants") &&
2585 !nb::hasattr(target, "verify_region_invariants"))
2586 throw nb::type_error(
2587 "the target object must have at least one of 'verify_invariants' or "
2588 "'verify_region_invariants' methods");
2589
2591 callbacks.construct = [](void *userData) {
2592 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
2593 };
2594 callbacks.destruct = [](void *userData) {
2595 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
2596 };
2597
2598 callbacks.verifyTrait = [](MlirOperation op,
2599 void *userData) -> MlirLogicalResult {
2600 return verifyTraitByMethod(op, userData, "verify_invariants");
2601 };
2602 callbacks.verifyRegionTrait = [](MlirOperation op,
2603 void *userData) -> MlirLogicalResult {
2604 return verifyTraitByMethod(op, userData, "verify_region_invariants");
2605 };
2606
2607 // To ensure that the same dynamic trait gets the same TypeID despite how many
2608 // times `attach` is called, we store it as an attribute on the target class.
2609 constexpr const char *typeIDAttr = "_TYPE_ID";
2610 if (!nb::hasattr(target, typeIDAttr)) {
2611 nb::setattr(target, typeIDAttr,
2612 nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
2613 }
2614 MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate(
2615 nb::cast<PyTypeID>(target.attr(typeIDAttr)).get(), callbacks,
2616 static_cast<void *>(target.ptr()));
2617 return attachOpTrait(opName, trait, context);
2618}
2619
2620void PyDynamicOpTrait::bind(nb::module_ &m) {
2621 nb::class_<PyDynamicOpTrait> cls(m, "DynamicOpTrait");
2622 cls.attr("attach") = classmethod(
2623 [](const nb::object &cls, const nb::object &opName, nb::object target,
2624 DefaultingPyMlirContext context) {
2625 if (target.is_none())
2626 target = cls;
2627 return PyDynamicOpTrait::attach(opName, target, *context.get());
2628 },
2629 nb::arg("cls"), nb::arg("op_name"), nb::arg("target").none() = nb::none(),
2630 nb::arg("context").none() = nb::none(),
2631 "Attach the dynamic op trait subclass to the given operation name.");
2632}
2633
2634bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
2635 PyMlirContext &context) {
2636 MlirDynamicOpTrait trait = mlirDynamicOpTraitIsTerminatorCreate();
2637 return attachOpTrait(opName, trait, context);
2638}
2639
2640void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
2641 nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
2642 m, "IsTerminatorTrait");
2643 cls.attr("attach") = classmethod(
2644 [](const nb::object &cls, const nb::object &opName,
2645 DefaultingPyMlirContext context) {
2646 return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get());
2648 "Attach IsTerminator trait to the given operation name.", nb::arg("cls"),
2649 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2650}
2651
2652bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
2653 PyMlirContext &context) {
2654 MlirDynamicOpTrait trait = mlirDynamicOpTraitNoTerminatorCreate();
2655 return attachOpTrait(opName, trait, context);
2656}
2657
2658void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
2659 nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
2660 m, "NoTerminatorTrait");
2661 cls.attr("attach") = classmethod(
2662 [](const nb::object &cls, const nb::object &opName,
2663 DefaultingPyMlirContext context) {
2664 return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get());
2665 },
2666 "Attach NoTerminator trait to the given operation name.", nb::arg("cls"),
2667 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2668}
2669
2670} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
2671} // namespace python
2672} // namespace mlir
2673
2674namespace {
2675// see
2676// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
2677
2678#ifndef _Py_CAST
2679#define _Py_CAST(type, expr) ((type)(expr))
2680#endif
2681
2682// Static inline functions should use _Py_NULL rather than using directly NULL
2683// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
2684// _Py_NULL is defined as nullptr.
2685#ifndef _Py_NULL
2686#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
2687 (defined(__cplusplus) && __cplusplus >= 201103)
2688#define _Py_NULL nullptr
2689#else
2690#define _Py_NULL NULL
2691#endif
2692#endif
2693
2694// Python 3.10.0a3
2695#if PY_VERSION_HEX < 0x030A00A3
2696
2697// bpo-42262 added Py_XNewRef()
2698#if !defined(Py_XNewRef)
2699[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
2700 Py_XINCREF(obj);
2701 return obj;
2702}
2703#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
2704#endif
2705
2706// bpo-42262 added Py_NewRef()
2707#if !defined(Py_NewRef)
2708[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
2709 Py_INCREF(obj);
2710 return obj;
2711}
2712#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
2713#endif
2714
2715#endif // Python 3.10.0a3
2716
2717// Python 3.9.0b1
2718#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
2719
2720// bpo-40429 added PyThreadState_GetFrame()
2721PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
2722 assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
2723 return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
2724}
2725
2726// bpo-40421 added PyFrame_GetBack()
2727PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
2728 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2729 return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
2730}
2731
2732// bpo-40421 added PyFrame_GetCode()
2733PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
2734 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2735 assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
2736 return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
2737}
2738
2739#endif // Python 3.9.0b1
2740
2741using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
2742
2743MlirLocation tracebackToLocation(MlirContext ctx) {
2744 size_t framesLimit =
2746 // Use a thread_local here to avoid requiring a large amount of space.
2747 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2748 frames;
2749 size_t count = 0;
2750
2751 nb::gil_scoped_acquire acquire;
2752 PyThreadState *tstate = PyThreadState_GET();
2753 PyFrameObject *next;
2754 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2755 // In the increment expression:
2756 // 1. get the next prev frame;
2757 // 2. decrement the ref count on the current frame (in order that it can get
2758 // gc'd, along with any objects in its closure and etc);
2759 // 3. set current = next.
2760 for (; pyFrame != nullptr && count < framesLimit;
2761 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2762 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2763 auto fileNameStr =
2764 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2765 std::string_view fileName(fileNameStr);
2766 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2767 continue;
2768
2769 // co_qualname and PyCode_Addr2Location added in py3.11
2770#if PY_VERSION_HEX < 0x030B00F0
2771 std::string name =
2772 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2773 std::string_view funcName(name);
2774 int startLine = PyFrame_GetLineNumber(pyFrame);
2775 MlirLocation loc = mlirLocationFileLineColGet(
2776 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2777 0);
2778#else
2779 std::string name =
2780 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2781 std::string_view funcName(name);
2782 int startLine, startCol, endLine, endCol;
2783 int lasti = PyFrame_GetLasti(pyFrame);
2784 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2785 &endCol)) {
2786 throw nb::python_error();
2787 }
2788 MlirLocation loc = mlirLocationFileLineColRangeGet(
2789 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2790 startCol, endLine, endCol);
2791#endif
2792
2793 frames[count] = mlirLocationNameGet(
2794 ctx, mlirStringRefCreate(funcName.data(), funcName.size()), loc);
2795 ++count;
2796 }
2797 // When the loop breaks (after the last iter), current frame (if non-null)
2798 // is leaked without this.
2799 Py_XDECREF(pyFrame);
2800
2801 if (count == 0)
2802 return mlirLocationUnknownGet(ctx);
2803
2804 MlirLocation callee = frames[0];
2805 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2806 if (count == 1)
2807 return callee;
2808
2809 MlirLocation caller = frames[count - 1];
2810 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2811 for (int i = count - 2; i >= 1; i--)
2812 caller = mlirLocationCallSiteGet(frames[i], caller);
2813
2814 return mlirLocationCallSiteGet(callee, caller);
2815}
2816
2817PyLocation
2818maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2819 if (location.has_value())
2820 return location.value();
2821 if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2823
2824 PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2825 MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2827 return {ref, mlirLoc};
2828}
2829} // namespace
2830
2831namespace mlir {
2832namespace python {
2834
2835void populateRoot(nb::module_ &m) {
2836 m.attr("T") = nb::type_var("T");
2837 m.attr("U") = nb::type_var("U");
2838
2839 nb::class_<PyGlobals>(m, "_Globals")
2840 .def_prop_rw("dialect_search_modules",
2843 .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
2844 "module_name"_a)
2845 .def(
2846 "_check_dialect_module_loaded",
2847 [](PyGlobals &self, const std::string &dialectNamespace) {
2848 return self.loadDialectModule(dialectNamespace);
2849 },
2850 "dialect_namespace"_a)
2851 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
2852 "dialect_namespace"_a, "dialect_class"_a,
2853 "Testing hook for directly registering a dialect")
2854 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
2855 "operation_name"_a, "operation_class"_a, nb::kw_only(),
2856 "replace"_a = false,
2857 "Testing hook for directly registering an operation")
2858 .def("loc_tracebacks_enabled",
2859 [](PyGlobals &self) {
2860 return self.getTracebackLoc().locTracebacksEnabled();
2861 })
2862 .def("set_loc_tracebacks_enabled",
2863 [](PyGlobals &self, bool enabled) {
2865 })
2866 .def("loc_tracebacks_frame_limit",
2867 [](PyGlobals &self) {
2869 })
2870 .def("set_loc_tracebacks_frame_limit",
2871 [](PyGlobals &self, std::optional<int> n) {
2874 })
2875 .def("register_traceback_file_inclusion",
2876 [](PyGlobals &self, const std::string &filename) {
2878 })
2879 .def("register_traceback_file_exclusion",
2880 [](PyGlobals &self, const std::string &filename) {
2882 });
2883
2884 // Aside from making the globals accessible to python, having python manage
2885 // it is necessary to make sure it is destroyed (and releases its python
2886 // resources) properly.
2887 m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
2888
2889 // Registration decorators.
2890 m.def(
2891 "register_dialect",
2892 [](nb::type_object pyClass) {
2893 std::string dialectNamespace =
2894 nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
2895 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
2896 return pyClass;
2897 },
2898 "dialect_class"_a,
2899 "Class decorator for registering a custom Dialect wrapper");
2900 m.def(
2901 "register_operation",
2902 [](const nb::type_object &dialectClass, bool replace) -> nb::object {
2903 return nb::cpp_function(
2904 [dialectClass,
2905 replace](nb::type_object opClass) -> nb::type_object {
2906 std::string operationName =
2907 nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
2908 PyGlobals::get().registerOperationImpl(operationName, opClass,
2909 replace);
2910 // Dict-stuff the new opClass by name onto the dialect class.
2911 nb::object opClassName = opClass.attr("__name__");
2912 dialectClass.attr(opClassName) = opClass;
2913 return opClass;
2914 });
2915 },
2916 // clang-format off
2917 nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
2918 "-> typing.Callable[[type[T]], type[T]]"),
2919 // clang-format on
2920 "dialect_class"_a, nb::kw_only(), "replace"_a = false,
2921 "Produce a class decorator for registering an Operation class as part of "
2922 "a dialect");
2923 m.def(
2924 "register_op_adaptor",
2925 [](const nb::type_object &opClass, bool replace) -> nb::object {
2926 return nb::cpp_function(
2927 [opClass,
2928 replace](nb::type_object adaptorClass) -> nb::type_object {
2929 std::string operationName =
2930 nb::cast<std::string>(adaptorClass.attr("OPERATION_NAME"));
2931 PyGlobals::get().registerOpAdaptorImpl(operationName,
2932 adaptorClass, replace);
2933 // Dict-stuff the new adaptorClass by name onto the opClass.
2934 opClass.attr("Adaptor") = adaptorClass;
2935 return adaptorClass;
2936 });
2937 },
2938 // clang-format off
2939 nb::sig("def register_op_adaptor(op_class: type, *, replace: bool = False) "
2940 "-> typing.Callable[[type[T]], type[T]]"),
2941 // clang-format on
2942 "op_class"_a, nb::kw_only(), "replace"_a = false,
2943 "Produce a class decorator for registering an OpAdaptor class for an "
2944 "operation.");
2945 m.def(
2947 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2948 return nb::cpp_function([mlirTypeID, replace](
2949 nb::callable typeCaster) -> nb::object {
2950 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
2951 return typeCaster;
2952 });
2953 },
2954 // clang-format off
2955 nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2956 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2957 // clang-format on
2958 "typeid"_a, nb::kw_only(), "replace"_a = false,
2959 "Register a type caster for casting MLIR types to custom user types.");
2960 m.def(
2962 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2963 return nb::cpp_function(
2964 [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
2965 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
2966 replace);
2967 return valueCaster;
2968 });
2969 },
2970 // clang-format off
2971 nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2972 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2973 // clang-format on
2974 "typeid"_a, nb::kw_only(), "replace"_a = false,
2975 "Register a value caster for casting MLIR values to custom user values.");
2976}
2977
2978//------------------------------------------------------------------------------
2979// Populates the core exports of the 'ir' submodule.
2980//------------------------------------------------------------------------------
2981void populateIRCore(nb::module_ &m) {
2982 //----------------------------------------------------------------------------
2983 // Enums.
2984 //----------------------------------------------------------------------------
2985 nb::enum_<PyDiagnosticSeverity>(m, "DiagnosticSeverity")
2986 .value("ERROR", PyDiagnosticSeverity::Error)
2987 .value("WARNING", PyDiagnosticSeverity::Warning)
2988 .value("NOTE", PyDiagnosticSeverity::Note)
2989 .value("REMARK", PyDiagnosticSeverity::Remark);
2990
2991 nb::enum_<PyWalkOrder>(m, "WalkOrder")
2992 .value("PRE_ORDER", PyWalkOrder::PreOrder)
2993 .value("POST_ORDER", PyWalkOrder::PostOrder);
2994 nb::enum_<PyWalkResult>(m, "WalkResult")
2995 .value("ADVANCE", PyWalkResult::Advance)
2996 .value("INTERRUPT", PyWalkResult::Interrupt)
2997 .value("SKIP", PyWalkResult::Skip);
2998
2999 //----------------------------------------------------------------------------
3000 // Mapping of Diagnostics.
3001 //----------------------------------------------------------------------------
3002 nb::class_<PyDiagnostic>(m, "Diagnostic")
3003 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
3004 "Returns the severity of the diagnostic.")
3005 .def_prop_ro("location", &PyDiagnostic::getLocation,
3006 "Returns the location associated with the diagnostic.")
3007 .def_prop_ro("message", &PyDiagnostic::getMessage,
3008 "Returns the message text of the diagnostic.")
3009 .def_prop_ro("notes", &PyDiagnostic::getNotes,
3010 "Returns a tuple of attached note diagnostics.")
3011 .def(
3012 "__str__",
3013 [](PyDiagnostic &self) -> nb::str {
3014 if (!self.isValid())
3015 return nb::str("<Invalid Diagnostic>");
3016 return self.getMessage();
3017 },
3018 "Returns the diagnostic message as a string.");
3019
3020 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
3021 .def(
3022 "__init__",
3024 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
3025 },
3026 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
3027 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
3028 "The severity level of the diagnostic.")
3029 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
3030 "The location associated with the diagnostic.")
3031 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
3032 "The message text of the diagnostic.")
3033 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
3034 "List of attached note diagnostics.")
3035 .def(
3036 "__str__",
3037 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
3038 "Returns the diagnostic message as a string.");
3039
3040 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
3041 .def("detach", &PyDiagnosticHandler::detach,
3042 "Detaches the diagnostic handler from the context.")
3043 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
3044 "Returns True if the handler is attached to a context.")
3045 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
3046 "Returns True if an error was encountered during diagnostic "
3047 "handling.")
3048 .def("__enter__", &PyDiagnosticHandler::contextEnter,
3049 "Enters the diagnostic handler as a context manager.")
3050 .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
3051 "exc_value"_a.none(), "traceback"_a.none(),
3052 "Exits the diagnostic handler context manager.");
3053
3054 // Expose DefaultThreadPool to python
3055 nb::class_<PyThreadPool>(m, "ThreadPool")
3056 .def(
3057 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
3058 "Creates a new thread pool with default concurrency.")
3059 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
3060 "Returns the maximum number of threads in the pool.")
3061 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
3062 "Returns the raw pointer to the LLVM thread pool as a string.");
3063
3064 nb::class_<PyMlirContext>(m, "Context")
3065 .def(
3066 "__init__",
3067 [](PyMlirContext &self) {
3068 MlirContext context = mlirContextCreateWithThreading(false);
3069 new (&self) PyMlirContext(context);
3070 },
3071 R"(
3072 Creates a new MLIR context.
3073
3074 The context is the top-level container for all MLIR objects. It owns the storage
3075 for types, attributes, locations, and other core IR objects. A context can be
3076 configured to allow or disallow unregistered dialects and can have dialects
3077 loaded on-demand.)")
3078 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
3079 "Gets the number of live Context objects.")
3080 .def(
3081 "_get_context_again",
3082 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
3084 return ref.releaseObject();
3085 },
3086 "Gets another reference to the same context.")
3087 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
3088 "Gets the number of live modules owned by this context.")
3090 "Gets a capsule wrapping the MlirContext.")
3093 "Creates a Context from a capsule wrapping MlirContext.")
3094 .def("__enter__", &PyMlirContext::contextEnter,
3095 "Enters the context as a context manager.")
3096 .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
3097 "exc_value"_a.none(), "traceback"_a.none(),
3098 "Exits the context manager.")
3099 .def_prop_ro_static(
3100 "current",
3101 [](nb::object & /*class*/)
3102 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3104 if (!context)
3105 return {};
3106 return nb::cast(context);
3107 },
3108 nb::sig("def current(/) -> Context | None"),
3109 "Gets the Context bound to the current thread or returns None if no "
3110 "context is set.")
3111 .def_prop_ro(
3112 "dialects",
3113 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3114 "Gets a container for accessing dialects by name.")
3115 .def_prop_ro(
3116 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3117 "Alias for `dialects`.")
3118 .def(
3119 "get_dialect_descriptor",
3120 [=](PyMlirContext &self, std::string &name) {
3121 MlirDialect dialect = mlirContextGetOrLoadDialect(
3122 self.get(), {name.data(), name.size()});
3123 if (mlirDialectIsNull(dialect)) {
3124 throw nb::value_error(
3125 join("Dialect '", name, "' not found").c_str());
3126 }
3127 return PyDialectDescriptor(self.getRef(), dialect);
3128 },
3129 "dialect_name"_a,
3130 "Gets or loads a dialect by name, returning its descriptor object.")
3131 .def_prop_rw(
3132 "allow_unregistered_dialects",
3133 [](PyMlirContext &self) -> bool {
3134 return mlirContextGetAllowUnregisteredDialects(self.get());
3135 },
3136 [](PyMlirContext &self, bool value) {
3137 mlirContextSetAllowUnregisteredDialects(self.get(), value);
3138 },
3139 "Controls whether unregistered dialects are allowed in this context.")
3140 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3141 "callback"_a,
3142 "Attaches a diagnostic handler that will receive callbacks.")
3143 .def(
3144 "enable_multithreading",
3145 [](PyMlirContext &self, bool enable) {
3146 mlirContextEnableMultithreading(self.get(), enable);
3147 },
3148 "enable"_a,
3149 R"(
3150 Enables or disables multi-threading support in the context.
3151
3152 Args:
3153 enable: Whether to enable (True) or disable (False) multi-threading.
3154 )")
3155 .def(
3156 "set_thread_pool",
3157 [](PyMlirContext &self, PyThreadPool &pool) {
3158 // we should disable multi-threading first before setting
3159 // new thread pool otherwise the assert in
3160 // MLIRContext::setThreadPool will be raised.
3161 mlirContextEnableMultithreading(self.get(), false);
3162 mlirContextSetThreadPool(self.get(), pool.get());
3163 },
3164 R"(
3165 Sets a custom thread pool for the context to use.
3166
3167 Args:
3168 pool: A ThreadPool object to use for parallel operations.
3169
3170 Note:
3171 Multi-threading is automatically disabled before setting the thread pool.)")
3172 .def(
3173 "get_num_threads",
3174 [](PyMlirContext &self) {
3175 return mlirContextGetNumThreads(self.get());
3176 },
3177 "Gets the number of threads in the context's thread pool.")
3178 .def(
3179 "_mlir_thread_pool_ptr",
3180 [](PyMlirContext &self) {
3181 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3182 std::stringstream ss;
3183 ss << pool.ptr;
3184 return ss.str();
3185 },
3186 "Gets the raw pointer to the LLVM thread pool as a string.")
3187 .def(
3188 "is_registered_operation",
3189 [](PyMlirContext &self, std::string &name) {
3191 self.get(), MlirStringRef{name.data(), name.size()});
3192 },
3193 "operation_name"_a,
3194 R"(
3195 Checks whether an operation with the given name is registered.
3196
3197 Args:
3198 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3199
3200 Returns:
3201 True if the operation is registered, False otherwise.)")
3202 .def(
3203 "append_dialect_registry",
3204 [](PyMlirContext &self, PyDialectRegistry &registry) {
3205 mlirContextAppendDialectRegistry(self.get(), registry);
3206 },
3207 "registry"_a,
3208 R"(
3209 Appends the contents of a dialect registry to the context.
3210
3211 Args:
3212 registry: A DialectRegistry containing dialects to append.)")
3213 .def_prop_rw("emit_error_diagnostics",
3216 R"(
3217 Controls whether error diagnostics are emitted to diagnostic handlers.
3218
3219 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3220 .def(
3221 "load_all_available_dialects",
3222 [](PyMlirContext &self) {
3224 },
3225 R"(
3226 Loads all dialects available in the registry into the context.
3227
3228 This eagerly loads all dialects that have been registered, making them
3229 immediately available for use.)");
3230
3231 //----------------------------------------------------------------------------
3232 // Mapping of PyDialectDescriptor
3233 //----------------------------------------------------------------------------
3234 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3235 .def_prop_ro(
3236 "namespace",
3237 [](PyDialectDescriptor &self) {
3238 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3239 return nb::str(ns.data, ns.length);
3240 },
3241 "Returns the namespace of the dialect.")
3242 .def(
3243 "__repr__",
3244 [](PyDialectDescriptor &self) {
3245 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3246 std::string repr("<DialectDescriptor ");
3247 repr.append(ns.data, ns.length);
3248 repr.append(">");
3249 return repr;
3250 },
3251 nb::sig("def __repr__(self) -> str"),
3252 "Returns a string representation of the dialect descriptor.");
3253
3254 //----------------------------------------------------------------------------
3255 // Mapping of PyDialects
3256 //----------------------------------------------------------------------------
3257 nb::class_<PyDialects>(m, "Dialects")
3258 .def(
3259 "__getitem__",
3260 [=](PyDialects &self, std::string keyName) {
3261 MlirDialect dialect =
3262 self.getDialectForKey(keyName, /*attrError=*/false);
3263 nb::object descriptor =
3264 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3265 return createCustomDialectWrapper(keyName, std::move(descriptor));
3266 },
3267 "Gets a dialect by name using subscript notation.")
3268 .def(
3269 "__getattr__",
3270 [=](PyDialects &self, std::string attrName) {
3271 MlirDialect dialect =
3272 self.getDialectForKey(attrName, /*attrError=*/true);
3273 nb::object descriptor =
3274 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3275 return createCustomDialectWrapper(attrName, std::move(descriptor));
3276 },
3277 "Gets a dialect by name using attribute notation.");
3278
3279 //----------------------------------------------------------------------------
3280 // Mapping of PyDialect
3281 //----------------------------------------------------------------------------
3282 nb::class_<PyDialect>(m, "Dialect")
3283 .def(nb::init<nb::object>(), "descriptor"_a,
3284 "Creates a Dialect from a DialectDescriptor.")
3285 .def_prop_ro(
3286 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3287 "Returns the DialectDescriptor for this dialect.")
3288 .def(
3289 "__repr__",
3290 [](const nb::object &self) {
3291 auto clazz = self.attr("__class__");
3292 return nb::str("<Dialect ") +
3293 self.attr("descriptor").attr("namespace") +
3294 nb::str(" (class ") + clazz.attr("__module__") +
3295 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3296 },
3297 nb::sig("def __repr__(self) -> str"),
3298 "Returns a string representation of the dialect.");
3299
3300 //----------------------------------------------------------------------------
3301 // Mapping of PyDialectRegistry
3302 //----------------------------------------------------------------------------
3303 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3305 "Gets a capsule wrapping the MlirDialectRegistry.")
3308 "Creates a DialectRegistry from a capsule wrapping "
3309 "`MlirDialectRegistry`.")
3310 .def(nb::init<>(), "Creates a new empty dialect registry.");
3311
3312 //----------------------------------------------------------------------------
3313 // Mapping of Location
3314 //----------------------------------------------------------------------------
3315 nb::class_<PyLocation>(m, "Location")
3317 "Gets a capsule wrapping the MlirLocation.")
3319 "Creates a Location from a capsule wrapping MlirLocation.")
3320 .def("__enter__", &PyLocation::contextEnter,
3321 "Enters the location as a context manager.")
3322 .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
3323 "exc_value"_a.none(), "traceback"_a.none(),
3324 "Exits the location context manager.")
3325 .def(
3326 "__eq__",
3327 [](PyLocation &self, PyLocation &other) -> bool {
3328 return mlirLocationEqual(self, other);
3329 },
3330 "Compares two locations for equality.")
3331 .def(
3332 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3333 "Compares location with non-location object (always returns False).")
3334 .def_prop_ro_static(
3335 "current",
3336 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3338 if (!loc)
3339 return std::nullopt;
3340 return loc;
3341 },
3342 // clang-format off
3343 nb::sig("def current(/) -> Location | None"),
3344 // clang-format on
3345 "Gets the Location bound to the current thread or raises ValueError.")
3346 .def_static(
3347 "unknown",
3348 [](DefaultingPyMlirContext context) {
3349 return PyLocation(context->getRef(),
3350 mlirLocationUnknownGet(context->get()));
3351 },
3352 "context"_a = nb::none(),
3353 "Gets a Location representing an unknown location.")
3354 .def_static(
3355 "callsite",
3356 [](PyLocation callee, const std::vector<PyLocation> &frames,
3357 DefaultingPyMlirContext context) {
3358 if (frames.empty())
3359 throw nb::value_error("No caller frames provided.");
3360 MlirLocation caller = frames.back().get();
3361 for (size_t index = frames.size() - 1; index-- > 0;) {
3362 caller = mlirLocationCallSiteGet(frames[index].get(), caller);
3363 }
3364 return PyLocation(context->getRef(),
3365 mlirLocationCallSiteGet(callee.get(), caller));
3366 },
3367 "callee"_a, "frames"_a, "context"_a = nb::none(),
3368 "Gets a Location representing a caller and callsite.")
3369 .def("is_a_callsite", mlirLocationIsACallSite,
3370 "Returns True if this location is a CallSiteLoc.")
3371 .def_prop_ro(
3372 "callee",
3373 [](PyLocation &self) {
3374 return PyLocation(self.getContext(),
3376 },
3377 "Gets the callee location from a CallSiteLoc.")
3378 .def_prop_ro(
3379 "caller",
3380 [](PyLocation &self) {
3381 return PyLocation(self.getContext(),
3383 },
3384 "Gets the caller location from a CallSiteLoc.")
3385 .def_static(
3386 "file",
3387 [](std::string filename, int line, int col,
3388 DefaultingPyMlirContext context) {
3389 return PyLocation(
3390 context->getRef(),
3392 context->get(), toMlirStringRef(filename), line, col));
3393 },
3394 "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
3395 "Gets a Location representing a file, line and column.")
3396 .def_static(
3397 "file",
3398 [](std::string filename, int startLine, int startCol, int endLine,
3399 int endCol, DefaultingPyMlirContext context) {
3400 return PyLocation(context->getRef(),
3402 context->get(), toMlirStringRef(filename),
3403 startLine, startCol, endLine, endCol));
3404 },
3405 "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
3406 "end_col"_a, "context"_a = nb::none(),
3407 "Gets a Location representing a file, line and column range.")
3408 .def("is_a_file", mlirLocationIsAFileLineColRange,
3409 "Returns True if this location is a FileLineColLoc.")
3410 .def_prop_ro(
3411 "filename",
3412 [](PyLocation loc) {
3413 return mlirIdentifierStr(
3415 },
3416 "Gets the filename from a FileLineColLoc.")
3417 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
3418 "Gets the start line number from a `FileLineColLoc`.")
3419 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
3420 "Gets the start column number from a `FileLineColLoc`.")
3421 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
3422 "Gets the end line number from a `FileLineColLoc`.")
3423 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
3424 "Gets the end column number from a `FileLineColLoc`.")
3425 .def_static(
3426 "fused",
3427 [](const std::vector<PyLocation> &pyLocations,
3428 std::optional<PyAttribute> metadata,
3429 DefaultingPyMlirContext context) {
3430 std::vector<MlirLocation> locations;
3431 locations.reserve(pyLocations.size());
3432 for (const PyLocation &pyLocation : pyLocations)
3433 locations.push_back(pyLocation.get());
3434 MlirLocation location = mlirLocationFusedGet(
3435 context->get(), locations.size(), locations.data(),
3436 metadata ? metadata->get() : MlirAttribute{0});
3437 return PyLocation(context->getRef(), location);
3438 },
3439 "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
3440 "Gets a Location representing a fused location with optional "
3441 "metadata.")
3442 .def("is_a_fused", mlirLocationIsAFused,
3443 "Returns True if this location is a `FusedLoc`.")
3444 .def_prop_ro(
3445 "locations",
3446 [](PyLocation &self) {
3447 unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3448 std::vector<MlirLocation> locations(numLocations);
3449 if (numLocations)
3450 mlirLocationFusedGetLocations(self, locations.data());
3451 std::vector<PyLocation> pyLocations{};
3452 pyLocations.reserve(numLocations);
3453 for (unsigned i = 0; i < numLocations; ++i)
3454 pyLocations.emplace_back(self.getContext(), locations[i]);
3455 return pyLocations;
3456 },
3457 "Gets the list of locations from a `FusedLoc`.")
3458 .def_static(
3459 "name",
3460 [](std::string name, std::optional<PyLocation> childLoc,
3461 DefaultingPyMlirContext context) {
3462 return PyLocation(
3463 context->getRef(),
3465 context->get(), toMlirStringRef(name),
3466 childLoc ? childLoc->get()
3467 : mlirLocationUnknownGet(context->get())));
3468 },
3469 "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
3470 "Gets a Location representing a named location with optional child "
3471 "location.")
3472 .def("is_a_name", mlirLocationIsAName,
3473 "Returns True if this location is a `NameLoc`.")
3474 .def_prop_ro(
3475 "name_str",
3476 [](PyLocation loc) {
3478 },
3479 "Gets the name string from a `NameLoc`.")
3480 .def_prop_ro(
3481 "child_loc",
3482 [](PyLocation &self) {
3483 return PyLocation(self.getContext(),
3485 },
3486 "Gets the child location from a `NameLoc`.")
3487 .def_static(
3488 "from_attr",
3489 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3490 return PyLocation(context->getRef(),
3491 mlirLocationFromAttribute(attribute));
3492 },
3493 "attribute"_a, "context"_a = nb::none(),
3494 "Gets a Location from a `LocationAttr`.")
3495 .def_prop_ro(
3496 "context",
3497 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3498 return self.getContext().getObject();
3499 },
3500 "Context that owns the `Location`.")
3501 .def_prop_ro(
3502 "attr",
3503 [](PyLocation &self) {
3504 return PyAttribute(self.getContext(),
3506 },
3507 "Get the underlying `LocationAttr`.")
3508 .def(
3509 "emit_error",
3510 [](PyLocation &self, std::string message) {
3511 mlirEmitError(self, message.c_str());
3512 },
3513 "message"_a,
3514 R"(
3515 Emits an error diagnostic at this location.
3516
3517 Args:
3518 message: The error message to emit.)")
3519 .def(
3520 "__repr__",
3521 [](PyLocation &self) {
3522 PyPrintAccumulator printAccum;
3523 mlirLocationPrint(self, printAccum.getCallback(),
3524 printAccum.getUserData());
3525 return printAccum.join();
3526 },
3527 "Returns the assembly representation of the location.");
3528
3529 //----------------------------------------------------------------------------
3530 // Mapping of Module
3531 //----------------------------------------------------------------------------
3532 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3534 "Gets a capsule wrapping the MlirModule.")
3536 R"(
3537 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3538
3539 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3540 prevent double-frees (of the underlying `mlir::Module`).)")
3541 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3542 R"(
3543 Clears the internal MLIR module reference.
3544
3545 This is used internally to prevent double-free when ownership is transferred
3546 via the C API capsule mechanism. Not intended for normal use.)")
3547 .def_static(
3548 "parse",
3549 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3550 -> nb::typed<nb::object, PyModule> {
3551 PyMlirContext::ErrorCapture errors(context->getRef());
3552 MlirModule module = mlirModuleCreateParse(
3553 context->get(), toMlirStringRef(moduleAsm));
3554 if (mlirModuleIsNull(module))
3555 throw MLIRError("Unable to parse module assembly", errors.take());
3556 return PyModule::forModule(module).releaseObject();
3557 },
3558 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3559 .def_static(
3560 "parse",
3561 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3562 -> nb::typed<nb::object, PyModule> {
3563 PyMlirContext::ErrorCapture errors(context->getRef());
3564 MlirModule module = mlirModuleCreateParse(
3565 context->get(), toMlirStringRef(moduleAsm));
3566 if (mlirModuleIsNull(module))
3567 throw MLIRError("Unable to parse module assembly", errors.take());
3568 return PyModule::forModule(module).releaseObject();
3569 },
3570 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3571 .def_static(
3572 "parseFile",
3573 [](const std::string &path, DefaultingPyMlirContext context)
3574 -> nb::typed<nb::object, PyModule> {
3575 PyMlirContext::ErrorCapture errors(context->getRef());
3576 MlirModule module = mlirModuleCreateParseFromFile(
3577 context->get(), toMlirStringRef(path));
3578 if (mlirModuleIsNull(module))
3579 throw MLIRError("Unable to parse module assembly", errors.take());
3580 return PyModule::forModule(module).releaseObject();
3581 },
3582 "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
3583 .def_static(
3584 "create",
3585 [](const std::optional<PyLocation> &loc)
3586 -> nb::typed<nb::object, PyModule> {
3587 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3588 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3589 return PyModule::forModule(module).releaseObject();
3590 },
3591 "loc"_a = nb::none(), "Creates an empty module.")
3592 .def_prop_ro(
3593 "context",
3594 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3595 return self.getContext().getObject();
3596 },
3597 "Context that created the `Module`.")
3598 .def_prop_ro(
3599 "operation",
3600 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3601 return PyOperation::forOperation(self.getContext(),
3602 mlirModuleGetOperation(self.get()),
3603 self.getRef().releaseObject())
3604 .releaseObject();
3605 },
3606 "Accesses the module as an operation.")
3607 .def_prop_ro(
3608 "body",
3609 [](PyModule &self) {
3611 self.getContext(), mlirModuleGetOperation(self.get()),
3612 self.getRef().releaseObject());
3613 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3614 return returnBlock;
3615 },
3616 "Return the block for this module.")
3617 .def(
3618 "dump",
3619 [](PyModule &self) {
3621 },
3623 .def(
3624 "__str__",
3625 [](const nb::object &self) {
3626 // Defer to the operation's __str__.
3627 return self.attr("operation").attr("__str__")();
3628 },
3629 nb::sig("def __str__(self) -> str"),
3630 R"(
3631 Gets the assembly form of the operation with default options.
3632
3633 If more advanced control over the assembly formatting or I/O options is needed,
3634 use the dedicated print or get_asm method, which supports keyword arguments to
3635 customize behavior.
3636 )")
3637 .def(
3638 "__eq__",
3639 [](PyModule &self, PyModule &other) {
3640 return mlirModuleEqual(self.get(), other.get());
3641 },
3642 "other"_a, "Compares two modules for equality.")
3643 .def(
3644 "__hash__",
3645 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3646 "Returns the hash value of the module.");
3647
3648 //----------------------------------------------------------------------------
3649 // Mapping of Operation.
3650 //----------------------------------------------------------------------------
3651 nb::class_<PyOperationBase>(m, "_OperationBase")
3652 .def_prop_ro(
3654 [](PyOperationBase &self) {
3655 return self.getOperation().getCapsule();
3656 },
3657 "Gets a capsule wrapping the `MlirOperation`.")
3658 .def(
3659 "__eq__",
3660 [](PyOperationBase &self, PyOperationBase &other) {
3661 return mlirOperationEqual(self.getOperation().get(),
3662 other.getOperation().get());
3663 },
3664 "Compares two operations for equality.")
3665 .def(
3666 "__eq__",
3667 [](PyOperationBase &self, nb::object other) { return false; },
3668 "Compares operation with non-operation object (always returns "
3669 "False).")
3670 .def(
3671 "__hash__",
3672 [](PyOperationBase &self) {
3673 return mlirOperationHashValue(self.getOperation().get());
3674 },
3675 "Returns the hash value of the operation.")
3676 .def_prop_ro(
3677 "attributes",
3678 [](PyOperationBase &self) {
3679 return PyOpAttributeMap(self.getOperation().getRef());
3680 },
3681 "Returns a dictionary-like map of operation attributes.")
3682 .def_prop_ro(
3683 "context",
3684 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3685 PyOperation &concreteOperation = self.getOperation();
3686 concreteOperation.checkValid();
3687 return concreteOperation.getContext().getObject();
3688 },
3689 "Context that owns the operation.")
3690 .def_prop_ro(
3691 "name",
3692 [](PyOperationBase &self) {
3693 auto &concreteOperation = self.getOperation();
3694 concreteOperation.checkValid();
3695 MlirOperation operation = concreteOperation.get();
3696 return mlirIdentifierStr(mlirOperationGetName(operation));
3697 },
3698 "Returns the fully qualified name of the operation.")
3699 .def_prop_ro(
3700 "operands",
3701 [](PyOperationBase &self) {
3702 return PyOpOperandList(self.getOperation().getRef());
3703 },
3704 "Returns the list of operation operands.")
3705 .def_prop_ro(
3706 "op_operands",
3707 [](PyOperationBase &self) {
3708 return PyOpOperands(self.getOperation().getRef());
3709 },
3710 "Returns the list of op operands.")
3711 .def_prop_ro(
3712 "regions",
3713 [](PyOperationBase &self) {
3714 return PyRegionList(self.getOperation().getRef());
3715 },
3716 "Returns the list of operation regions.")
3717 .def_prop_ro(
3718 "results",
3719 [](PyOperationBase &self) {
3720 return PyOpResultList(self.getOperation().getRef());
3721 },
3722 "Returns the list of Operation results.")
3723 .def_prop_ro(
3724 "result",
3725 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3726 auto &operation = self.getOperation();
3727 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3728 .maybeDownCast();
3729 },
3730 "Shortcut to get an op result if it has only one (throws an error "
3731 "otherwise).")
3732 .def_prop_rw(
3733 "location",
3734 [](PyOperationBase &self) {
3735 PyOperation &operation = self.getOperation();
3736 return PyLocation(operation.getContext(),
3737 mlirOperationGetLocation(operation.get()));
3738 },
3739 [](PyOperationBase &self, const PyLocation &location) {
3740 PyOperation &operation = self.getOperation();
3741 mlirOperationSetLocation(operation.get(), location.get());
3742 },
3743 nb::for_getter("Returns the source location the operation was "
3744 "defined or derived from."),
3745 nb::for_setter("Sets the source location the operation was defined "
3746 "or derived from."))
3747 .def_prop_ro(
3748 "parent",
3749 [](PyOperationBase &self)
3750 -> std::optional<nb::typed<nb::object, PyOperation>> {
3751 auto parent = self.getOperation().getParentOperation();
3752 if (parent)
3753 return parent->getObject();
3754 return {};
3755 },
3756 "Returns the parent operation, or `None` if at top level.")
3757 .def(
3758 "__str__",
3759 [](PyOperationBase &self) {
3760 return self.getAsm(/*binary=*/false,
3761 /*largeElementsLimit=*/std::nullopt,
3762 /*largeResourceLimit=*/std::nullopt,
3763 /*enableDebugInfo=*/false,
3764 /*prettyDebugInfo=*/false,
3765 /*printGenericOpForm=*/false,
3766 /*useLocalScope=*/false,
3767 /*useNameLocAsPrefix=*/false,
3768 /*assumeVerified=*/false,
3769 /*skipRegions=*/false);
3770 },
3771 nb::sig("def __str__(self) -> str"),
3772 "Returns the assembly form of the operation.")
3773 .def("print",
3774 nb::overload_cast<PyAsmState &, nb::object, bool>(
3776 "state"_a, "file"_a = nb::none(), "binary"_a = false,
3777 R"(
3778 Prints the assembly form of the operation to a file like object.
3779
3780 Args:
3781 state: `AsmState` capturing the operation numbering and flags.
3782 file: Optional file like object to write to. Defaults to sys.stdout.
3783 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3784 .def("print",
3785 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3786 bool, bool, bool, bool, bool, bool, nb::object,
3787 bool, bool>(&PyOperationBase::print),
3788 // Careful: Lots of arguments must match up with print method.
3789 "large_elements_limit"_a = nb::none(),
3790 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3791 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3792 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3793 "assume_verified"_a = false, "file"_a = nb::none(),
3794 "binary"_a = false, "skip_regions"_a = false,
3795 R"(
3796 Prints the assembly form of the operation to a file like object.
3797
3798 Args:
3799 large_elements_limit: Whether to elide elements attributes above this
3800 number of elements. Defaults to None (no limit).
3801 large_resource_limit: Whether to elide resource attributes above this
3802 number of characters. Defaults to None (no limit). If large_elements_limit
3803 is set and this is None, the behavior will be to use large_elements_limit
3804 as large_resource_limit.
3805 enable_debug_info: Whether to print debug/location information. Defaults
3806 to False.
3807 pretty_debug_info: Whether to format debug information for easier reading
3808 by a human (warning: the result is unparseable). Defaults to False.
3809 print_generic_op_form: Whether to print the generic assembly forms of all
3810 ops. Defaults to False.
3811 use_local_scope: Whether to print in a way that is more optimized for
3812 multi-threaded access but may not be consistent with how the overall
3813 module prints.
3814 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3815 prefixes for the SSA identifiers. Defaults to False.
3816 assume_verified: By default, if not printing generic form, the verifier
3817 will be run and if it fails, generic form will be printed with a comment
3818 about failed verification. While a reasonable default for interactive use,
3819 for systematic use, it is often better for the caller to verify explicitly
3820 and report failures in a more robust fashion. Set this to True if doing this
3821 in order to avoid running a redundant verification. If the IR is actually
3822 invalid, behavior is undefined.
3823 file: The file like object to write to. Defaults to sys.stdout.
3824 binary: Whether to write bytes (True) or str (False). Defaults to False.
3825 skip_regions: Whether to skip printing regions. Defaults to False.)")
3826 .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
3827 "desired_version"_a = nb::none(),
3828 R"(
3829 Write the bytecode form of the operation to a file like object.
3830
3831 Args:
3832 file: The file like object to write to.
3833 desired_version: Optional version of bytecode to emit.
3834 Returns:
3835 The bytecode writer status.)")
3836 .def("get_asm", &PyOperationBase::getAsm,
3837 // Careful: Lots of arguments must match up with get_asm method.
3838 "binary"_a = false, "large_elements_limit"_a = nb::none(),
3839 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3840 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3841 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3842 "assume_verified"_a = false, "skip_regions"_a = false,
3843 R"(
3844 Gets the assembly form of the operation with all options available.
3845
3846 Args:
3847 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3848 False.
3849 ... others ...: See the print() method for common keyword arguments for
3850 configuring the printout.
3851 Returns:
3852 Either a bytes or str object, depending on the setting of the `binary`
3853 argument.)")
3854 .def("verify", &PyOperationBase::verify,
3855 "Verify the operation. Raises MLIRError if verification fails, and "
3856 "returns true otherwise.")
3857 .def("move_after", &PyOperationBase::moveAfter, "other"_a,
3858 "Puts self immediately after the other operation in its parent "
3859 "block.")
3860 .def("move_before", &PyOperationBase::moveBefore, "other"_a,
3861 "Puts self immediately before the other operation in its parent "
3862 "block.")
3863 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
3864 R"(
3865 Checks if this operation is before another in the same block.
3866
3867 Args:
3868 other: Another operation in the same parent block.
3869
3870 Returns:
3871 True if this operation is before `other` in the operation list of the parent block.)")
3872 .def(
3873 "clone",
3874 [](PyOperationBase &self,
3875 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3876 return self.getOperation().clone(ip);
3877 },
3878 "ip"_a = nb::none(),
3879 R"(
3880 Creates a deep copy of the operation.
3881
3882 Args:
3883 ip: Optional insertion point where the cloned operation should be inserted.
3884 If None, the current insertion point is used. If False, the operation
3885 remains detached.
3886
3887 Returns:
3888 A new Operation that is a clone of this operation.)")
3889 .def(
3890 "detach_from_parent",
3891 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3892 PyOperation &operation = self.getOperation();
3893 operation.checkValid();
3894 if (!operation.isAttached())
3895 throw nb::value_error("Detached operation has no parent.");
3896
3897 operation.detachFromParent();
3898 return operation.createOpView();
3899 },
3900 "Detaches the operation from its parent block.")
3901 .def_prop_ro(
3902 "attached",
3903 [](PyOperationBase &self) {
3904 PyOperation &operation = self.getOperation();
3905 operation.checkValid();
3906 return operation.isAttached();
3907 },
3908 "Reports if the operation is attached to its parent block.")
3909 .def(
3910 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
3911 R"(
3912 Erases the operation and frees its memory.
3913
3914 Note:
3915 After erasing, any Python references to the operation become invalid.)")
3916 .def("walk", &PyOperationBase::walk, "callback"_a,
3917 "walk_order"_a = PyWalkOrder::PostOrder,
3918 // clang-format off
3919 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
3920 // clang-format on
3921 R"(
3922 Walks the operation tree with a callback function.
3923
3924 Args:
3925 callback: A callable that takes an Operation and returns a WalkResult.
3926 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
3927
3928 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3929 .def_static(
3930 "create",
3931 [](std::string_view name,
3932 std::optional<std::vector<PyType *>> results,
3933 std::optional<std::vector<PyValue *>> operands,
3934 std::optional<nb::dict> attributes,
3935 std::optional<std::vector<PyBlock *>> successors, int regions,
3936 const std::optional<PyLocation> &location,
3937 const nb::object &maybeIp,
3938 bool inferType) -> nb::typed<nb::object, PyOperation> {
3939 // Unpack/validate operands.
3940 std::vector<MlirValue> mlirOperands;
3941 if (operands) {
3942 mlirOperands.reserve(operands->size());
3943 for (PyValue *operand : *operands) {
3944 if (!operand)
3945 throw nb::value_error("operand value cannot be None");
3946 mlirOperands.push_back(operand->get());
3947 }
3948 }
3949
3950 PyLocation pyLoc = maybeGetTracebackLocation(location);
3951 return PyOperation::create(
3952 name, results, mlirOperands.data(), mlirOperands.size(),
3953 attributes, successors, regions, pyLoc, maybeIp, inferType);
3954 },
3955 "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
3956 "attributes"_a = nb::none(), "successors"_a = nb::none(),
3957 "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
3958 "infer_type"_a = false,
3959 R"(
3960 Creates a new operation.
3961
3962 Args:
3963 name: Operation name (e.g. `dialect.operation`).
3964 results: Optional sequence of Type representing op result types.
3965 operands: Optional operands of the operation.
3966 attributes: Optional Dict of {str: Attribute}.
3967 successors: Optional List of Block for the operation's successors.
3968 regions: Number of regions to create (default = 0).
3969 location: Optional Location object (defaults to resolve from context manager).
3970 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
3971 infer_type: Whether to infer result types (default = False).
3972 Returns:
3973 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
3974 .def_static(
3975 "parse",
3976 [](const std::string &sourceStr, const std::string &sourceName,
3978 -> nb::typed<nb::object, PyOpView> {
3979 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3980 ->createOpView();
3981 },
3982 "source"_a, nb::kw_only(), "source_name"_a = "",
3983 "context"_a = nb::none(),
3984 "Parses an operation. Supports both text assembly format and binary "
3985 "bytecode format.")
3987 "Gets a capsule wrapping the MlirOperation.")
3990 "Creates an Operation from a capsule wrapping MlirOperation.")
3991 .def_prop_ro(
3992 "operation",
3993 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3994 return self;
3995 },
3996 "Returns self (the operation).")
3997 .def_prop_ro(
3998 "opview",
3999 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
4000 return self.createOpView();
4001 },
4002 R"(
4003 Returns an OpView of this operation.
4004
4005 Note:
4006 If the operation has a registered and loaded dialect then this OpView will
4007 be concrete wrapper class.)")
4008 .def_prop_ro("block", &PyOperation::getBlock,
4009 "Returns the block containing this operation.")
4010 .def_prop_ro(
4011 "successors",
4012 [](PyOperationBase &self) {
4013 return PyOpSuccessors(self.getOperation().getRef());
4014 },
4015 "Returns the list of Operation successors.")
4016 .def(
4017 "replace_uses_of_with",
4018 [](PyOperation &self, PyValue &of, PyValue &with) {
4019 mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
4020 },
4021 "of"_a, "with_"_a,
4022 "Replaces uses of the 'of' value with the 'with' value inside the "
4023 "operation.")
4024 .def("_set_invalid", &PyOperation::setInvalid,
4025 "Invalidate the operation.");
4026
4027 auto opViewClass =
4028 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
4029 .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
4030 .def(
4031 "__init__",
4032 [](PyOpView *self, std::string_view name,
4033 std::tuple<int, bool> opRegionSpec,
4034 nb::object operandSegmentSpecObj,
4035 nb::object resultSegmentSpecObj,
4036 std::optional<nb::list> resultTypeList, nb::list operandList,
4037 std::optional<nb::dict> attributes,
4038 std::optional<std::vector<PyBlock *>> successors,
4039 std::optional<int> regions,
4040 const std::optional<PyLocation> &location,
4041 const nb::object &maybeIp) {
4042 PyLocation pyLoc = maybeGetTracebackLocation(location);
4044 name, opRegionSpec, operandSegmentSpecObj,
4045 resultSegmentSpecObj, resultTypeList, operandList,
4046 attributes, successors, regions, pyLoc, maybeIp));
4047 },
4048 "name"_a, "opRegionSpec"_a,
4049 "operandSegmentSpecObj"_a = nb::none(),
4050 "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
4051 "operands"_a = nb::none(), "attributes"_a = nb::none(),
4052 "successors"_a = nb::none(), "regions"_a = nb::none(),
4053 "loc"_a = nb::none(), "ip"_a = nb::none())
4054 .def_prop_ro(
4055 "operation",
4056 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
4057 return self.getOperationObject();
4058 })
4059 .def_prop_ro("opview",
4060 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
4061 return self;
4062 })
4063 .def(
4064 "__str__",
4065 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
4066 .def_prop_ro(
4067 "successors",
4068 [](PyOperationBase &self) {
4069 return PyOpSuccessors(self.getOperation().getRef());
4070 },
4071 "Returns the list of Operation successors.")
4072 .def(
4073 "_set_invalid",
4074 [](PyOpView &self) { self.getOperation().setInvalid(); },
4075 "Invalidate the operation.");
4076 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
4077 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
4078 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
4079 // It is faster to pass the operation_name, ods_regions, and
4080 // ods_operand_segments/ods_result_segments as arguments to the constructor,
4081 // rather than to access them as attributes.
4082 opViewClass.attr("build_generic") = classmethod(
4083 [](nb::handle cls, std::optional<nb::list> resultTypeList,
4084 nb::list operandList, std::optional<nb::dict> attributes,
4085 std::optional<std::vector<PyBlock *>> successors,
4086 std::optional<int> regions, std::optional<PyLocation> location,
4087 const nb::object &maybeIp) {
4088 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4089 std::tuple<int, bool> opRegionSpec =
4090 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
4091 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
4092 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
4093 PyLocation pyLoc = maybeGetTracebackLocation(location);
4094 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
4095 resultSegmentSpec, resultTypeList,
4096 operandList, attributes, successors,
4097 regions, pyLoc, maybeIp);
4098 },
4099 "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
4100 "attributes"_a = nb::none(), "successors"_a = nb::none(),
4101 "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
4102 "Builds a specific, generated OpView based on class level attributes.");
4103 opViewClass.attr("parse") = classmethod(
4104 [](const nb::object &cls, const std::string &sourceStr,
4105 const std::string &sourceName,
4106 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
4107 PyOperationRef parsed =
4108 PyOperation::parse(context->getRef(), sourceStr, sourceName);
4109
4110 // Check if the expected operation was parsed, and cast to to the
4111 // appropriate `OpView` subclass if successful.
4112 // NOTE: This accesses attributes that have been automatically added to
4113 // `OpView` subclasses, and is not intended to be used on `OpView`
4114 // directly.
4115 std::string clsOpName =
4116 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4117 MlirStringRef identifier =
4119 std::string_view parsedOpName(identifier.data, identifier.length);
4120 if (clsOpName != parsedOpName)
4121 throw MLIRError(join("Expected a '", clsOpName, "' op, got: '",
4122 parsedOpName, "'"));
4123 return PyOpView::constructDerived(cls, parsed.getObject());
4124 },
4125 "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
4126 "context"_a = nb::none(),
4127 "Parses a specific, generated OpView based on class level attributes.");
4128
4130
4131 //----------------------------------------------------------------------------
4132 // Mapping of PyRegion.
4133 //----------------------------------------------------------------------------
4134 nb::class_<PyRegion>(m, "Region")
4135 .def_prop_ro(
4136 "blocks",
4137 [](PyRegion &self) {
4138 return PyBlockList(self.getParentOperation(), self.get());
4139 },
4140 "Returns a forward-optimized sequence of blocks.")
4141 .def_prop_ro(
4142 "owner",
4143 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4144 return self.getParentOperation()->createOpView();
4145 },
4146 "Returns the operation owning this region.")
4147 .def(
4148 "__iter__",
4149 [](PyRegion &self) {
4150 self.checkValid();
4151 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
4152 return PyBlockIterator(self.getParentOperation(), firstBlock);
4153 },
4154 "Iterates over blocks in the region.")
4155 .def(
4156 "__eq__",
4157 [](PyRegion &self, PyRegion &other) {
4158 return self.get().ptr == other.get().ptr;
4159 },
4160 "Compares two regions for pointer equality.")
4161 .def(
4162 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
4163 "Compares region with non-region object (always returns False).");
4164
4165 //----------------------------------------------------------------------------
4166 // Mapping of PyBlock.
4167 //----------------------------------------------------------------------------
4168 nb::class_<PyBlock>(m, "Block")
4170 "Gets a capsule wrapping the MlirBlock.")
4171 .def_prop_ro(
4172 "owner",
4173 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4174 return self.getParentOperation()->createOpView();
4175 },
4176 "Returns the owning operation of this block.")
4177 .def_prop_ro(
4178 "region",
4179 [](PyBlock &self) {
4180 MlirRegion region = mlirBlockGetParentRegion(self.get());
4181 return PyRegion(self.getParentOperation(), region);
4182 },
4183 "Returns the owning region of this block.")
4184 .def_prop_ro(
4185 "arguments",
4186 [](PyBlock &self) {
4187 return PyBlockArgumentList(self.getParentOperation(), self.get());
4188 },
4189 "Returns a list of block arguments.")
4190 .def(
4191 "add_argument",
4192 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
4193 return PyBlockArgument(self.getParentOperation(),
4194 mlirBlockAddArgument(self.get(), type, loc));
4195 },
4196 "type"_a, "loc"_a,
4197 R"(
4198 Appends an argument of the specified type to the block.
4199
4200 Args:
4201 type: The type of the argument to add.
4202 loc: The source location for the argument.
4203
4204 Returns:
4205 The newly added block argument.)")
4206 .def(
4207 "erase_argument",
4208 [](PyBlock &self, unsigned index) {
4209 return mlirBlockEraseArgument(self.get(), index);
4210 },
4211 "index"_a,
4212 R"(
4213 Erases the argument at the specified index.
4214
4215 Args:
4216 index: The index of the argument to erase.)")
4217 .def_prop_ro(
4218 "operations",
4219 [](PyBlock &self) {
4220 return PyOperationList(self.getParentOperation(), self.get());
4221 },
4222 "Returns a forward-optimized sequence of operations.")
4223 .def_static(
4224 "create_at_start",
4225 [](PyRegion &parent, const nb::sequence &pyArgTypes,
4226 const std::optional<nb::sequence> &pyArgLocs) {
4227 parent.checkValid();
4228 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4229 mlirRegionInsertOwnedBlock(parent, 0, block);
4230 return PyBlock(parent.getParentOperation(), block);
4231 },
4232 "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
4233 "Creates and returns a new Block at the beginning of the given "
4234 "region (with given argument types and locations).")
4235 .def(
4236 "append_to",
4237 [](PyBlock &self, PyRegion &region) {
4238 MlirBlock b = self.get();
4241 mlirRegionAppendOwnedBlock(region.get(), b);
4242 },
4243 "region"_a,
4244 R"(
4245 Appends this block to a region.
4246
4247 Transfers ownership if the block is currently owned by another region.
4248
4249 Args:
4250 region: The region to append the block to.)")
4251 .def(
4252 "create_before",
4253 [](PyBlock &self, const nb::args &pyArgTypes,
4254 const std::optional<nb::sequence> &pyArgLocs) {
4255 self.checkValid();
4256 MlirBlock block =
4257 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4258 MlirRegion region = mlirBlockGetParentRegion(self.get());
4259 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4260 return PyBlock(self.getParentOperation(), block);
4261 },
4262 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4263 "Creates and returns a new Block before this block "
4264 "(with given argument types and locations).")
4265 .def(
4266 "create_after",
4267 [](PyBlock &self, const nb::args &pyArgTypes,
4268 const std::optional<nb::sequence> &pyArgLocs) {
4269 self.checkValid();
4270 MlirBlock block =
4271 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4272 MlirRegion region = mlirBlockGetParentRegion(self.get());
4273 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4274 return PyBlock(self.getParentOperation(), block);
4275 },
4276 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4277 "Creates and returns a new Block after this block "
4278 "(with given argument types and locations).")
4279 .def(
4280 "__iter__",
4281 [](PyBlock &self) {
4282 self.checkValid();
4283 MlirOperation firstOperation =
4284 mlirBlockGetFirstOperation(self.get());
4285 return PyOperationIterator(self.getParentOperation(),
4286 firstOperation);
4287 },
4288 "Iterates over operations in the block.")
4289 .def(
4290 "__eq__",
4291 [](PyBlock &self, PyBlock &other) {
4292 return self.get().ptr == other.get().ptr;
4293 },
4294 "Compares two blocks for pointer equality.")
4295 .def(
4296 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4297 "Compares block with non-block object (always returns False).")
4298 .def(
4299 "__hash__", [](PyBlock &self) { return hash(self.get().ptr); },
4300 "Returns the hash value of the block.")
4301 .def(
4302 "__str__",
4303 [](PyBlock &self) {
4304 self.checkValid();
4305 PyPrintAccumulator printAccum;
4306 mlirBlockPrint(self.get(), printAccum.getCallback(),
4307 printAccum.getUserData());
4308 return printAccum.join();
4309 },
4310 "Returns the assembly form of the block.")
4311 .def(
4312 "append",
4313 [](PyBlock &self, PyOperationBase &operation) {
4314 if (operation.getOperation().isAttached())
4315 operation.getOperation().detachFromParent();
4316
4317 MlirOperation mlirOperation = operation.getOperation().get();
4318 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4319 operation.getOperation().setAttached(
4320 self.getParentOperation().getObject());
4321 },
4322 "operation"_a,
4323 R"(
4324 Appends an operation to this block.
4325
4326 If the operation is currently in another block, it will be moved.
4327
4328 Args:
4329 operation: The operation to append to the block.)")
4330 .def_prop_ro(
4331 "successors",
4332 [](PyBlock &self) {
4333 return PyBlockSuccessors(self, self.getParentOperation());
4334 },
4335 "Returns the list of Block successors.")
4336 .def_prop_ro(
4337 "predecessors",
4338 [](PyBlock &self) {
4339 return PyBlockPredecessors(self, self.getParentOperation());
4340 },
4341 "Returns the list of Block predecessors.");
4342
4343 //----------------------------------------------------------------------------
4344 // Mapping of PyInsertionPoint.
4345 //----------------------------------------------------------------------------
4346
4347 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4348 .def(nb::init<PyBlock &>(), "block"_a,
4349 "Inserts after the last operation but still inside the block.")
4350 .def("__enter__", &PyInsertionPoint::contextEnter,
4351 "Enters the insertion point as a context manager.")
4352 .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
4353 "exc_value"_a.none(), "traceback"_a.none(),
4354 "Exits the insertion point context manager.")
4355 .def_prop_ro_static(
4356 "current",
4357 [](nb::object & /*class*/) {
4359 if (!ip)
4360 throw nb::value_error("No current InsertionPoint");
4361 return ip;
4362 },
4363 nb::sig("def current(/) -> InsertionPoint"),
4364 "Gets the InsertionPoint bound to the current thread or raises "
4365 "ValueError if none has been set.")
4366 .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
4367 "Inserts before a referenced operation.")
4368 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
4369 R"(
4370 Creates an insertion point at the beginning of a block.
4371
4372 Args:
4373 block: The block at whose beginning operations should be inserted.
4374
4375 Returns:
4376 An InsertionPoint at the block's beginning.)")
4377 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4378 "block"_a,
4379 R"(
4380 Creates an insertion point before a block's terminator.
4381
4382 Args:
4383 block: The block whose terminator to insert before.
4384
4385 Returns:
4386 An InsertionPoint before the terminator.
4387
4388 Raises:
4389 ValueError: If the block has no terminator.)")
4390 .def_static("after", &PyInsertionPoint::after, "operation"_a,
4391 R"(
4392 Creates an insertion point immediately after an operation.
4393
4394 Args:
4395 operation: The operation after which to insert.
4396
4397 Returns:
4398 An InsertionPoint after the operation.)")
4399 .def("insert", &PyInsertionPoint::insert, "operation"_a,
4400 R"(
4401 Inserts an operation at this insertion point.
4402
4403 Args:
4404 operation: The operation to insert.)")
4405 .def_prop_ro(
4406 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4407 "Returns the block that this `InsertionPoint` points to.")
4408 .def_prop_ro(
4409 "ref_operation",
4410 [](PyInsertionPoint &self)
4411 -> std::optional<nb::typed<nb::object, PyOperation>> {
4412 auto refOperation = self.getRefOperation();
4413 if (refOperation)
4414 return refOperation->getObject();
4415 return {};
4416 },
4417 "The reference operation before which new operations are "
4418 "inserted, or None if the insertion point is at the end of "
4419 "the block.");
4420
4421 //----------------------------------------------------------------------------
4422 // Mapping of PyAttribute.
4423 //----------------------------------------------------------------------------
4424 nb::class_<PyAttribute>(m, "Attribute")
4425 // Delegate to the PyAttribute copy constructor, which will also lifetime
4426 // extend the backing context which owns the MlirAttribute.
4427 .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
4428 "Casts the passed attribute to the generic `Attribute`.")
4430 "Gets a capsule wrapping the MlirAttribute.")
4431 .def_static(
4433 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4434 .def_static(
4435 "parse",
4436 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4437 -> nb::typed<nb::object, PyAttribute> {
4438 PyMlirContext::ErrorCapture errors(context->getRef());
4439 MlirAttribute attr = mlirAttributeParseGet(
4440 context->get(), toMlirStringRef(attrSpec));
4441 if (mlirAttributeIsNull(attr))
4442 throw MLIRError("Unable to parse attribute", errors.take());
4443 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4444 },
4445 "asm"_a, "context"_a = nb::none(),
4446 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4447 "failure.")
4448 .def_prop_ro(
4449 "context",
4450 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4451 return self.getContext().getObject();
4452 },
4453 "Context that owns the `Attribute`.")
4454 .def_prop_ro(
4455 "type",
4456 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4457 return PyType(self.getContext(), mlirAttributeGetType(self))
4458 .maybeDownCast();
4459 },
4460 "Returns the type of the `Attribute`.")
4461 .def(
4462 "get_named",
4463 [](PyAttribute &self, std::string name) {
4464 return PyNamedAttribute(self, std::move(name));
4465 },
4466 nb::keep_alive<0, 1>(),
4467 R"(
4468 Binds a name to the attribute, creating a `NamedAttribute`.
4469
4470 Args:
4471 name: The name to bind to the `Attribute`.
4472
4473 Returns:
4474 A `NamedAttribute` with the given name and this attribute.)")
4475 .def(
4476 "__eq__",
4477 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4478 "Compares two attributes for equality.")
4479 .def(
4480 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4481 "Compares attribute with non-attribute object (always returns "
4482 "False).")
4483 .def(
4484 "__hash__", [](PyAttribute &self) { return hash(self.get().ptr); },
4485 "Returns the hash value of the attribute.")
4486 .def(
4487 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4489 .def(
4490 "__str__",
4491 [](PyAttribute &self) {
4492 PyPrintAccumulator printAccum;
4493 mlirAttributePrint(self, printAccum.getCallback(),
4494 printAccum.getUserData());
4495 return printAccum.join();
4496 },
4497 "Returns the assembly form of the Attribute.")
4498 .def(
4499 "__repr__",
4500 [](PyAttribute &self) {
4501 // Generally, assembly formats are not printed for __repr__ because
4502 // this can cause exceptionally long debug output and exceptions.
4503 // However, attribute values are generally considered useful and
4504 // are printed. This may need to be re-evaluated if debug dumps end
4505 // up being excessive.
4506 PyPrintAccumulator printAccum;
4507 printAccum.parts.append("Attribute(");
4508 mlirAttributePrint(self, printAccum.getCallback(),
4509 printAccum.getUserData());
4510 printAccum.parts.append(")");
4511 return printAccum.join();
4512 },
4513 "Returns a string representation of the attribute.")
4514 .def_prop_ro(
4515 "typeid",
4516 [](PyAttribute &self) {
4517 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4518 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4519 "mlirTypeID was expected to be non-null.");
4520 return PyTypeID(mlirTypeID);
4521 },
4522 "Returns the `TypeID` of the attribute.")
4523 .def(
4525 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4526 return self.maybeDownCast();
4527 },
4528 "Downcasts the attribute to a more specific attribute if possible.");
4529
4530 //----------------------------------------------------------------------------
4531 // Mapping of PyNamedAttribute
4532 //----------------------------------------------------------------------------
4533 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4534 .def(
4535 "__repr__",
4536 [](PyNamedAttribute &self) {
4537 PyPrintAccumulator printAccum;
4538 printAccum.parts.append("NamedAttribute(");
4539 printAccum.parts.append(
4540 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4541 mlirIdentifierStr(self.namedAttr.name).length));
4542 printAccum.parts.append("=");
4543 mlirAttributePrint(self.namedAttr.attribute,
4544 printAccum.getCallback(),
4545 printAccum.getUserData());
4546 printAccum.parts.append(")");
4547 return printAccum.join();
4548 },
4549 "Returns a string representation of the named attribute.")
4550 .def_prop_ro(
4551 "name",
4552 [](PyNamedAttribute &self) {
4553 return mlirIdentifierStr(self.namedAttr.name);
4554 },
4555 "The name of the `NamedAttribute` binding.")
4556 .def_prop_ro(
4557 "attr",
4558 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4559 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4560 "The underlying generic attribute of the `NamedAttribute` binding.");
4561
4562 //----------------------------------------------------------------------------
4563 // Mapping of PyType.
4564 //----------------------------------------------------------------------------
4565 nb::class_<PyType>(m, "Type")
4566 // Delegate to the PyType copy constructor, which will also lifetime
4567 // extend the backing context which owns the MlirType.
4568 .def(nb::init<PyType &>(), "cast_from_type"_a,
4569 "Casts the passed type to the generic `Type`.")
4571 "Gets a capsule wrapping the `MlirType`.")
4573 "Creates a Type from a capsule wrapping `MlirType`.")
4574 .def_static(
4575 "parse",
4576 [](std::string typeSpec,
4577 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4578 PyMlirContext::ErrorCapture errors(context->getRef());
4579 MlirType type =
4580 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4581 if (mlirTypeIsNull(type))
4582 throw MLIRError("Unable to parse type", errors.take());
4583 return PyType(context.get()->getRef(), type).maybeDownCast();
4584 },
4585 "asm"_a, "context"_a = nb::none(),
4586 R"(
4587 Parses the assembly form of a type.
4588
4589 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4590
4591 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4592 .def_prop_ro(
4593 "context",
4594 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4595 return self.getContext().getObject();
4596 },
4597 "Context that owns the `Type`.")
4598 .def(
4599 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4600 "Compares two types for equality.")
4601 .def(
4602 "__eq__", [](PyType &self, nb::object &other) { return false; },
4603 "other"_a.none(),
4604 "Compares type with non-type object (always returns False).")
4605 .def(
4606 "__hash__", [](PyType &self) { return hash(self.get().ptr); },
4607 "Returns the hash value of the `Type`.")
4608 .def(
4609 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4610 .def(
4611 "__str__",
4612 [](PyType &self) {
4613 PyPrintAccumulator printAccum;
4614 mlirTypePrint(self, printAccum.getCallback(),
4615 printAccum.getUserData());
4616 return printAccum.join();
4617 },
4618 "Returns the assembly form of the `Type`.")
4619 .def(
4620 "__repr__",
4621 [](PyType &self) {
4622 // Generally, assembly formats are not printed for __repr__ because
4623 // this can cause exceptionally long debug output and exceptions.
4624 // However, types are an exception as they typically have compact
4625 // assembly forms and printing them is useful.
4626 PyPrintAccumulator printAccum;
4627 printAccum.parts.append("Type(");
4628 mlirTypePrint(self, printAccum.getCallback(),
4629 printAccum.getUserData());
4630 printAccum.parts.append(")");
4631 return printAccum.join();
4632 },
4633 "Returns a string representation of the `Type`.")
4634 .def(
4636 [](PyType &self) -> nb::typed<nb::object, PyType> {
4637 return self.maybeDownCast();
4638 },
4639 "Downcasts the Type to a more specific `Type` if possible.")
4640 .def_prop_ro(
4641 "typeid",
4642 [](PyType &self) {
4643 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4644 if (!mlirTypeIDIsNull(mlirTypeID))
4645 return PyTypeID(mlirTypeID);
4646 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4647 throw nb::value_error(join(origRepr, " has no typeid.").c_str());
4648 },
4649 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4650 "`Type` has no "
4651 "`TypeID`.");
4652
4653 //----------------------------------------------------------------------------
4654 // Mapping of PyTypeID.
4655 //----------------------------------------------------------------------------
4656 nb::class_<PyTypeID>(m, "TypeID")
4658 "Gets a capsule wrapping the `MlirTypeID`.")
4660 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4661 // Note, this tests whether the underlying TypeIDs are the same,
4662 // not whether the wrapper MlirTypeIDs are the same, nor whether
4663 // the Python objects are the same (i.e., PyTypeID is a value type).
4664 .def(
4665 "__eq__",
4666 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4667 "Compares two `TypeID`s for equality.")
4668 .def(
4669 "__eq__",
4670 [](PyTypeID &self, const nb::object &other) { return false; },
4671 "Compares TypeID with non-TypeID object (always returns False).")
4672 // Note, this gives the hash value of the underlying TypeID, not the
4673 // hash value of the Python object, nor the hash value of the
4674 // MlirTypeID wrapper.
4675 .def(
4676 "__hash__",
4677 [](PyTypeID &self) {
4678 return static_cast<size_t>(mlirTypeIDHashValue(self));
4679 },
4680 "Returns the hash value of the `TypeID`.");
4681
4682 //----------------------------------------------------------------------------
4683 // Mapping of Value.
4684 //----------------------------------------------------------------------------
4685 m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
4686
4687 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4688 nb::sig("class Value(Generic[_T])"))
4689 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
4690 "Creates a Value reference from another `Value`.")
4692 "Gets a capsule wrapping the `MlirValue`.")
4694 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4695 .def_prop_ro(
4696 "context",
4697 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4698 return self.getParentOperation()->getContext().getObject();
4699 },
4700 "Context in which the value lives.")
4701 .def(
4702 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4704 .def_prop_ro(
4705 "owner",
4706 [](PyValue &self)
4707 -> nb::typed<nb::object, std::variant<PyOpView, PyBlock>> {
4708 MlirValue v = self.get();
4709 if (mlirValueIsAOpResult(v)) {
4710 assert(mlirOperationEqual(self.getParentOperation()->get(),
4711 mlirOpResultGetOwner(self.get())) &&
4712 "expected the owner of the value in Python to match "
4713 "that in "
4714 "the IR");
4715 return self.getParentOperation()->createOpView();
4716 }
4717
4719 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4720 return nb::cast(PyBlock(self.getParentOperation(), block));
4721 }
4722
4723 assert(false && "Value must be a block argument or an op result");
4724 return nb::none();
4725 },
4726 "Returns the owner of the value (`Operation` for results, `Block` "
4727 "for "
4728 "arguments).")
4729 .def_prop_ro(
4730 "uses",
4731 [](PyValue &self) {
4732 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
4733 },
4734 "Returns an iterator over uses of this value.")
4735 .def(
4736 "__eq__",
4737 [](PyValue &self, PyValue &other) {
4738 return self.get().ptr == other.get().ptr;
4739 },
4740 "Compares two values for pointer equality.")
4741 .def(
4742 "__eq__", [](PyValue &self, nb::object other) { return false; },
4743 "Compares value with non-value object (always returns False).")
4744 .def(
4745 "__hash__", [](PyValue &self) { return hash(self.get().ptr); },
4746 "Returns the hash value of the value.")
4747 .def(
4748 "__str__",
4749 [](PyValue &self) {
4750 PyPrintAccumulator printAccum;
4751 printAccum.parts.append("Value(");
4752 mlirValuePrint(self.get(), printAccum.getCallback(),
4753 printAccum.getUserData());
4754 printAccum.parts.append(")");
4755 return printAccum.join();
4756 },
4757 R"(
4758 Returns the string form of the value.
4759
4760 If the value is a block argument, this is the assembly form of its type and the
4761 position in the argument list. If the value is an operation result, this is
4762 equivalent to printing the operation that produced it.
4763 )")
4764 .def(
4765 "get_name",
4766 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4767 PyPrintAccumulator printAccum;
4768 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4769 if (useLocalScope)
4771 if (useNameLocAsPrefix)
4773 MlirAsmState valueState =
4774 mlirAsmStateCreateForValue(self.get(), flags);
4775 mlirValuePrintAsOperand(self.get(), valueState,
4776 printAccum.getCallback(),
4777 printAccum.getUserData());
4779 mlirAsmStateDestroy(valueState);
4780 return printAccum.join();
4781 },
4782 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
4783 R"(
4784 Returns the string form of value as an operand.
4785
4786 Args:
4787 use_local_scope: Whether to use local scope for naming.
4788 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4789
4790 Returns:
4791 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4792 .def(
4793 "get_name",
4794 [](PyValue &self, PyAsmState &state) {
4795 PyPrintAccumulator printAccum;
4796 MlirAsmState valueState = state.get();
4797 mlirValuePrintAsOperand(self.get(), valueState,
4798 printAccum.getCallback(),
4799 printAccum.getUserData());
4800 return printAccum.join();
4801 },
4802 "state"_a,
4803 "Returns the string form of value as an operand (i.e., the ValueID).")
4804 .def_prop_ro(
4805 "type",
4806 [](PyValue &self) -> nb::typed<nb::object, PyType> {
4807 return PyType(self.getParentOperation()->getContext(),
4808 mlirValueGetType(self.get()))
4809 .maybeDownCast();
4810 },
4811 "Returns the type of the value.")
4812 .def(
4813 "set_type",
4814 [](PyValue &self, const PyType &type) {
4815 mlirValueSetType(self.get(), type);
4816 },
4817 "type"_a, "Sets the type of the value.",
4818 nb::sig("def set_type(self, type: _T)"))
4819 .def(
4820 "replace_all_uses_with",
4821 [](PyValue &self, PyValue &with) {
4822 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4823 },
4824 "Replace all uses of value with the new value, updating anything in "
4825 "the IR that uses `self` to use the other value instead.")
4826 .def(
4827 "replace_all_uses_except",
4828 [](PyValue &self, PyValue &with, PyOperation &exception) {
4829 MlirOperation exceptedUser = exception.get();
4830 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4831 },
4832 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4833 .def(
4834 "replace_all_uses_except",
4835 [](PyValue &self, PyValue &with, const nb::list &exceptions) {
4836 // Convert Python list to a std::vector of MlirOperations
4837 std::vector<MlirOperation> exceptionOps;
4838 for (nb::handle exception : exceptions) {
4839 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4840 }
4841
4843 self, with, static_cast<intptr_t>(exceptionOps.size()),
4844 exceptionOps.data());
4845 },
4846 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4847 .def(
4848 "replace_all_uses_except",
4849 [](PyValue &self, PyValue &with, PyOperation &exception) {
4850 MlirOperation exceptedUser = exception.get();
4851 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4852 },
4853 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4854 .def(
4855 "replace_all_uses_except",
4856 [](PyValue &self, PyValue &with,
4857 std::vector<PyOperation> &exceptions) {
4858 // Convert Python list to a std::vector of MlirOperations
4859 std::vector<MlirOperation> exceptionOps;
4860 for (PyOperation &exception : exceptions)
4861 exceptionOps.push_back(exception);
4863 self, with, static_cast<intptr_t>(exceptionOps.size()),
4864 exceptionOps.data());
4865 },
4866 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4867 .def(
4869 [](PyValue &self) { return self.maybeDownCast(); },
4870 "Downcasts the `Value` to a more specific kind if possible.")
4871 .def_prop_ro(
4872 "location",
4873 [](PyValue self) {
4874 return PyLocation(
4876 mlirValueGetLocation(self));
4877 },
4878 "Returns the source location of the value.");
4879
4883
4884 nb::class_<PyAsmState>(m, "AsmState")
4885 .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
4886 R"(
4887 Creates an `AsmState` for consistent SSA value naming.
4888
4889 Args:
4890 value: The value to create state for.
4891 use_local_scope: Whether to use local scope for naming.)")
4892 .def(nb::init<PyOperationBase &, bool>(), "op"_a,
4893 "use_local_scope"_a = false,
4894 R"(
4895 Creates an AsmState for consistent SSA value naming.
4896
4897 Args:
4898 op: The operation to create state for.
4899 use_local_scope: Whether to use local scope for naming.)");
4900
4901 //----------------------------------------------------------------------------
4902 // Mapping of SymbolTable.
4903 //----------------------------------------------------------------------------
4904 nb::class_<PySymbolTable>(m, "SymbolTable")
4905 .def(nb::init<PyOperationBase &>(),
4906 R"(
4907 Creates a symbol table for an operation.
4908
4909 Args:
4910 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4911
4912 Raises:
4913 TypeError: If the operation is not a symbol table.)")
4914 .def(
4915 "__getitem__",
4916 [](PySymbolTable &self,
4917 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4918 return self.dunderGetItem(name);
4919 },
4920 R"(
4921 Looks up a symbol by name in the symbol table.
4922
4923 Args:
4924 name: The name of the symbol to look up.
4925
4926 Returns:
4927 The operation defining the symbol.
4928
4929 Raises:
4930 KeyError: If the symbol is not found.)")
4931 .def("insert", &PySymbolTable::insert, "operation"_a,
4932 R"(
4933 Inserts a symbol operation into the symbol table.
4934
4935 Args:
4936 operation: An operation with a symbol name to insert.
4937
4938 Returns:
4939 The symbol name attribute of the inserted operation.
4940
4941 Raises:
4942 ValueError: If the operation does not have a symbol name.)")
4943 .def("erase", &PySymbolTable::erase, "operation"_a,
4944 R"(
4945 Erases a symbol operation from the symbol table.
4946
4947 Args:
4948 operation: The symbol operation to erase.
4949
4950 Note:
4951 The operation is also erased from the IR and invalidated.)")
4952 .def("__delitem__", &PySymbolTable::dunderDel,
4953 "Deletes a symbol by name from the symbol table.")
4954 .def(
4955 "__contains__",
4956 [](PySymbolTable &table, const std::string &name) {
4957 return !mlirOperationIsNull(mlirSymbolTableLookup(
4958 table, mlirStringRefCreate(name.data(), name.length())));
4959 },
4960 "Checks if a symbol with the given name exists in the table.")
4961 // Static helpers.
4962 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
4963 "name"_a, "Sets the symbol name for a symbol operation.")
4964 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
4965 "Gets the symbol name from a symbol operation.")
4966 .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
4967 "Gets the visibility attribute of a symbol operation.")
4968 .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
4969 "visibility"_a,
4970 "Sets the visibility attribute of a symbol operation.")
4971 .def_static("replace_all_symbol_uses",
4972 &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
4973 "new_symbol"_a, "from_op"_a,
4974 "Replaces all uses of a symbol with a new symbol name within "
4975 "the given operation.")
4976 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4977 "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
4978 "Walks symbol tables starting from an operation with a "
4979 "callback function.");
4980
4981 // Container bindings.
4997
4998 // Debug bindings.
5000
5001 // Attribute builder getter.
5003
5004 // Extensible Dialect
5008}
5009} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
5010} // namespace python
5011} // namespace mlir
return success()
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:32
static const char kModuleParseDocstring[]
Definition IRCore.cpp:32
#define Py_XNewRef(obj)
Definition IRCore.cpp:2703
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:55
#define _Py_CAST(type, expr)
Definition IRCore.cpp:2679
#define Py_NewRef(obj)
Definition IRCore.cpp:2712
#define _Py_NULL
Definition IRCore.cpp:2690
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:60
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:43
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:445
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:471
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1099
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:668
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1119
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:680
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:672
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1136
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:520
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:735
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:436
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:177
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:731
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:638
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1284
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1155
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1085
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:201
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:467
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:209
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:796
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:704
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:219
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:459
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:214
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:232
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1131
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:743
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1303
MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags)
Definition IR.cpp:156
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:642
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:236
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:251
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1095
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:453
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:424
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:224
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:676
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1159
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:739
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1299
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:414
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:815
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:720
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1146
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:428
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:646
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:503
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:634
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1127
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1141
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:712
MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:716
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:654
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:810
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1295
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:512
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:658
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1288
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:820
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1123
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1104
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:684
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:588
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:255
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1280
void mlirOperationRemoveFromParent(MlirOperation op)
Definition IR.cpp:640
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1091
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:805
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:205
void mlirValueDump(MlirValue value)
Definition IR.cpp:1163
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1269
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:449
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:625
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:195
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1150
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:516
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:247
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:228
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:240
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:508
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:479
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:708
void mlirTypeDump(MlirType type)
Definition IR.cpp:1274
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:801
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the value caster registration bindin...
Definition Interop.h:142
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Definition Interop.h:130
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
Accumulates into a file, either writing text (default) or binary.
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
nanobind::class_< PyRegionList > ClassTy
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:298
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:279
PyAsmState(MlirValue value, bool useLocalScope)
Definition IRCore.cpp:1774
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1006
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRCore.h:1008
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:1884
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:1880
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:1888
nanobind::typed< nanobind::object, PyAttribute > maybeDownCast()
Definition IRCore.cpp:1896
PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2213
Python wrapper for MlirBlockArgument.
Definition IRCore.h:1637
nanobind::typed< nanobind::object, PyBlock > dunderNext()
Definition IRCore.cpp:242
Blocks are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1433
PyBlock appendBlock(const nanobind::args &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Definition IRCore.cpp:298
PyBlockPredecessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2373
PyBlockSuccessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2350
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:182
Represents a diagnostic handler attached to the context.
Definition IRCore.h:406
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:753
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:747
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.h:418
Python class mirroring the C MlirDiagnostic struct.
Definition IRCore.h:356
Wrapper around an MlirDialectRegistry.
Definition IRCore.h:498
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:837
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRCore.h:474
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:820
static bool attach(const nanobind::object &opName, const nanobind::object &target, PyMlirContext &context)
Definition IRCore.cpp:2576
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2629
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2647
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:29
void registerOpAdaptorImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds an operation adaptor class.
Definition Globals.cpp:154
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:190
bool loadDialectModule(std::string_view dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:64
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
std::optional< nanobind::object > lookupOperationClass(std::string_view operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:220
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:112
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:99
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:143
void setDialectSearchPrefixes(std::vector< std::string > newValues)
Definition Globals.h:43
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:176
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:122
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:166
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition Globals.cpp:132
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:205
std::vector< std::string > getDialectSearchPrefixes()
Get and set the list of parent modules to search for dialect implementation classes.
Definition Globals.h:39
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:833
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1805
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:1870
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1844
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:1853
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1831
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1796
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:1866
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:861
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:853
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:849
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:865
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRCore.h:307
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:511
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:479
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:526
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:520
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:212
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRCore.cpp:471
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRCore.h:246
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:516
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:475
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:1864
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:930
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:548
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:898
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:923
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRCore.h:1032
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:1914
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRCore.h:66
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRCore.h:92
void dunderSetItem(const std::string &name, const PyAttribute &attr)
Definition IRCore.cpp:2433
nanobind::typed< nanobind::object, PyAttribute > dunderGetItemNamed(const std::string &name)
Definition IRCore.cpp:2400
nanobind::typed< nanobind::object, std::optional< PyAttribute > > get(const std::string &key, nanobind::object defaultValue)
Definition IRCore.cpp:2410
static void forEachAttr(MlirOperation op, std::function< void(MlirStringRef, MlirAttribute)> fn)
Definition IRCore.cpp:2455
PyNamedAttribute dunderGetItemIndexed(intptr_t index)
Definition IRCore.cpp:2418
nanobind::typed< nanobind::object, PyOpOperand > dunderNext()
Definition IRCore.cpp:407
PyOpOperandList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2245
void dunderSetItem(intptr_t index, PyValue value)
Definition IRCore.cpp:2253
nanobind::typed< nanobind::object, PyOpView > getOwner() const
Definition IRCore.cpp:388
PyOpOperands(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2288
Sliceable< PyOpOperandList, PyOpOperand > SliceableT
Definition IRCore.cpp:2286
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1425
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2317
void dunderSetItem(intptr_t index, PyBlock block)
Definition IRCore.cpp:2325
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRCore.h:735
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1764
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1587
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1756
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:578
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1189
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1143
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition IRCore.cpp:1042
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1090
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1171
void walk(std::function< PyWalkResult(MlirOperation)> callback, PyWalkOrder walkOrder)
Definition IRCore.cpp:1111
nanobind::typed< nanobind::object, PyOpView > dunderNext()
Definition IRCore.cpp:319
Operations are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1474
nanobind::typed< nanobind::object, PyOpView > dunderGetItem(intptr_t index)
Definition IRCore.cpp:358
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, const MlirValue *operands, size_t numOperands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1253
void setInvalid()
Invalidate the operation.
Definition IRCore.h:702
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRCore.h:635
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:999
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1368
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1229
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1205
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1224
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRCore.cpp:1011
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1388
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1215
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:990
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:938
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRCore.cpp:1026
nanobind::typed< nanobind::object, PyRegion > dunderNext()
Definition IRCore.cpp:190
Regions of an op are fixed length and indexed numerically so are represented with a sequence-like con...
Definition IRCore.h:1390
PyRegionList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:209
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2066
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2030
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2106
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2051
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2147
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2038
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2135
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2078
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2061
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2091
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2117
Tracks an entry in the thread context stack.
Definition IRCore.h:125
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:663
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:691
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:703
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:668
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:611
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:714
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:673
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:658
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:606
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRCore.h:181
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRCore.h:901
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:1960
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:1970
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:1964
Wrapper around the generic MlirType.
Definition IRCore.h:875
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:877
bool operator==(const PyType &other) const
Definition IRCore.cpp:1926
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:1930
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:1934
nanobind::typed< nanobind::object, PyType > maybeDownCast()
Definition IRCore.cpp:1942
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:1978
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRCore.h:1177
nanobind::typed< nanobind::object, std::variant< PyBlockArgument, PyOpResult, PyValue > > maybeDownCast()
Definition IRCore.cpp:1997
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:2018
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
struct MlirDiagnostic MlirDiagnostic
Definition Diagnostics.h:29
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate(void)
Get the dynamic op trait that indicates the operation is a terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate(MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData)
Create a custom dynamic op trait with the given type ID and callbacks.
MLIR_CAPI_EXPORTED bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, MlirStringRef opName, MlirContext context)
Attach a dynamic op trait to the given operation name.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate(void)
Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:264
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:855
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1206
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:116
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:845
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:272
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1388
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:136
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:310
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1367
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:848
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1315
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:391
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1372
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:83
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1336
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1249
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1232
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:72
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:955
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition IR.cpp:320
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:354
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1188
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1171
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:402
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1220
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:288
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:360
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1307
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:994
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:869
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1192
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:836
MlirWalkResult
Operation walk result.
Definition IR.h:841
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:935
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1160
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:99
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1063
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:292
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:346
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1044
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:94
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition IR.cpp:342
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:941
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:978
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:941
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1019
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1081
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1362
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1253
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:280
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1218
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1352
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:406
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:298
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:374
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1067
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:830
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1178
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:111
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:861
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1265
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1228
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:333
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1009
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:387
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:873
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:268
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1257
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1348
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:998
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:324
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:398
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:891
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:54
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:990
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:410
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1072
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1313
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1377
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:931
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:1002
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:880
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1261
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1324
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:107
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:120
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:304
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:378
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:103
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:328
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1210
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1344
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:918
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1058
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:859
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1250
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:76
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
Definition IR.cpp:909
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:865
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1165
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:852
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:70
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:924
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:124
MLIR_CAPI_EXPORTED int mlirLlvmThreadPoolGetMaxConcurrency(MlirLlvmThreadPool pool)
Returns the maximum number of threads in the thread pool.
Definition Support.cpp:38
MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool)
Destroy an LLVM thread pool.
Definition Support.cpp:34
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void)
Create an LLVM thread pool.
Definition Support.cpp:30
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
struct MlirStringRef MlirStringRef
Definition Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:201
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1553
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m)
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1238
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:198
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1568
PyObjectRef< PyOperation > PyOperationRef
Definition IRCore.h:630
static void populateResultTypes(std::string_view name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1467
MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.h:1339
static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait, PyMlirContext &context)
Definition IRCore.cpp:2561
PyObjectRef< PyModule > PyModuleRef
Definition IRCore.h:537
static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData, const char *methodName)
Definition IRCore.cpp:2550
static PyOperationRef getValueOwnerRef(MlirValue value)
Definition IRCore.cpp:1982
MlirBlock MLIR_PYTHON_API_EXPORTED createBlock(const nanobind::sequence &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1414
PyWalkOrder
Traversal order for operation walk.
Definition IRCore.h:347
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m)
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:1878
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
MlirLogicalResult(* verifyTrait)(MlirOperation op, void *userData)
The callback function to verify the operation.
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
MlirLogicalResult(* verifyRegionTrait)(MlirOperation op, void *userData)
The callback function to verify the operation with access to regions.
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Named MLIR attribute.
Definition IR.h:76
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRCore.h:1326
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:144
static nanobind::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:149
static void dunderSetItemNamed(const std::string &attributeKind, nanobind::callable func, bool replace)
Definition IRCore.cpp:156
static void set(nanobind::object &o, bool enable)
Definition IRCore.cpp:106
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRCore.h:434
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRCore.h:444