MLIR 23.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
13#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14// clang-format on
16#include "mlir-c/Debug.h"
17#include "mlir-c/Diagnostics.h"
19#include "mlir-c/IR.h"
20#include "mlir-c/Support.h"
21
22#include <array>
23#include <functional>
24#include <optional>
25#include <string>
26
27namespace nb = nanobind;
28using namespace nb::literals;
29using namespace mlir;
31
32static const char kModuleParseDocstring[] =
33 R"(Parses a module's assembly format from a string.
34
35Returns a new MlirModule or raises an MLIRError if the parsing fails.
36
37See also: https://mlir.llvm.org/docs/LangRef/
38)";
39
40static const char kDumpDocstring[] =
41 "Dumps a debug representation of the object to stderr.";
42
44 R"(Replace all uses of this value with the `with` value, except for those
45in `exceptions`. `exceptions` can be either a single operation or a list of
46operations.
47)";
48
49//------------------------------------------------------------------------------
50// Utilities.
51//------------------------------------------------------------------------------
52
53/// Local helper to compute std::hash for a value.
54template <typename T>
55static size_t hash(const T &value) {
56 return std::hash<T>{}(value);
57}
58
59static nb::object
60createCustomDialectWrapper(const std::string &dialectNamespace,
61 nb::object dialectDescriptor) {
62 auto dialectClass =
64 dialectNamespace);
65 if (!dialectClass) {
66 // Use the base class.
68 std::move(dialectDescriptor)));
69 }
70
71 // Create the custom implementation.
72 return (*dialectClass)(std::move(dialectDescriptor));
73}
74
75namespace mlir {
76namespace python {
78
79MlirBlock createBlock(const nb::sequence &pyArgTypes,
80 const std::optional<nb::sequence> &pyArgLocs) {
81 std::vector<MlirType> argTypes;
82 argTypes.reserve(nb::len(pyArgTypes));
83 for (nb::handle pyType : pyArgTypes)
84 argTypes.push_back(
85 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
86
87 std::vector<MlirLocation> argLocs;
88 if (pyArgLocs) {
89 argLocs.reserve(nb::len(*pyArgLocs));
90 for (nb::handle pyLoc : *pyArgLocs)
91 argLocs.push_back(
92 nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
93 } else if (!argTypes.empty()) {
94 argLocs.assign(
95 argTypes.size(),
97 }
98
99 if (argTypes.size() != argLocs.size())
100 throw nb::value_error(
101 join("Expected ", argTypes.size(), " locations, got: ", argLocs.size())
102 .c_str());
103 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
104}
105
106void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
107 nb::ft_lock_guard lock(mutex);
108 mlirEnableGlobalDebug(enable);
109}
110
111bool PyGlobalDebugFlag::get(const nb::object &) {
112 nb::ft_lock_guard lock(mutex);
114}
115
116void PyGlobalDebugFlag::bind(nb::module_ &m) {
117 // Debug flags.
118 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
119 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
120 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
121 .def_static(
122 "set_types",
123 [](const std::string &type) {
124 nb::ft_lock_guard lock(mutex);
125 mlirSetGlobalDebugType(type.c_str());
126 },
127 "types"_a, "Sets specific debug types to be produced by LLVM.")
128 .def_static(
129 "set_types",
130 [](const std::vector<std::string> &types) {
131 std::vector<const char *> pointers;
132 pointers.reserve(types.size());
133 for (const std::string &str : types)
134 pointers.push_back(str.c_str());
135 nb::ft_lock_guard lock(mutex);
136 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
137 },
138 "types"_a,
139 "Sets multiple specific debug types to be produced by LLVM.");
140}
141
142nb::ft_mutex PyGlobalDebugFlag::mutex;
143
144bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
145 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
146}
147
148nb::callable
149PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
150 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
151 if (!builder)
152 throw nb::key_error(attributeKind.c_str());
153 return *builder;
154}
155
156void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
157 nb::callable func, bool replace) {
158 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
159 replace);
160}
161
162void PyAttrBuilderMap::bind(nb::module_ &m) {
163 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
164 .def_static("contains", &PyAttrBuilderMap::dunderContains,
165 "attribute_kind"_a,
166 "Checks whether an attribute builder is registered for the "
167 "given attribute kind.")
168 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
169 "attribute_kind"_a,
170 "Gets the registered attribute builder for the given "
171 "attribute kind.")
172 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
173 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
174 "Register an attribute builder for building MLIR "
175 "attributes from Python values.");
176}
177
178//------------------------------------------------------------------------------
179// PyBlock
180//------------------------------------------------------------------------------
181
183 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
184}
185
186//------------------------------------------------------------------------------
187// Collections.
188//------------------------------------------------------------------------------
189
190nb::typed<nb::object, PyRegion> PyRegionIterator::dunderNext() {
191 operation->checkValid();
192 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
193 PyErr_SetNone(PyExc_StopIteration);
194 // python functions should return NULL after setting any exception
195 return nb::object();
196 }
197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198 return nb::cast(PyRegion(operation, region));
199}
200
201void PyRegionIterator::bind(nb::module_ &m) {
202 nb::class_<PyRegionIterator>(m, "RegionIterator")
203 .def("__iter__", &PyRegionIterator::dunderIter,
204 "Returns an iterator over the regions in the operation.")
205 .def("__next__", &PyRegionIterator::dunderNext,
206 "Returns the next region in the iteration.");
207}
208
212 length == -1 ? mlirOperationGetNumRegions(operation->get())
213 : length,
214 step),
215 operation(std::move(operation)) {}
216
218 operation->checkValid();
219 return PyRegionIterator(operation, startIndex);
220}
221
223 c.def("__iter__", &PyRegionList::dunderIter,
224 "Returns an iterator over the regions in the sequence.");
225}
226
227intptr_t PyRegionList::getRawNumElements() {
228 operation->checkValid();
229 return mlirOperationGetNumRegions(operation->get());
230}
231
232PyRegion PyRegionList::getRawElement(intptr_t pos) {
233 operation->checkValid();
234 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
235}
236
237PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
238 intptr_t step) const {
239 return PyRegionList(operation, startIndex, length, step);
240}
241
242nb::typed<nb::object, PyBlock> PyBlockIterator::dunderNext() {
243 operation->checkValid();
244 if (mlirBlockIsNull(next)) {
245 PyErr_SetNone(PyExc_StopIteration);
246 // python functions should return NULL after setting any exception
247 return nb::object();
248 }
249
250 PyBlock returnBlock(operation, next);
251 next = mlirBlockGetNextInRegion(next);
252 return nb::cast(returnBlock);
253}
254
255void PyBlockIterator::bind(nb::module_ &m) {
256 nb::class_<PyBlockIterator>(m, "BlockIterator")
257 .def("__iter__", &PyBlockIterator::dunderIter,
258 "Returns an iterator over the blocks in the operation's region.")
259 .def("__next__", &PyBlockIterator::dunderNext,
260 "Returns the next block in the iteration.");
261}
262
264 operation->checkValid();
265 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
266}
267
269 operation->checkValid();
270 intptr_t count = 0;
271 MlirBlock block = mlirRegionGetFirstBlock(region);
272 while (!mlirBlockIsNull(block)) {
273 count += 1;
274 block = mlirBlockGetNextInRegion(block);
275 }
276 return count;
277}
278
280 operation->checkValid();
281 if (index < 0) {
282 index += dunderLen();
283 }
284 if (index < 0) {
285 throw nb::index_error("attempt to access out of bounds block");
286 }
287 MlirBlock block = mlirRegionGetFirstBlock(region);
288 while (!mlirBlockIsNull(block)) {
289 if (index == 0) {
290 return PyBlock(operation, block);
291 }
292 block = mlirBlockGetNextInRegion(block);
293 index -= 1;
294 }
295 throw nb::index_error("attempt to access out of bounds block");
296}
297
298PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
299 const std::optional<nb::sequence> &pyArgLocs) {
300 operation->checkValid();
301 MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
302 mlirRegionAppendOwnedBlock(region, block);
303 return PyBlock(operation, block);
304}
305
306void PyBlockList::bind(nb::module_ &m) {
307 nb::class_<PyBlockList>(m, "BlockList")
308 .def("__getitem__", &PyBlockList::dunderGetItem,
309 "Returns the block at the specified index.")
310 .def("__iter__", &PyBlockList::dunderIter,
311 "Returns an iterator over blocks in the operation's region.")
312 .def("__len__", &PyBlockList::dunderLen,
313 "Returns the number of blocks in the operation's region.")
314 .def("append", &PyBlockList::appendBlock,
315 R"(
316 Appends a new block, with argument types as positional args.
317
318 Returns:
319 The created block.
320 )",
321 "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
322}
323
324nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
325 parentOperation->checkValid();
326 if (mlirOperationIsNull(next)) {
327 PyErr_SetNone(PyExc_StopIteration);
328 // python functions should return NULL after setting any exception
329 return nb::object();
330 }
331
332 PyOperationRef returnOperation =
333 PyOperation::forOperation(parentOperation->getContext(), next);
334 next = mlirOperationGetNextInBlock(next);
335 return returnOperation->createOpView();
336}
337
338void PyOperationIterator::bind(nb::module_ &m) {
339 nb::class_<PyOperationIterator>(m, "OperationIterator")
340 .def("__iter__", &PyOperationIterator::dunderIter,
341 "Returns an iterator over the operations in an operation's block.")
342 .def("__next__", &PyOperationIterator::dunderNext,
343 "Returns the next operation in the iteration.");
344}
345
347 parentOperation->checkValid();
348 return PyOperationIterator(parentOperation,
350}
351
353 parentOperation->checkValid();
354 intptr_t count = 0;
355 MlirOperation childOp = mlirBlockGetFirstOperation(block);
356 while (!mlirOperationIsNull(childOp)) {
357 count += 1;
358 childOp = mlirOperationGetNextInBlock(childOp);
359 }
360 return count;
361}
362
363nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
364 parentOperation->checkValid();
365 if (index < 0) {
366 index += dunderLen();
367 }
368 if (index < 0) {
369 throw nb::index_error("attempt to access out of bounds operation");
370 }
371 MlirOperation childOp = mlirBlockGetFirstOperation(block);
372 while (!mlirOperationIsNull(childOp)) {
373 if (index == 0) {
374 return PyOperation::forOperation(parentOperation->getContext(), childOp)
375 ->createOpView();
376 }
377 childOp = mlirOperationGetNextInBlock(childOp);
378 index -= 1;
379 }
380 throw nb::index_error("attempt to access out of bounds operation");
381}
382
383void PyOperationList::bind(nb::module_ &m) {
384 nb::class_<PyOperationList>(m, "OperationList")
385 .def("__getitem__", &PyOperationList::dunderGetItem,
386 "Returns the operation at the specified index.")
387 .def("__iter__", &PyOperationList::dunderIter,
388 "Returns an iterator over operations in the list.")
389 .def("__len__", &PyOperationList::dunderLen,
390 "Returns the number of operations in the list.");
391}
392
393nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
394 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
398}
400size_t PyOpOperand::getOperandNumber() const {
401 return mlirOpOperandGetOperandNumber(opOperand);
402}
403
404void PyOpOperand::bind(nb::module_ &m) {
405 nb::class_<PyOpOperand>(m, "OpOperand")
406 .def_prop_ro("owner", &PyOpOperand::getOwner,
407 "Returns the operation that owns this operand.")
408 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
409 "Returns the operand number in the owning operation.");
410}
411
412nb::typed<nb::object, PyOpOperand> PyOpOperandIterator::dunderNext() {
413 if (mlirOpOperandIsNull(opOperand)) {
414 PyErr_SetNone(PyExc_StopIteration);
415 // python functions should return NULL after setting any exception
416 return nb::object();
417 }
418
419 PyOpOperand returnOpOperand(opOperand);
420 opOperand = mlirOpOperandGetNextUse(opOperand);
421 return nb::cast(returnOpOperand);
422}
423
424void PyOpOperandIterator::bind(nb::module_ &m) {
425 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
426 .def("__iter__", &PyOpOperandIterator::dunderIter,
427 "Returns an iterator over operands.")
428 .def("__next__", &PyOpOperandIterator::dunderNext,
429 "Returns the next operand in the iteration.");
430}
432//------------------------------------------------------------------------------
433// PyThreadPool
434//------------------------------------------------------------------------------
435
437
439 if (threadPool.ptr)
440 mlirLlvmThreadPoolDestroy(threadPool);
441}
444 return mlirLlvmThreadPoolGetMaxConcurrency(threadPool);
445}
446
447std::string PyThreadPool::_mlir_thread_pool_ptr() const {
448 std::stringstream ss;
449 ss << threadPool.ptr;
450 return ss.str();
451}
453//------------------------------------------------------------------------------
454// PyMlirContext
455//------------------------------------------------------------------------------
456
457PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
458 nb::gil_scoped_acquire acquire;
459 nb::ft_lock_guard lock(live_contexts_mutex);
460 auto &liveContexts = getLiveContexts();
461 liveContexts[context.ptr] = this;
462}
463
465 // Note that the only public way to construct an instance is via the
466 // forContext method, which always puts the associated handle into
467 // liveContexts.
468 nb::gil_scoped_acquire acquire;
469 {
470 nb::ft_lock_guard lock(live_contexts_mutex);
471 getLiveContexts().erase(context.ptr);
472 }
473 mlirContextDestroy(context);
474}
477 return PyMlirContextRef(this, nb::cast(this));
478}
480nb::object PyMlirContext::getCapsule() {
481 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
482}
483
484nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
485 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
486 if (mlirContextIsNull(rawContext))
487 throw nb::python_error();
488 return forContext(rawContext).releaseObject();
489}
490
491PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
492 nb::gil_scoped_acquire acquire;
493 nb::ft_lock_guard lock(live_contexts_mutex);
494 auto &liveContexts = getLiveContexts();
495 auto it = liveContexts.find(context.ptr);
496 if (it == liveContexts.end()) {
497 // Create.
498 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
499 nb::object pyRef = nb::cast(unownedContextWrapper);
500 assert(pyRef && "cast to nb::object failed");
501 liveContexts[context.ptr] = unownedContextWrapper;
502 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
503 }
504 // Use existing.
505 nb::object pyRef = nb::cast(it->second);
506 return PyMlirContextRef(it->second, std::move(pyRef));
507}
508
509nb::ft_mutex PyMlirContext::live_contexts_mutex;
510
511PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
512 static LiveContextMap liveContexts;
513 return liveContexts;
514}
515
517 nb::ft_lock_guard lock(live_contexts_mutex);
518 return getLiveContexts().size();
519}
521nb::object PyMlirContext::contextEnter(nb::object context) {
522 return PyThreadContextEntry::pushContext(context);
523}
524
525void PyMlirContext::contextExit(const nb::object &excType,
526 const nb::object &excVal,
527 const nb::object &excTb) {
529}
530
531nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
532 // Note that ownership is transferred to the delete callback below by way of
533 // an explicit inc_ref (borrow).
534 PyDiagnosticHandler *pyHandler =
535 new PyDiagnosticHandler(get(), std::move(callback));
536 nb::object pyHandlerObject =
537 nb::cast(pyHandler, nb::rv_policy::take_ownership);
538 (void)pyHandlerObject.inc_ref();
539
540 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
541 // guaranteed to be known to pybind.
542 auto handlerCallback =
543 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
544 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
545 nb::object pyDiagnosticObject =
546 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
547
548 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
549 bool result = false;
550 {
551 // Since this can be called from arbitrary C++ contexts, always get the
552 // gil.
553 nb::gil_scoped_acquire gil;
554 try {
555 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
556 } catch (std::exception &e) {
557 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
558 e.what());
559 pyHandler->hadError = true;
560 }
561 }
562
563 pyDiagnostic->invalidate();
565 };
566 auto deleteCallback = +[](void *userData) {
567 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
568 assert(pyHandler->registeredID && "handler is not registered");
569 pyHandler->registeredID.reset();
570
571 // Decrement reference, balancing the inc_ref() above.
572 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
573 pyHandlerObject.dec_ref();
574 };
575
576 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
577 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
578 return pyHandlerObject;
579}
580
581MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
582 void *userData) {
583 auto *self = static_cast<ErrorCapture *>(userData);
584 // Check if the context requested we emit errors instead of capturing them.
585 if (self->ctx->emitErrorDiagnostics)
587
589 MlirDiagnosticSeverity::MlirDiagnosticError)
592 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
594}
595
598 if (!context) {
599 throw std::runtime_error(
600 "An MLIR function requires a Context but none was provided in the call "
601 "or from the surrounding environment. Either pass to the function with "
602 "a 'context=' argument or establish a default using 'with Context():'");
603 }
604 return *context;
605}
607//------------------------------------------------------------------------------
608// PyThreadContextEntry management
609//------------------------------------------------------------------------------
610
611std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
612 static thread_local std::vector<PyThreadContextEntry> stack;
613 return stack;
614}
615
617 auto &stack = getStack();
618 if (stack.empty())
619 return nullptr;
620 return &stack.back();
621}
622
623void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
624 nb::object insertionPoint,
625 nb::object location) {
626 auto &stack = getStack();
627 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
628 std::move(location));
629 // If the new stack has more than one entry and the context of the new top
630 // entry matches the previous, copy the insertionPoint and location from the
631 // previous entry if missing from the new top entry.
632 if (stack.size() > 1) {
633 auto &prev = *(stack.rbegin() + 1);
634 auto &current = stack.back();
635 if (current.context.is(prev.context)) {
636 // Default non-context objects from the previous entry.
637 if (!current.insertionPoint)
638 current.insertionPoint = prev.insertionPoint;
639 if (!current.location)
640 current.location = prev.location;
641 }
642 }
643}
644
646 if (!context)
647 return nullptr;
648 return nb::cast<PyMlirContext *>(context);
649}
650
652 if (!insertionPoint)
653 return nullptr;
654 return nb::cast<PyInsertionPoint *>(insertionPoint);
655}
656
658 if (!location)
659 return nullptr;
660 return nb::cast<PyLocation *>(location);
661}
662
664 auto *tos = getTopOfStack();
665 return tos ? tos->getContext() : nullptr;
666}
667
669 auto *tos = getTopOfStack();
670 return tos ? tos->getInsertionPoint() : nullptr;
671}
672
674 auto *tos = getTopOfStack();
675 return tos ? tos->getLocation() : nullptr;
676}
677
678nb::object PyThreadContextEntry::pushContext(nb::object context) {
679 push(FrameKind::Context, /*context=*/context,
680 /*insertionPoint=*/nb::object(),
681 /*location=*/nb::object());
682 return context;
683}
684
686 auto &stack = getStack();
687 if (stack.empty())
688 throw std::runtime_error("Unbalanced Context enter/exit");
689 auto &tos = stack.back();
690 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
691 throw std::runtime_error("Unbalanced Context enter/exit");
692 stack.pop_back();
693}
694
695nb::object
696PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
697 PyInsertionPoint &insertionPoint =
698 nb::cast<PyInsertionPoint &>(insertionPointObj);
699 nb::object contextObj =
700 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
701 push(FrameKind::InsertionPoint,
702 /*context=*/contextObj,
703 /*insertionPoint=*/insertionPointObj,
704 /*location=*/nb::object());
705 return insertionPointObj;
706}
707
709 auto &stack = getStack();
710 if (stack.empty())
711 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
712 auto &tos = stack.back();
713 if (tos.frameKind != FrameKind::InsertionPoint &&
714 tos.getInsertionPoint() != &insertionPoint)
715 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
716 stack.pop_back();
717}
718
719nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
720 PyLocation &location = nb::cast<PyLocation &>(locationObj);
721 nb::object contextObj = location.getContext().getObject();
722 push(FrameKind::Location, /*context=*/contextObj,
723 /*insertionPoint=*/nb::object(),
724 /*location=*/locationObj);
725 return locationObj;
726}
727
729 auto &stack = getStack();
730 if (stack.empty())
731 throw std::runtime_error("Unbalanced Location enter/exit");
732 auto &tos = stack.back();
733 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
734 throw std::runtime_error("Unbalanced Location enter/exit");
735 stack.pop_back();
736}
738//------------------------------------------------------------------------------
739// PyDiagnostic*
740//------------------------------------------------------------------------------
741
743 valid = false;
744 if (materializedNotes) {
745 for (nb::handle noteObject : *materializedNotes) {
746 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
747 note->invalidate();
748 }
749 }
750}
751
753 nb::object callback)
754 : context(context), callback(std::move(callback)) {}
755
757
759 if (!registeredID)
760 return;
761 MlirDiagnosticHandlerID localID = *registeredID;
762 mlirContextDetachDiagnosticHandler(context, localID);
763 assert(!registeredID && "should have unregistered");
764 // Not strictly necessary but keeps stale pointers from being around to cause
765 // issues.
766 context = {nullptr};
767}
768
769void PyDiagnostic::checkValid() {
770 if (!valid) {
771 throw std::invalid_argument(
772 "Diagnostic is invalid (used outside of callback)");
773 }
774}
775
777 checkValid();
778 return static_cast<PyDiagnosticSeverity>(
779 mlirDiagnosticGetSeverity(diagnostic));
780}
781
783 checkValid();
784 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
785 MlirContext context = mlirLocationGetContext(loc);
786 return PyLocation(PyMlirContext::forContext(context), loc);
787}
788
789nb::str PyDiagnostic::getMessage() {
790 checkValid();
791 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
792 PyFileAccumulator accum(fileObject, /*binary=*/false);
793 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
794 return nb::cast<nb::str>(fileObject.attr("getvalue")());
795}
796
797nb::tuple PyDiagnostic::getNotes() {
798 checkValid();
799 if (materializedNotes)
800 return *materializedNotes;
801 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
802 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
803 for (intptr_t i = 0; i < numNotes; ++i) {
804 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
805 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
806 PyTuple_SetItem(notes.ptr(), i, diagnostic.release().ptr());
807 }
808 materializedNotes = std::move(notes);
809
810 return *materializedNotes;
811}
812
814 std::vector<DiagnosticInfo> notes;
815 for (nb::handle n : getNotes())
816 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
817 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
818 std::move(notes)};
819}
821//------------------------------------------------------------------------------
822// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
823//------------------------------------------------------------------------------
824
825MlirDialect PyDialects::getDialectForKey(const std::string &key,
826 bool attrError) {
827 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
828 {key.data(), key.size()});
829 if (mlirDialectIsNull(dialect)) {
830 std::string msg = join("Dialect '", key, "' not found");
831 if (attrError)
832 throw nb::attribute_error(msg.c_str());
833 throw nb::index_error(msg.c_str());
834 }
835 return dialect;
836}
839 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
840}
841
843 MlirDialectRegistry rawRegistry =
845 if (mlirDialectRegistryIsNull(rawRegistry))
846 throw nb::python_error();
847 return PyDialectRegistry(rawRegistry);
848}
850//------------------------------------------------------------------------------
851// PyLocation
852//------------------------------------------------------------------------------
854nb::object PyLocation::getCapsule() {
855 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
856}
857
858PyLocation PyLocation::createFromCapsule(nb::object capsule) {
859 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
860 if (mlirLocationIsNull(rawLoc))
861 throw nb::python_error();
863 rawLoc);
864}
866nb::object PyLocation::contextEnter(nb::object locationObj) {
867 return PyThreadContextEntry::pushLocation(locationObj);
868}
869
870void PyLocation::contextExit(const nb::object &excType,
871 const nb::object &excVal,
872 const nb::object &excTb) {
874}
875
878 if (!location) {
879 throw std::runtime_error(
880 "An MLIR function requires a Location but none was provided in the "
881 "call or from the surrounding environment. Either pass to the function "
882 "with a 'loc=' argument or establish a default using 'with loc:'");
883 }
884 return *location;
885}
886
887//------------------------------------------------------------------------------
888// PyModule
889//------------------------------------------------------------------------------
890
891PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
892 : BaseContextObject(std::move(contextRef)), module(module) {}
893
895 nb::gil_scoped_acquire acquire;
896 auto &liveModules = getContext()->liveModules;
897 assert(liveModules.count(module.ptr) == 1 &&
898 "destroying module not in live map");
899 liveModules.erase(module.ptr);
900 mlirModuleDestroy(module);
901}
902
903PyModuleRef PyModule::forModule(MlirModule module) {
904 MlirContext context = mlirModuleGetContext(module);
905 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
906
907 nb::gil_scoped_acquire acquire;
908 auto &liveModules = contextRef->liveModules;
909 auto it = liveModules.find(module.ptr);
910 if (it == liveModules.end()) {
911 // Create.
912 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
913 // Note that the default return value policy on cast is automatic_reference,
914 // which does not take ownership (delete will not be called).
915 // Just be explicit.
916 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
917 unownedModule->handle = pyRef;
918 liveModules[module.ptr] =
919 std::make_pair(unownedModule->handle, unownedModule);
920 return PyModuleRef(unownedModule, std::move(pyRef));
921 }
922 // Use existing.
923 PyModule *existing = it->second.second;
924 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
925 return PyModuleRef(existing, std::move(pyRef));
926}
927
928nb::object PyModule::createFromCapsule(nb::object capsule) {
929 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
930 if (mlirModuleIsNull(rawModule))
931 throw nb::python_error();
932 return forModule(rawModule).releaseObject();
933}
934
935nb::object PyModule::getCapsule() {
936 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
937}
939//------------------------------------------------------------------------------
940// PyOperation
941//------------------------------------------------------------------------------
942
943PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
944 : BaseContextObject(std::move(contextRef)), operation(operation) {}
945
947 // If the operation has already been invalidated there is nothing to do.
948 if (!valid)
949 return;
950 // Otherwise, invalidate the operation when it is attached.
951 if (isAttached())
952 setInvalid();
953 else {
954 // And destroy it when it is detached, i.e. owned by Python.
955 erase();
956 }
957}
958
959namespace {
960
961// Constructs a new object of type T in-place on the Python heap, returning a
962// PyObjectRef to it, loosely analogous to std::make_shared<T>().
963template <typename T, class... Args>
964PyObjectRef<T> makeObjectRef(Args &&...args) {
965 nb::handle type = nb::type<T>();
966 nb::object instance = nb::inst_alloc(type);
967 T *ptr = nb::inst_ptr<T>(instance);
968 new (ptr) T(std::forward<Args>(args)...);
969 nb::inst_mark_ready(instance);
970 return PyObjectRef<T>(ptr, std::move(instance));
971}
972
973} // namespace
974
975PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
976 MlirOperation operation,
977 nb::object parentKeepAlive) {
978 // Create.
979 PyOperationRef unownedOperation =
980 makeObjectRef<PyOperation>(std::move(contextRef), operation);
981 unownedOperation->handle = unownedOperation.getObject();
982 if (parentKeepAlive) {
983 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
984 }
985 return unownedOperation;
986}
987
989 MlirOperation operation,
990 nb::object parentKeepAlive) {
991 return createInstance(std::move(contextRef), operation,
992 std::move(parentKeepAlive));
993}
994
996 MlirOperation operation,
997 nb::object parentKeepAlive) {
998 PyOperationRef created = createInstance(std::move(contextRef), operation,
999 std::move(parentKeepAlive));
1000 created->attached = false;
1001 return created;
1002}
1003
1005 const std::string &sourceStr,
1006 const std::string &sourceName) {
1007 PyMlirContext::ErrorCapture errors(contextRef);
1008 MlirOperation op =
1009 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1010 toMlirStringRef(sourceName));
1011 if (mlirOperationIsNull(op))
1012 throw MLIRError("Unable to parse operation assembly", errors.take());
1013 return PyOperation::createDetached(std::move(contextRef), op);
1014}
1015
1018 setDetached();
1019 parentKeepAlive = nb::object();
1020}
1021
1022MlirOperation PyOperation::get() const {
1023 checkValid();
1024 return operation;
1025}
1028 return PyOperationRef(this, nb::borrow<nb::object>(handle));
1029}
1030
1031void PyOperation::setAttached(const nb::object &parent) {
1032 assert(!attached && "operation already attached");
1033 attached = true;
1034}
1035
1037 assert(attached && "operation already detached");
1038 attached = false;
1039}
1040
1041void PyOperation::checkValid() const {
1042 if (!valid) {
1043 throw std::runtime_error("the operation has been invalidated");
1044 }
1045}
1046
1047void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1048 std::optional<int64_t> largeResourceLimit,
1049 bool enableDebugInfo, bool prettyDebugInfo,
1050 bool printGenericOpForm, bool useLocalScope,
1051 bool useNameLocAsPrefix, bool assumeVerified,
1052 nb::object fileObject, bool binary,
1053 bool skipRegions) {
1054 PyOperation &operation = getOperation();
1055 operation.checkValid();
1056 if (fileObject.is_none())
1057 fileObject = nb::module_::import_("sys").attr("stdout");
1058
1059 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1060 if (largeElementsLimit)
1061 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1062 if (largeResourceLimit)
1063 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1064 if (enableDebugInfo)
1065 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1066 /*prettyForm=*/prettyDebugInfo);
1067 if (printGenericOpForm)
1069 if (useLocalScope)
1071 if (assumeVerified)
1073 if (skipRegions)
1075 if (useNameLocAsPrefix)
1077
1078 PyFileAccumulator accum(fileObject, binary);
1079 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1080 accum.getUserData());
1082}
1083
1084void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1085 bool binary) {
1086 PyOperation &operation = getOperation();
1087 operation.checkValid();
1088 if (fileObject.is_none())
1089 fileObject = nb::module_::import_("sys").attr("stdout");
1090 PyFileAccumulator accum(fileObject, binary);
1091 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1092 accum.getUserData());
1093}
1094
1095void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1096 std::optional<int64_t> bytecodeVersion) {
1097 PyOperation &operation = getOperation();
1098 operation.checkValid();
1099 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1100
1101 if (!bytecodeVersion.has_value())
1102 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1103 accum.getUserData());
1104
1105 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1106 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1108 operation, config, accum.getCallback(), accum.getUserData());
1111 throw nb::value_error(
1112 join("Unable to honor desired bytecode version ", *bytecodeVersion)
1113 .c_str());
1114}
1115
1116void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
1117 PyWalkOrder walkOrder) {
1118 PyOperation &operation = getOperation();
1119 operation.checkValid();
1120 struct UserData {
1121 std::function<PyWalkResult(MlirOperation)> callback;
1122 bool gotException;
1123 std::string exceptionWhat;
1124 nb::object exceptionType;
1125 };
1126 UserData userData{callback, false, {}, {}};
1127 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1128 void *userData) {
1129 UserData *calleeUserData = static_cast<UserData *>(userData);
1130 try {
1131 return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
1132 } catch (nb::python_error &e) {
1133 calleeUserData->gotException = true;
1134 calleeUserData->exceptionWhat = std::string(e.what());
1135 calleeUserData->exceptionType = nb::borrow(e.type());
1136 return MlirWalkResult::MlirWalkResultInterrupt;
1137 }
1138 };
1139 mlirOperationWalk(operation, walkCallback, &userData,
1140 static_cast<MlirWalkOrder>(walkOrder));
1141 if (userData.gotException) {
1142 std::string message("Exception raised in callback: ");
1143 message.append(userData.exceptionWhat);
1144 throw std::runtime_error(message);
1145 }
1146}
1147
1148nb::object PyOperationBase::getAsm(bool binary,
1149 std::optional<int64_t> largeElementsLimit,
1150 std::optional<int64_t> largeResourceLimit,
1151 bool enableDebugInfo, bool prettyDebugInfo,
1152 bool printGenericOpForm, bool useLocalScope,
1153 bool useNameLocAsPrefix, bool assumeVerified,
1154 bool skipRegions) {
1155 nb::object fileObject;
1156 if (binary) {
1157 fileObject = nb::module_::import_("io").attr("BytesIO")();
1158 } else {
1159 fileObject = nb::module_::import_("io").attr("StringIO")();
1160 }
1161 print(/*largeElementsLimit=*/largeElementsLimit,
1162 /*largeResourceLimit=*/largeResourceLimit,
1163 /*enableDebugInfo=*/enableDebugInfo,
1164 /*prettyDebugInfo=*/prettyDebugInfo,
1165 /*printGenericOpForm=*/printGenericOpForm,
1166 /*useLocalScope=*/useLocalScope,
1167 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1168 /*assumeVerified=*/assumeVerified,
1169 /*fileObject=*/fileObject,
1170 /*binary=*/binary,
1171 /*skipRegions=*/skipRegions);
1172
1173 return fileObject.attr("getvalue")();
1174}
1175
1177 PyOperation &operation = getOperation();
1178 PyOperation &otherOp = other.getOperation();
1179 operation.checkValid();
1180 otherOp.checkValid();
1181 mlirOperationMoveAfter(operation, otherOp);
1182 operation.parentKeepAlive = otherOp.parentKeepAlive;
1183}
1184
1186 PyOperation &operation = getOperation();
1187 PyOperation &otherOp = other.getOperation();
1188 operation.checkValid();
1189 otherOp.checkValid();
1190 mlirOperationMoveBefore(operation, otherOp);
1191 operation.parentKeepAlive = otherOp.parentKeepAlive;
1192}
1193
1195 PyOperation &operation = getOperation();
1196 PyOperation &otherOp = other.getOperation();
1197 operation.checkValid();
1198 otherOp.checkValid();
1199 return mlirOperationIsBeforeInBlock(operation, otherOp);
1200}
1201
1203 PyOperation &op = getOperation();
1206 throw MLIRError("Verification failed", errors.take());
1207 return true;
1208}
1209
1210std::optional<PyOperationRef> PyOperation::getParentOperation() {
1211 checkValid();
1212 if (!isAttached())
1213 throw nb::value_error("Detached operations have no parent");
1214 MlirOperation operation = mlirOperationGetParentOperation(get());
1215 if (mlirOperationIsNull(operation))
1216 return {};
1217 return PyOperation::forOperation(getContext(), operation);
1218}
1219
1221 checkValid();
1222 std::optional<PyOperationRef> parentOperation = getParentOperation();
1223 MlirBlock block = mlirOperationGetBlock(get());
1224 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1225 assert(parentOperation && "Operation has no parent");
1226 return PyBlock{std::move(*parentOperation), block};
1227}
1228
1230 checkValid();
1231 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1232}
1233
1234nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1235 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1236 if (mlirOperationIsNull(rawOperation))
1237 throw nb::python_error();
1238 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1239 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1240 .releaseObject();
1241}
1242
1243static void maybeInsertOperation(PyOperationRef &op,
1244 const nb::object &maybeIp) {
1245 // InsertPoint active?
1246 if (!maybeIp.is(nb::cast(false))) {
1247 PyInsertionPoint *ip;
1248 if (maybeIp.is_none()) {
1250 } else {
1251 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1252 }
1253 if (ip)
1254 ip->insert(*op.get());
1255 }
1256}
1257
1258nb::object PyOperation::create(std::string_view name,
1259 std::optional<std::vector<PyType *>> results,
1260 const MlirValue *operands, size_t numOperands,
1261 std::optional<nb::dict> attributes,
1262 std::optional<std::vector<PyBlock *>> successors,
1263 int regions, PyLocation &location,
1264 const nb::object &maybeIp, bool inferType) {
1265 std::vector<MlirType> mlirResults;
1266 std::vector<MlirBlock> mlirSuccessors;
1267 std::vector<std::pair<std::string, MlirAttribute>> mlirAttributes;
1268
1269 // General parameter validation.
1270 if (regions < 0)
1271 throw nb::value_error("number of regions must be >= 0");
1272
1273 // Unpack/validate results.
1274 if (results) {
1275 mlirResults.reserve(results->size());
1276 for (PyType *result : *results) {
1277 // TODO: Verify result type originate from the same context.
1278 if (!result)
1279 throw nb::value_error("result type cannot be None");
1280 mlirResults.push_back(*result);
1281 }
1282 }
1283 // Unpack/validate attributes.
1284 if (attributes) {
1285 mlirAttributes.reserve(attributes->size());
1286 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1287 std::string key;
1288 try {
1289 key = nb::cast<std::string>(it.first);
1290 } catch (nb::cast_error &err) {
1291 std::string msg = join("Invalid attribute key (not a string) when "
1292 "attempting to create the operation \"",
1293 name, "\" (", err.what(), ")");
1294 throw nb::type_error(msg.c_str());
1295 }
1296 try {
1297 auto &attribute = nb::cast<PyAttribute &>(it.second);
1298 // TODO: Verify attribute originates from the same context.
1299 mlirAttributes.emplace_back(std::move(key), attribute);
1300 } catch (nb::cast_error &err) {
1301 std::string msg = join("Invalid attribute value for the key \"", key,
1302 "\" when attempting to create the operation \"",
1303 name, "\" (", err.what(), ")");
1304 throw nb::type_error(msg.c_str());
1305 } catch (std::runtime_error &) {
1306 // This exception seems thrown when the value is "None".
1307 std::string msg = join(
1308 "Found an invalid (`None`?) attribute value for the key \"", key,
1309 "\" when attempting to create the operation \"", name, "\"");
1310 throw std::runtime_error(msg);
1311 }
1312 }
1313 }
1314 // Unpack/validate successors.
1315 if (successors) {
1316 mlirSuccessors.reserve(successors->size());
1317 for (PyBlock *successor : *successors) {
1318 // TODO: Verify successor originate from the same context.
1319 if (!successor)
1320 throw nb::value_error("successor block cannot be None");
1321 mlirSuccessors.push_back(successor->get());
1322 }
1323 }
1324
1325 // Apply unpacked/validated to the operation state. Beyond this
1326 // point, exceptions cannot be thrown or else the state will leak.
1327 MlirOperationState state =
1328 mlirOperationStateGet(toMlirStringRef(name), location);
1329 if (numOperands > 0)
1330 mlirOperationStateAddOperands(&state, numOperands, operands);
1331 state.enableResultTypeInference = inferType;
1332 if (!mlirResults.empty())
1333 mlirOperationStateAddResults(&state, mlirResults.size(),
1334 mlirResults.data());
1335 if (!mlirAttributes.empty()) {
1336 // Note that the attribute names directly reference bytes in
1337 // mlirAttributes, so that vector must not be changed from here
1338 // on.
1339 std::vector<MlirNamedAttribute> mlirNamedAttributes;
1340 mlirNamedAttributes.reserve(mlirAttributes.size());
1341 for (const std::pair<std::string, MlirAttribute> &it : mlirAttributes)
1342 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1344 toMlirStringRef(it.first)),
1345 it.second));
1346 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1347 mlirNamedAttributes.data());
1348 }
1349 if (!mlirSuccessors.empty())
1350 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1351 mlirSuccessors.data());
1352 if (regions) {
1353 std::vector<MlirRegion> mlirRegions;
1354 mlirRegions.resize(regions);
1355 for (int i = 0; i < regions; ++i)
1356 mlirRegions[i] = mlirRegionCreate();
1357 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1358 mlirRegions.data());
1359 }
1360
1361 // Construct the operation.
1362 PyMlirContext::ErrorCapture errors(location.getContext());
1363 MlirOperation operation = mlirOperationCreate(&state);
1364 if (!operation.ptr)
1365 throw MLIRError("Operation creation failed", errors.take());
1366 PyOperationRef created =
1367 PyOperation::createDetached(location.getContext(), operation);
1368 maybeInsertOperation(created, maybeIp);
1369
1370 return created.getObject();
1371}
1372
1373nb::object PyOperation::clone(const nb::object &maybeIp) {
1374 MlirOperation clonedOperation = mlirOperationClone(operation);
1375 PyOperationRef cloned =
1376 PyOperation::createDetached(getContext(), clonedOperation);
1377 maybeInsertOperation(cloned, maybeIp);
1378
1379 return cloned->createOpView();
1380}
1381
1382nb::object PyOperation::createOpView() {
1383 checkValid();
1384 MlirIdentifier ident = mlirOperationGetName(get());
1385 MlirStringRef identStr = mlirIdentifierStr(ident);
1386 auto operationCls = PyGlobals::get().lookupOperationClass(
1387 std::string_view(identStr.data, identStr.length));
1388 if (operationCls)
1389 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1390 return nb::cast(PyOpView(getRef().getObject()));
1391}
1392
1393void PyOperation::erase() {
1395 setInvalid();
1396 mlirOperationDestroy(operation);
1397}
1398
1399void PyOpResult::bindDerived(ClassTy &c) {
1400 c.def_prop_ro(
1401 "owner",
1402 [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
1403 assert(mlirOperationEqual(self.getParentOperation()->get(),
1404 mlirOpResultGetOwner(self.get())) &&
1405 "expected the owner of the value in Python to match that in "
1406 "the IR");
1407 return self.getParentOperation()->createOpView();
1408 },
1409 "Returns the operation that produces this result.");
1410 c.def_prop_ro(
1411 "result_number",
1412 [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
1413 "Returns the position of this result in the operation's result list.");
1415
1416/// Returns the list of types of the values held by container.
1417template <typename Container>
1418static std::vector<nb::typed<nb::object, PyType>>
1419getValueTypes(Container &container, PyMlirContextRef &context) {
1420 std::vector<nb::typed<nb::object, PyType>> result;
1421 result.reserve(container.size());
1422 for (int i = 0, e = container.size(); i < e; ++i) {
1423 result.push_back(PyType(context->getRef(),
1424 mlirValueGetType(container.getElement(i).get()))
1426 }
1427 return result;
1428}
1429
1431 intptr_t length, intptr_t step)
1432 : Sliceable(startIndex,
1434 : length,
1435 step),
1436 operation(std::move(operation)) {}
1437
1438void PyOpResultList::bindDerived(ClassTy &c) {
1439 c.def_prop_ro(
1440 "types",
1441 [](PyOpResultList &self) {
1442 return getValueTypes(self, self.operation->getContext());
1443 },
1444 "Returns a list of types for all results in this result list.");
1445 c.def_prop_ro(
1446 "owner",
1447 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1448 return self.operation->createOpView();
1449 },
1450 "Returns the operation that owns this result list.");
1451}
1452
1453intptr_t PyOpResultList::getRawNumElements() {
1454 operation->checkValid();
1455 return mlirOperationGetNumResults(operation->get());
1456}
1457
1458PyOpResult PyOpResultList::getRawElement(intptr_t index) {
1459 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1460 return PyOpResult(value);
1461}
1462
1463PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
1464 intptr_t step) const {
1465 return PyOpResultList(operation, startIndex, length, step);
1466}
1468//------------------------------------------------------------------------------
1469// PyOpView
1470//------------------------------------------------------------------------------
1471
1472static void populateResultTypes(std::string_view name, nb::list resultTypeList,
1473 const nb::object &resultSegmentSpecObj,
1474 std::vector<int32_t> &resultSegmentLengths,
1475 std::vector<PyType *> &resultTypes) {
1476 resultTypes.reserve(resultTypeList.size());
1477 if (resultSegmentSpecObj.is_none()) {
1478 // Non-variadic result unpacking.
1479 size_t index = 0;
1480 for (nb::handle resultType : resultTypeList) {
1481 try {
1482 resultTypes.push_back(nb::cast<PyType *>(resultType));
1483 if (!resultTypes.back())
1484 throw nb::cast_error();
1485 } catch (nb::cast_error &err) {
1486 throw nb::value_error(join("Result ", index, " of operation \"", name,
1487 "\" must be a Type (", err.what(), ")")
1488 .c_str());
1489 }
1490 ++index;
1491 }
1492 } else {
1493 // Sized result unpacking.
1494 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1495 if (resultSegmentSpec.size() != resultTypeList.size()) {
1496 throw nb::value_error(
1497 join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
1498 " result segments but was provided ", resultTypeList.size())
1499 .c_str());
1500 }
1501 resultSegmentLengths.reserve(resultTypeList.size());
1502 for (size_t i = 0, e = resultSegmentSpec.size(); i < e; ++i) {
1503 int segmentSpec = resultSegmentSpec[i];
1504 if (segmentSpec == 1 || segmentSpec == 0) {
1505 // Unpack unary element.
1506 try {
1507 auto *resultType = nb::cast<PyType *>(resultTypeList[i]);
1508 if (resultType) {
1509 resultTypes.push_back(resultType);
1510 resultSegmentLengths.push_back(1);
1511 } else if (segmentSpec == 0) {
1512 // Allowed to be optional.
1513 resultSegmentLengths.push_back(0);
1514 } else {
1515 throw nb::value_error(
1516 join("Result ", i, " of operation \"", name,
1517 "\" must be a Type (was None and result is not optional)")
1518 .c_str());
1519 }
1520 } catch (nb::cast_error &err) {
1521 throw nb::value_error(join("Result ", i, " of operation \"", name,
1522 "\" must be a Type (", err.what(), ")")
1523 .c_str());
1524 }
1525 } else if (segmentSpec == -1) {
1526 // Unpack sequence by appending.
1527 try {
1528 if (resultTypeList[i].is_none()) {
1529 // Treat it as an empty list.
1530 resultSegmentLengths.push_back(0);
1531 } else {
1532 // Unpack the list.
1533 auto segment = nb::cast<nb::sequence>(resultTypeList[i]);
1534 for (nb::handle segmentItem : segment) {
1535 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1536 if (!resultTypes.back()) {
1537 throw nb::type_error("contained a None item");
1538 }
1539 }
1540 resultSegmentLengths.push_back(nb::len(segment));
1541 }
1542 } catch (std::exception &err) {
1543 // NOTE: Sloppy to be using a catch-all here, but there are at least
1544 // three different unrelated exceptions that can be thrown in the
1545 // above "casts". Just keep the scope above small and catch them all.
1546 throw nb::value_error(join("Result ", i, " of operation \"", name,
1547 "\" must be a Sequence of Types (",
1548 err.what(), ")")
1549 .c_str());
1550 }
1551 } else {
1552 throw nb::value_error("Unexpected segment spec");
1554 }
1555 }
1556}
1557
1558MlirValue getUniqueResult(MlirOperation operation) {
1559 auto numResults = mlirOperationGetNumResults(operation);
1560 if (numResults != 1) {
1561 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1562 throw nb::value_error(
1563 join("Cannot call .result on operation ",
1564 std::string_view(name.data, name.length), " which has ",
1565 numResults,
1566 " results (it is only valid for operations with a "
1567 "single result)")
1568 .c_str());
1569 }
1570 return mlirOperationGetResult(operation, 0);
1571}
1572
1573static MlirValue getOpResultOrValue(nb::handle operand) {
1574 if (operand.is_none()) {
1575 throw nb::value_error("contained a None item");
1576 }
1577 PyOperationBase *op;
1578 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1579 return getUniqueResult(op->getOperation());
1580 }
1581 PyOpResultList *opResultList;
1582 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1583 return getUniqueResult(opResultList->getOperation()->get());
1584 }
1585 PyValue *value;
1586 if (nb::try_cast<PyValue *>(operand, value)) {
1587 return value->get();
1588 }
1589 throw nb::value_error("is not a Value");
1590}
1591
1592nb::object PyOpView::buildGeneric(
1593 std::string_view name, std::tuple<int, bool> opRegionSpec,
1594 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1595 std::optional<nb::list> resultTypeList, nb::list operandList,
1596 std::optional<nb::dict> attributes,
1597 std::optional<std::vector<PyBlock *>> successors,
1598 std::optional<int> regions, PyLocation &location,
1599 const nb::object &maybeIp) {
1600 PyMlirContextRef context = location.getContext();
1601
1602 // Class level operation construction metadata.
1603 // Operand and result segment specs are either none, which does no
1604 // variadic unpacking, or a list of ints with segment sizes, where each
1605 // element is either a positive number (typically 1 for a scalar) or -1 to
1606 // indicate that it is derived from the length of the same-indexed operand
1607 // or result (implying that it is a list at that position).
1608 std::vector<int32_t> operandSegmentLengths;
1609 std::vector<int32_t> resultSegmentLengths;
1610
1611 // Validate/determine region count.
1612 int opMinRegionCount = std::get<0>(opRegionSpec);
1613 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1614 if (!regions) {
1615 regions = opMinRegionCount;
1616 }
1617 if (*regions < opMinRegionCount) {
1618 throw nb::value_error(join("Operation \"", name,
1619 "\" requires a minimum of ", opMinRegionCount,
1620 " regions but was built with regions=", *regions)
1621 .c_str());
1622 }
1623 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1624 throw nb::value_error(join("Operation \"", name,
1625 "\" requires a maximum of ", opMinRegionCount,
1626 " regions but was built with regions=", *regions)
1627 .c_str());
1628 }
1629
1630 // Unpack results.
1631 std::vector<PyType *> resultTypes;
1632 if (resultTypeList.has_value()) {
1633 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1634 resultSegmentLengths, resultTypes);
1635 }
1636
1637 // Unpack operands.
1638 std::vector<MlirValue> operands;
1639 operands.reserve(operands.size());
1640 size_t index = 0;
1641 if (operandSegmentSpecObj.is_none()) {
1642 // Non-sized operand unpacking.
1643 for (nb::handle operand : operandList) {
1644 try {
1645 operands.push_back(getOpResultOrValue(operand));
1646 } catch (nb::builtin_exception &err) {
1647 throw nb::value_error(join("Operand ", index, " of operation \"", name,
1648 "\" must be a Value (", err.what(), ")")
1649 .c_str());
1650 }
1651 ++index;
1652 }
1653 } else {
1654 // Sized operand unpacking.
1655 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1656 if (operandSegmentSpec.size() != operandList.size()) {
1657 throw nb::value_error(
1658 join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
1659 "operand segments but was provided ", operandList.size())
1660 .c_str());
1661 }
1662 operandSegmentLengths.reserve(operandList.size());
1663 for (size_t i = 0, e = operandSegmentSpec.size(); i < e; ++i) {
1664 int segmentSpec = operandSegmentSpec[i];
1665 if (segmentSpec == 1 || segmentSpec == 0) {
1666 // Unpack unary element.
1667 const nanobind::handle operand = operandList[i];
1668 if (!operand.is_none()) {
1669 try {
1670 operands.push_back(getOpResultOrValue(operand));
1671 } catch (nb::builtin_exception &err) {
1672 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1673 "\" must be a Value (", err.what(), ")")
1674 .c_str());
1675 }
1676
1677 operandSegmentLengths.push_back(1);
1678 } else if (segmentSpec == 0) {
1679 // Allowed to be optional.
1680 operandSegmentLengths.push_back(0);
1681 } else {
1682 throw nb::value_error(
1683 join("Operand ", i, " of operation \"", name,
1684 "\" must be a Value (was None and operand is not optional)")
1685 .c_str());
1686 }
1687 } else if (segmentSpec == -1) {
1688 // Unpack sequence by appending.
1689 try {
1690 if (operandList[i].is_none()) {
1691 // Treat it as an empty list.
1692 operandSegmentLengths.push_back(0);
1693 } else {
1694 // Unpack the list.
1695 auto segment = nb::cast<nb::sequence>(operandList[i]);
1696 for (nb::handle segmentItem : segment) {
1697 operands.push_back(getOpResultOrValue(segmentItem));
1698 }
1699 operandSegmentLengths.push_back(nb::len(segment));
1700 }
1701 } catch (std::exception &err) {
1702 // NOTE: Sloppy to be using a catch-all here, but there are at least
1703 // three different unrelated exceptions that can be thrown in the
1704 // above "casts". Just keep the scope above small and catch them all.
1705 throw nb::value_error(join("Operand ", i, " of operation \"", name,
1706 "\" must be a Sequence of Values (",
1707 err.what(), ")")
1708 .c_str());
1709 }
1710 } else {
1711 throw nb::value_error("Unexpected segment spec");
1712 }
1713 }
1714 }
1715
1716 // Merge operand/result segment lengths into attributes if needed.
1717 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1718 // Dup.
1719 if (attributes) {
1720 attributes = nb::dict(*attributes);
1721 } else {
1722 attributes = nb::dict();
1723 }
1724 if (attributes->contains("resultSegmentSizes") ||
1725 attributes->contains("operandSegmentSizes")) {
1726 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1727 "'operandSegmentSizes' attribute is unsupported. "
1728 "Use Operation.create for such low-level access.");
1729 }
1730
1731 // Add resultSegmentSizes attribute.
1732 if (!resultSegmentLengths.empty()) {
1733 MlirAttribute segmentLengthAttr =
1734 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1735 resultSegmentLengths.data());
1736 (*attributes)["resultSegmentSizes"] =
1737 PyAttribute(context, segmentLengthAttr);
1738 }
1739
1740 // Add operandSegmentSizes attribute.
1741 if (!operandSegmentLengths.empty()) {
1742 MlirAttribute segmentLengthAttr =
1743 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1744 operandSegmentLengths.data());
1745 (*attributes)["operandSegmentSizes"] =
1746 PyAttribute(context, segmentLengthAttr);
1747 }
1748 }
1749
1750 // Delegate to create.
1751 return PyOperation::create(name,
1752 /*results=*/std::move(resultTypes),
1753 /*operands=*/operands.data(),
1754 /*numOperands=*/operands.size(),
1755 /*attributes=*/std::move(attributes),
1756 /*successors=*/std::move(successors),
1757 /*regions=*/*regions, location, maybeIp,
1758 !resultTypeList);
1759}
1760
1761nb::object PyOpView::constructDerived(const nb::object &cls,
1762 const nb::object &operation) {
1763 nb::handle opViewType = nb::type<PyOpView>();
1764 nb::object instance = cls.attr("__new__")(cls);
1765 opViewType.attr("__init__")(instance, operation);
1766 return instance;
1767}
1768
1769PyOpView::PyOpView(const nb::object &operationObject)
1770 // Casting through the PyOperationBase base-class and then back to the
1771 // Operation lets us accept any PyOperationBase subclass.
1772 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1773 operationObject(operation.getRef().getObject()) {}
1775//------------------------------------------------------------------------------
1776// PyAsmState
1777//------------------------------------------------------------------------------
1778
1779PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
1780 flags = mlirOpPrintingFlagsCreate();
1781 // The OpPrintingFlags are not exposed Python side, create locally and
1782 // associate lifetime with the state.
1783 if (useLocalScope)
1785 state = mlirAsmStateCreateForValue(value, flags);
1786}
1787
1788PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
1789 flags = mlirOpPrintingFlagsCreate();
1790 // The OpPrintingFlags are not exposed Python side, create locally and
1791 // associate lifetime with the state.
1792 if (useLocalScope)
1794 state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
1795}
1797//------------------------------------------------------------------------------
1798// PyInsertionPoint.
1799//------------------------------------------------------------------------------
1800
1801PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1804 : refOperation(beforeOperationBase.getOperation().getRef()),
1805 block((*refOperation)->getBlock()) {}
1806
1808 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1809
1810void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1811 PyOperation &operation = operationBase.getOperation();
1812 if (operation.isAttached())
1813 throw nb::value_error(
1814 "Attempt to insert operation that is already attached");
1815 block.getParentOperation()->checkValid();
1816 MlirOperation beforeOp = {nullptr};
1817 if (refOperation) {
1818 // Insert before operation.
1819 (*refOperation)->checkValid();
1820 beforeOp = (*refOperation)->get();
1821 } else {
1822 // Insert at end (before null) is only valid if the block does not
1823 // already end in a known terminator (violating this will cause assertion
1824 // failures later).
1825 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1826 throw nb::index_error("Cannot insert operation at the end of a block "
1827 "that already has a terminator. Did you mean to "
1828 "use 'InsertionPoint.at_block_terminator(block)' "
1829 "versus 'InsertionPoint(block)'?");
1830 }
1832 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1833 operation.setAttached();
1834}
1835
1837 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1838 if (mlirOperationIsNull(firstOp)) {
1839 // Just insert at end.
1840 return PyInsertionPoint(block);
1841 }
1842
1843 // Insert before first op.
1845 block.getParentOperation()->getContext(), firstOp);
1846 return PyInsertionPoint{block, std::move(firstOpRef)};
1847}
1848
1850 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1851 if (mlirOperationIsNull(terminator))
1852 throw nb::value_error("Block has no terminator");
1854 block.getParentOperation()->getContext(), terminator);
1855 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1856}
1857
1859 PyOperation &operation = op.getOperation();
1860 PyBlock block = operation.getBlock();
1861 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
1862 if (mlirOperationIsNull(nextOperation))
1863 return PyInsertionPoint(block);
1865 block.getParentOperation()->getContext(), nextOperation);
1866 return PyInsertionPoint{block, std::move(nextOpRef)};
1867}
1868
1869size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
1871nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
1872 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
1873}
1874
1875void PyInsertionPoint::contextExit(const nb::object &excType,
1876 const nb::object &excVal,
1877 const nb::object &excTb) {
1879}
1881//------------------------------------------------------------------------------
1882// PyAttribute.
1883//------------------------------------------------------------------------------
1885bool PyAttribute::operator==(const PyAttribute &other) const {
1886 return mlirAttributeEqual(attr, other.attr);
1887}
1889nb::object PyAttribute::getCapsule() {
1890 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
1891}
1892
1893PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
1894 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1895 if (mlirAttributeIsNull(rawAttr))
1896 throw nb::python_error();
1897 return PyAttribute(
1899}
1900
1901nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
1902 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
1903 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1904 "mlirTypeID was expected to be non-null.");
1905 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1906 mlirTypeID, mlirAttributeGetDialect(this->get()));
1907 // nb::rv_policy::move means use std::move to move the return value
1908 // contents into a new instance that will be owned by Python.
1909 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1910 if (!typeCaster)
1911 return thisObj;
1912 return typeCaster.value()(thisObj);
1913}
1915//------------------------------------------------------------------------------
1916// PyNamedAttribute.
1917//------------------------------------------------------------------------------
1918
1919PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1920 : ownedName(new std::string(std::move(ownedName))) {
1923 toMlirStringRef(*this->ownedName)),
1924 attr);
1925}
1927//------------------------------------------------------------------------------
1928// PyType.
1929//------------------------------------------------------------------------------
1931bool PyType::operator==(const PyType &other) const {
1932 return mlirTypeEqual(type, other.type);
1933}
1935nb::object PyType::getCapsule() {
1936 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
1937}
1938
1939PyType PyType::createFromCapsule(nb::object capsule) {
1940 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1941 if (mlirTypeIsNull(rawType))
1942 throw nb::python_error();
1944 rawType);
1945}
1946
1947nb::typed<nb::object, PyType> PyType::maybeDownCast() {
1948 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
1949 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1950 "mlirTypeID was expected to be non-null.");
1951 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
1952 mlirTypeID, mlirTypeGetDialect(this->get()));
1953 // nb::rv_policy::move means use std::move to move the return value
1954 // contents into a new instance that will be owned by Python.
1955 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
1956 if (!typeCaster)
1957 return thisObj;
1958 return typeCaster.value()(thisObj);
1959}
1961//------------------------------------------------------------------------------
1962// PyTypeID.
1963//------------------------------------------------------------------------------
1965nb::object PyTypeID::getCapsule() {
1966 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
1967}
1968
1969PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
1970 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1971 if (mlirTypeIDIsNull(mlirTypeID))
1972 throw nb::python_error();
1973 return PyTypeID(mlirTypeID);
1974}
1975bool PyTypeID::operator==(const PyTypeID &other) const {
1976 return mlirTypeIDEqual(typeID, other.typeID);
1977}
1979//------------------------------------------------------------------------------
1980// PyValue and subclasses.
1981//------------------------------------------------------------------------------
1983nb::object PyValue::getCapsule() {
1984 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
1985}
1986
1987static PyOperationRef getValueOwnerRef(MlirValue value) {
1988 MlirOperation owner;
1989 if (mlirValueIsAOpResult(value))
1990 owner = mlirOpResultGetOwner(value);
1991 else if (mlirValueIsABlockArgument(value))
1993 else
1994 assert(false && "Value must be an block arg or op result.");
1995 if (mlirOperationIsNull(owner))
1996 throw nb::python_error();
1997 MlirContext ctx = mlirOperationGetContext(owner);
1999}
2000
2001nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
2003 MlirType type = mlirValueGetType(get());
2004 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2005 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2006 "mlirTypeID was expected to be non-null.");
2007 std::optional<nb::callable> valueCaster =
2009 // nb::rv_policy::move means use std::move to move the return value
2010 // contents into a new instance that will be owned by Python.
2011 nb::object thisObj;
2012 if (mlirValueIsAOpResult(value))
2013 thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
2014 else if (mlirValueIsABlockArgument(value))
2015 thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
2016 else
2017 assert(false && "Value must be an block arg or op result.");
2018 if (valueCaster)
2019 return valueCaster.value()(thisObj);
2020 return thisObj;
2021}
2022
2023PyValue PyValue::createFromCapsule(nb::object capsule) {
2024 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2025 if (mlirValueIsNull(value))
2026 throw nb::python_error();
2027 PyOperationRef ownerRef = getValueOwnerRef(value);
2028 return PyValue(ownerRef, value);
2029}
2031//------------------------------------------------------------------------------
2032// PySymbolTable.
2033//------------------------------------------------------------------------------
2034
2036 : operation(operation.getOperation().getRef()) {
2037 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2038 if (mlirSymbolTableIsNull(symbolTable)) {
2039 throw nb::type_error("Operation is not a Symbol Table.");
2040 }
2041}
2042
2043nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2044 operation->checkValid();
2045 MlirOperation symbol = mlirSymbolTableLookup(
2046 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2047 if (mlirOperationIsNull(symbol))
2048 throw nb::key_error(
2049 join("Symbol '", name, "' not in the symbol table.").c_str());
2050
2051 return PyOperation::forOperation(operation->getContext(), symbol,
2052 operation.getObject())
2053 ->createOpView();
2054}
2055
2057 operation->checkValid();
2058 symbol.getOperation().checkValid();
2059 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2060 // The operation is also erased, so we must invalidate it. There may be Python
2061 // references to this operation so we don't want to delete it from the list of
2062 // live operations here.
2063 symbol.getOperation().valid = false;
2064}
2065
2066void PySymbolTable::dunderDel(const std::string &name) {
2067 nb::object operation = dunderGetItem(name);
2068 erase(nb::cast<PyOperationBase &>(operation));
2069}
2070
2072 operation->checkValid();
2073 symbol.getOperation().checkValid();
2074 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2076 if (mlirAttributeIsNull(symbolAttr))
2077 throw nb::value_error("Expected operation to have a symbol name.");
2079 symbol.getOperation().getContext(),
2080 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2081}
2082
2084 // Op must already be a symbol.
2085 PyOperation &operation = symbol.getOperation();
2086 operation.checkValid();
2088 MlirAttribute existingNameAttr =
2089 mlirOperationGetAttributeByName(operation.get(), attrName);
2090 if (mlirAttributeIsNull(existingNameAttr))
2091 throw nb::value_error("Expected operation to have a symbol name.");
2092 return PyStringAttribute(symbol.getOperation().getContext(),
2093 existingNameAttr);
2094}
2095
2097 const std::string &name) {
2098 // Op must already be a symbol.
2099 PyOperation &operation = symbol.getOperation();
2100 operation.checkValid();
2102 MlirAttribute existingNameAttr =
2103 mlirOperationGetAttributeByName(operation.get(), attrName);
2104 if (mlirAttributeIsNull(existingNameAttr))
2105 throw nb::value_error("Expected operation to have a symbol name.");
2106 MlirAttribute newNameAttr =
2107 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2108 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2109}
2110
2112 PyOperation &operation = symbol.getOperation();
2113 operation.checkValid();
2115 MlirAttribute existingVisAttr =
2116 mlirOperationGetAttributeByName(operation.get(), attrName);
2117 if (mlirAttributeIsNull(existingVisAttr))
2118 throw nb::value_error("Expected operation to have a symbol visibility.");
2119 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2120}
2121
2123 const std::string &visibility) {
2124 if (visibility != "public" && visibility != "private" &&
2125 visibility != "nested")
2126 throw nb::value_error(
2127 "Expected visibility to be 'public', 'private' or 'nested'");
2128 PyOperation &operation = symbol.getOperation();
2129 operation.checkValid();
2131 MlirAttribute existingVisAttr =
2132 mlirOperationGetAttributeByName(operation.get(), attrName);
2133 if (mlirAttributeIsNull(existingVisAttr))
2134 throw nb::value_error("Expected operation to have a symbol visibility.");
2135 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2136 toMlirStringRef(visibility));
2137 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2138}
2139
2140void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2141 const std::string &newSymbol,
2142 PyOperationBase &from) {
2143 PyOperation &fromOperation = from.getOperation();
2144 fromOperation.checkValid();
2146 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2148
2149 throw nb::value_error("Symbol rename failed");
2150}
2151
2153 bool allSymUsesVisible,
2154 nb::object callback) {
2155 PyOperation &fromOperation = from.getOperation();
2156 fromOperation.checkValid();
2157 struct UserData {
2158 PyMlirContextRef context;
2159 nb::object callback;
2160 bool gotException;
2161 std::string exceptionWhat;
2162 nb::object exceptionType;
2163 };
2164 UserData userData{
2165 fromOperation.getContext(), std::move(callback), false, {}, {}};
2167 fromOperation.get(), allSymUsesVisible,
2168 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2169 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2170 auto pyFoundOp =
2171 PyOperation::forOperation(calleeUserData->context, foundOp);
2172 if (calleeUserData->gotException)
2173 return;
2174 try {
2175 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2176 } catch (nb::python_error &e) {
2177 calleeUserData->gotException = true;
2178 calleeUserData->exceptionWhat = e.what();
2179 calleeUserData->exceptionType = nb::borrow(e.type());
2180 }
2181 },
2182 static_cast<void *>(&userData));
2183 if (userData.gotException) {
2184 std::string message("Exception raised in callback: ");
2185 message.append(userData.exceptionWhat);
2186 throw std::runtime_error(message);
2187 }
2188}
2189
2190void PyBlockArgument::bindDerived(ClassTy &c) {
2191 c.def_prop_ro(
2192 "owner",
2193 [](PyBlockArgument &self) {
2194 return PyBlock(self.getParentOperation(),
2196 },
2197 "Returns the block that owns this argument.");
2198 c.def_prop_ro(
2199 "arg_number",
2200 [](PyBlockArgument &self) {
2201 return mlirBlockArgumentGetArgNumber(self.get());
2202 },
2203 "Returns the position of this argument in the block's argument list.");
2204 c.def(
2205 "set_type",
2206 [](PyBlockArgument &self, PyType type) {
2207 return mlirBlockArgumentSetType(self.get(), type);
2208 },
2209 "type"_a, "Sets the type of this block argument.");
2210 c.def(
2211 "set_location",
2212 [](PyBlockArgument &self, PyLocation loc) {
2214 },
2215 "loc"_a, "Sets the location of this block argument.");
2216}
2217
2219 MlirBlock block, intptr_t startIndex,
2222 length == -1 ? mlirBlockGetNumArguments(block) : length, step),
2223 operation(std::move(operation)), block(block) {}
2224
2225void PyBlockArgumentList::bindDerived(ClassTy &c) {
2226 c.def_prop_ro(
2227 "types",
2228 [](PyBlockArgumentList &self) {
2229 return getValueTypes(self, self.operation->getContext());
2230 },
2231 "Returns a list of types for all arguments in this argument list.");
2232}
2233
2234intptr_t PyBlockArgumentList::getRawNumElements() {
2235 operation->checkValid();
2236 return mlirBlockGetNumArguments(block);
2237}
2238
2239PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
2240 MlirValue argument = mlirBlockGetArgument(block, pos);
2241 return PyBlockArgument(operation, argument);
2242}
2243
2244PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
2246 intptr_t step) const {
2247 return PyBlockArgumentList(operation, block, startIndex, length, step);
2248}
2249
2251 intptr_t length, intptr_t step)
2252 : Sliceable(startIndex,
2254 : length,
2255 step),
2256 operation(operation) {}
2257
2260 mlirOperationSetOperand(operation->get(), index, value.get());
2261}
2262
2263void PyOpOperandList::bindDerived(ClassTy &c) {
2264 c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
2265 "Sets the operand at the specified index to a new value.");
2266}
2267
2268intptr_t PyOpOperandList::getRawNumElements() {
2269 operation->checkValid();
2270 return mlirOperationGetNumOperands(operation->get());
2271}
2272
2273PyValue PyOpOperandList::getRawElement(intptr_t pos) {
2274 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2275 PyOperationRef pyOwner = getValueOwnerRef(operand);
2276 return PyValue(pyOwner, operand);
2277}
2278
2279PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
2280 intptr_t step) const {
2281 return PyOpOperandList(operation, startIndex, length, step);
2282}
2284/// A list of OpOperands. Internally, these are stored as consecutive elements,
2285/// random access is cheap. The (returned) OpOperand list is associated with the
2286/// operation whose operands these are, and thus extends the lifetime of this
2287/// operation.
2288class PyOpOperands : public Sliceable<PyOpOperands, PyOpOperand> {
2289public:
2290 static constexpr const char *pyClassName = "OpOperands";
2292
2294 intptr_t length = -1, intptr_t step = 1)
2296 length == -1 ? mlirOperationGetNumOperands(operation->get())
2297 : length,
2298 step),
2299 operation(operation) {}
2300
2301private:
2302 /// Give the parent CRTP class access to hook implementations below.
2303 friend class Sliceable<PyOpOperands, PyOpOperand>;
2304
2305 intptr_t getRawNumElements() {
2306 operation->checkValid();
2307 return mlirOperationGetNumOperands(operation->get());
2308 }
2309
2310 PyOpOperand getRawElement(intptr_t pos) {
2311 MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
2312 return PyOpOperand(opOperand);
2313 }
2314
2316 return PyOpOperands(operation, startIndex, length, step);
2318
2319 PyOperationRef operation;
2320};
2321
2323 intptr_t length, intptr_t step)
2324 : Sliceable(startIndex,
2326 : length,
2327 step),
2328 operation(operation) {}
2329
2332 mlirOperationSetSuccessor(operation->get(), index, block.get());
2333}
2334
2335void PyOpSuccessors::bindDerived(ClassTy &c) {
2336 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
2337 "Sets the successor block at the specified index.");
2338}
2339
2340intptr_t PyOpSuccessors::getRawNumElements() {
2341 operation->checkValid();
2342 return mlirOperationGetNumSuccessors(operation->get());
2343}
2344
2345PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
2346 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2347 return PyBlock(operation, block);
2348}
2349
2351 intptr_t step) const {
2352 return PyOpSuccessors(operation, startIndex, length, step);
2353}
2354
2356 intptr_t startIndex, intptr_t length,
2357 intptr_t step)
2358 : Sliceable(startIndex,
2359 length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
2360 step),
2361 operation(operation), block(block) {}
2362
2363intptr_t PyBlockSuccessors::getRawNumElements() {
2364 block.checkValid();
2365 return mlirBlockGetNumSuccessors(block.get());
2366}
2367
2368PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
2369 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2370 return PyBlock(operation, block);
2371}
2372
2374 intptr_t step) const {
2375 return PyBlockSuccessors(block, operation, startIndex, length, step);
2376}
2377
2379 PyOperationRef operation,
2380 intptr_t startIndex, intptr_t length,
2381 intptr_t step)
2382 : Sliceable(startIndex,
2383 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2384 : length,
2385 step),
2386 operation(operation), block(block) {}
2387
2388intptr_t PyBlockPredecessors::getRawNumElements() {
2389 block.checkValid();
2390 return mlirBlockGetNumPredecessors(block.get());
2391}
2392
2393PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
2394 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2395 return PyBlock(operation, block);
2396}
2397
2398PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
2399 intptr_t length,
2400 intptr_t step) const {
2401 return PyBlockPredecessors(block, operation, startIndex, length, step);
2402}
2403
2404nb::typed<nb::object, PyAttribute>
2405PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
2406 MlirAttribute attr =
2408 if (mlirAttributeIsNull(attr)) {
2409 throw nb::key_error("attempt to access a non-existent attribute");
2411 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2412}
2413
2414nb::typed<nb::object, std::optional<PyAttribute>>
2415PyOpAttributeMap::get(const std::string &key, nb::object defaultValue) {
2416 MlirAttribute attr =
2418 if (mlirAttributeIsNull(attr))
2419 return defaultValue;
2420 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2421}
2422
2424 if (index < 0) {
2425 index += dunderLen();
2426 }
2427 if (index < 0 || index >= dunderLen()) {
2428 throw nb::index_error("attempt to access out of bounds attribute");
2429 }
2430 MlirNamedAttribute namedAttr =
2431 mlirOperationGetAttribute(operation->get(), index);
2432 return PyNamedAttribute(
2433 namedAttr.attribute,
2434 std::string(mlirIdentifierStr(namedAttr.name).data,
2435 mlirIdentifierStr(namedAttr.name).length));
2436}
2437
2438void PyOpAttributeMap::dunderSetItem(const std::string &name,
2439 const PyAttribute &attr) {
2440 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2441 attr);
2442}
2443
2444void PyOpAttributeMap::dunderDelItem(const std::string &name) {
2445 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2447 if (!removed)
2448 throw nb::key_error("attempt to delete a non-existent attribute");
2449}
2452 return mlirOperationGetNumAttributes(operation->get());
2453}
2454
2455bool PyOpAttributeMap::dunderContains(const std::string &name) {
2456 return !mlirAttributeIsNull(
2457 mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
2458}
2459
2461 MlirOperation op, std::function<void(MlirStringRef, MlirAttribute)> fn) {
2463 for (intptr_t i = 0; i < n; ++i) {
2466 fn(name, na.attribute);
2467 }
2468}
2469
2470void PyOpAttributeMap::bind(nb::module_ &m) {
2471 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2472 .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
2473 "Checks if an attribute with the given name exists in the map.")
2474 .def("__len__", &PyOpAttributeMap::dunderLen,
2475 "Returns the number of attributes in the map.")
2476 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
2477 "Gets an attribute by name.")
2478 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
2479 "Gets a named attribute by index.")
2480 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
2481 "Sets an attribute with the given name.")
2482 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
2483 "Deletes an attribute with the given name.")
2484 .def("get", &PyOpAttributeMap::get, nb::arg("key"),
2485 nb::arg("default") = nb::none(),
2486 "Gets an attribute by name or the default value, if it does not "
2487 "exist.")
2488 .def(
2489 "__iter__",
2490 [](PyOpAttributeMap &self) {
2491 nb::list keys;
2493 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2494 keys.append(nb::str(name.data, name.length));
2495 });
2496 return nb::iter(keys);
2497 },
2498 "Iterates over attribute names.")
2499 .def(
2500 "keys",
2501 [](PyOpAttributeMap &self) {
2502 nb::list out;
2504 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
2505 out.append(nb::str(name.data, name.length));
2506 });
2507 return out;
2508 },
2509 "Returns a list of attribute names.")
2510 .def(
2511 "values",
2512 [](PyOpAttributeMap &self) {
2513 nb::list out;
2515 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
2516 out.append(PyAttribute(self.operation->getContext(), attr)
2517 .maybeDownCast());
2518 });
2519 return out;
2520 },
2521 "Returns a list of attribute values.")
2522 .def(
2523 "items",
2524 [](PyOpAttributeMap &self) {
2525 nb::list out;
2527 self.operation->get(),
2528 [&](MlirStringRef name, MlirAttribute attr) {
2529 out.append(nb::make_tuple(
2530 nb::str(name.data, name.length),
2531 PyAttribute(self.operation->getContext(), attr)
2532 .maybeDownCast()));
2533 });
2534 return out;
2535 },
2536 "Returns a list of `(name, attribute)` tuples.");
2537}
2538
2539void PyOpAdaptor::bind(nb::module_ &m) {
2540 nb::class_<PyOpAdaptor>(m, "OpAdaptor")
2541 .def(nb::init<nb::list, PyOpAttributeMap>(),
2542 "Creates an OpAdaptor with the given operands and attributes.",
2543 "operands"_a, "attributes"_a)
2544 .def(nb::init<nb::list, PyOpView &>(),
2545 "Creates an OpAdaptor with the given operands and operation view.",
2546 "operands"_a, "opview"_a)
2547 .def_prop_ro(
2548 "operands", [](PyOpAdaptor &self) { return self.operands; },
2549 "Returns the operands of the adaptor.")
2550 .def_prop_ro(
2551 "attributes", [](PyOpAdaptor &self) { return self.attributes; },
2552 "Returns the attributes of the adaptor.");
2553}
2554
2555static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData,
2556 const char *methodName) {
2557 nb::handle targetObj(static_cast<PyObject *>(userData));
2558 if (!nb::hasattr(targetObj, methodName))
2559 return mlirLogicalResultSuccess();
2561 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
2562 bool success = nb::cast<bool>(targetObj.attr(methodName)(opView));
2564};
2565
2566static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
2567 PyMlirContext &context) {
2568 std::string opNameStr;
2569 if (opName.is_type()) {
2570 opNameStr = nb::cast<std::string>(opName.attr("OPERATION_NAME"));
2571 } else if (nb::isinstance<nb::str>(opName)) {
2572 opNameStr = nb::cast<std::string>(opName);
2573 } else {
2574 throw nb::type_error("the root argument must be a type or a string");
2575 }
2578 trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get());
2579}
2580
2581bool PyDynamicOpTrait::attach(const nb::object &opName,
2582 const nb::object &target,
2583 PyMlirContext &context) {
2584 if (!nb::hasattr(target, "verify_invariants") &&
2585 !nb::hasattr(target, "verify_region_invariants"))
2586 throw nb::type_error(
2587 "the target object must have at least one of 'verify_invariants' or "
2588 "'verify_region_invariants' methods");
2589
2591 callbacks.construct = [](void *userData) {
2592 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
2593 };
2594 callbacks.destruct = [](void *userData) {
2595 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
2596 };
2597
2598 callbacks.verifyTrait = [](MlirOperation op,
2599 void *userData) -> MlirLogicalResult {
2600 return verifyTraitByMethod(op, userData, "verify_invariants");
2601 };
2602 callbacks.verifyRegionTrait = [](MlirOperation op,
2603 void *userData) -> MlirLogicalResult {
2604 return verifyTraitByMethod(op, userData, "verify_region_invariants");
2605 };
2606
2607 // To ensure that the same dynamic trait gets the same TypeID despite how many
2608 // times `attach` is called, we store it as an attribute on the target class.
2609 constexpr const char *typeIDAttr = "_TYPE_ID";
2610 if (!nb::hasattr(target, typeIDAttr)) {
2611 nb::setattr(target, typeIDAttr,
2612 nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
2613 }
2614 MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate(
2615 nb::cast<PyTypeID>(target.attr(typeIDAttr)).get(), callbacks,
2616 static_cast<void *>(target.ptr()));
2617 return attachOpTrait(opName, trait, context);
2618}
2619
2620void PyDynamicOpTrait::bind(nb::module_ &m) {
2621 nb::class_<PyDynamicOpTrait> cls(m, "DynamicOpTrait");
2622 cls.attr("attach") = classmethod(
2623 [](const nb::object &cls, const nb::object &opName, nb::object target,
2624 DefaultingPyMlirContext context) {
2625 if (target.is_none())
2626 target = cls;
2627 return PyDynamicOpTrait::attach(opName, target, *context.get());
2628 },
2629 nb::arg("cls"), nb::arg("op_name"), nb::arg("target").none() = nb::none(),
2630 nb::arg("context").none() = nb::none(),
2631 "Attach the dynamic op trait subclass to the given operation name.");
2632}
2633
2634bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
2635 PyMlirContext &context) {
2636 MlirDynamicOpTrait trait = mlirDynamicOpTraitIsTerminatorCreate();
2637 return attachOpTrait(opName, trait, context);
2638}
2639
2640void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
2641 nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
2642 m, "IsTerminatorTrait");
2643 cls.attr("attach") = classmethod(
2644 [](const nb::object &cls, const nb::object &opName,
2645 DefaultingPyMlirContext context) {
2646 return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get());
2648 "Attach IsTerminator trait to the given operation name.", nb::arg("cls"),
2649 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2650}
2651
2652bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
2653 PyMlirContext &context) {
2654 MlirDynamicOpTrait trait = mlirDynamicOpTraitNoTerminatorCreate();
2655 return attachOpTrait(opName, trait, context);
2656}
2657
2658void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
2659 nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
2660 m, "NoTerminatorTrait");
2661 cls.attr("attach") = classmethod(
2662 [](const nb::object &cls, const nb::object &opName,
2663 DefaultingPyMlirContext context) {
2664 return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get());
2665 },
2666 "Attach NoTerminator trait to the given operation name.", nb::arg("cls"),
2667 nb::arg("op_name"), nb::arg("context").none() = nb::none());
2668}
2669
2670} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
2671} // namespace python
2672} // namespace mlir
2673
2674namespace {
2675
2676using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
2677
2678MlirLocation tracebackToLocation(MlirContext ctx) {
2679#if defined(Py_LIMITED_API)
2680 // Frame introspection C APIs are not available under the limited API.
2681 // Traceback-based auto-location is not supported; return unknown.
2682 return mlirLocationUnknownGet(ctx);
2683#else
2684 size_t framesLimit =
2686 // Use a thread_local here to avoid requiring a large amount of space.
2687 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2688 frames;
2689 size_t count = 0;
2690
2691 nb::gil_scoped_acquire acquire;
2692
2693 PyThreadState *tstate = PyThreadState_GET();
2694 PyFrameObject *next;
2695 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2696 // In the increment expression:
2697 // 1. get the next prev frame;
2698 // 2. decrement the ref count on the current frame (in order that it can get
2699 // gc'd, along with any objects in its closure and etc);
2700 // 3. set current = next.
2701 for (; pyFrame != nullptr && count < framesLimit;
2702 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2703 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2704 auto fileNameStr =
2705 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2706 std::string_view fileName(fileNameStr);
2707 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2708 continue;
2709
2710 // co_qualname and PyCode_Addr2Location added in py3.11
2711#if PY_VERSION_HEX < 0x030B00F0
2712 std::string name =
2713 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2714 std::string_view funcName(name);
2715 int startLine = PyFrame_GetLineNumber(pyFrame);
2716 MlirLocation loc = mlirLocationFileLineColGet(
2717 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2718 0);
2719#else
2720 std::string name =
2721 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2722 std::string_view funcName(name);
2723 int startLine, startCol, endLine, endCol;
2724 int lasti = PyFrame_GetLasti(pyFrame);
2725 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2726 &endCol)) {
2727 throw nb::python_error();
2728 }
2729 MlirLocation loc = mlirLocationFileLineColRangeGet(
2730 ctx, mlirStringRefCreate(fileName.data(), fileName.size()), startLine,
2731 startCol, endLine, endCol);
2732#endif
2733
2734 frames[count] = mlirLocationNameGet(
2735 ctx, mlirStringRefCreate(funcName.data(), funcName.size()), loc);
2736 ++count;
2737 }
2738 // When the loop breaks (after the last iter), current frame (if non-null)
2739 // is leaked without this.
2740 Py_XDECREF(pyFrame);
2741
2742 if (count == 0)
2743 return mlirLocationUnknownGet(ctx);
2744
2745 MlirLocation callee = frames[0];
2746 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2747 if (count == 1)
2748 return callee;
2749
2750 MlirLocation caller = frames[count - 1];
2751 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2752 for (int i = count - 2; i >= 1; i--)
2753 caller = mlirLocationCallSiteGet(frames[i], caller);
2754
2755 return mlirLocationCallSiteGet(callee, caller);
2756#endif
2757}
2758
2759PyLocation
2760maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2761 if (location.has_value())
2762 return location.value();
2763 if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2765
2766 PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2767 MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2769 return {ref, mlirLoc};
2770}
2771} // namespace
2772
2773namespace mlir {
2774namespace python {
2776
2777void populateRoot(nb::module_ &m) {
2778 m.attr("T") = nb::type_var("T");
2779 m.attr("U") = nb::type_var("U");
2780
2781 nb::class_<PyGlobals>(m, "_Globals")
2782 .def_prop_rw("dialect_search_modules",
2785 .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
2786 "module_name"_a)
2787 .def(
2788 "_check_dialect_module_loaded",
2789 [](PyGlobals &self, const std::string &dialectNamespace) {
2790 return self.loadDialectModule(dialectNamespace);
2791 },
2792 "dialect_namespace"_a)
2793 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
2794 "dialect_namespace"_a, "dialect_class"_a,
2795 "Testing hook for directly registering a dialect")
2796 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
2797 "operation_name"_a, "operation_class"_a, nb::kw_only(),
2798 "replace"_a = false,
2799 "Testing hook for directly registering an operation")
2800 .def("loc_tracebacks_enabled",
2801 [](PyGlobals &self) {
2802 return self.getTracebackLoc().locTracebacksEnabled();
2803 })
2804 .def("set_loc_tracebacks_enabled",
2805 [](PyGlobals &self, bool enabled) {
2807 })
2808 .def("loc_tracebacks_frame_limit",
2809 [](PyGlobals &self) {
2811 })
2812 .def("set_loc_tracebacks_frame_limit",
2813 [](PyGlobals &self, std::optional<int> n) {
2816 })
2817 .def("register_traceback_file_inclusion",
2818 [](PyGlobals &self, const std::string &filename) {
2820 })
2821 .def("register_traceback_file_exclusion",
2822 [](PyGlobals &self, const std::string &filename) {
2824 });
2825
2826 // Aside from making the globals accessible to python, having python manage
2827 // it is necessary to make sure it is destroyed (and releases its python
2828 // resources) properly.
2829 m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
2830
2831 // Registration decorators.
2832 m.def(
2833 "register_dialect",
2834 [](nb::type_object pyClass) {
2835 std::string dialectNamespace =
2836 nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
2837 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
2838 return pyClass;
2839 },
2840 "dialect_class"_a,
2841 "Class decorator for registering a custom Dialect wrapper");
2842 m.def(
2843 "register_operation",
2844 [](const nb::type_object &dialectClass, bool replace) -> nb::object {
2845 return nb::cpp_function(
2846 [dialectClass,
2847 replace](nb::type_object opClass) -> nb::type_object {
2848 std::string operationName =
2849 nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
2850 PyGlobals::get().registerOperationImpl(operationName, opClass,
2851 replace);
2852 // Dict-stuff the new opClass by name onto the dialect class.
2853 nb::object opClassName = opClass.attr("__name__");
2854 dialectClass.attr(opClassName) = opClass;
2855 return opClass;
2856 });
2857 },
2858 // clang-format off
2859 nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
2860 "-> typing.Callable[[type[T]], type[T]]"),
2861 // clang-format on
2862 "dialect_class"_a, nb::kw_only(), "replace"_a = false,
2863 "Produce a class decorator for registering an Operation class as part of "
2864 "a dialect");
2865 m.def(
2866 "register_op_adaptor",
2867 [](const nb::type_object &opClass, bool replace) -> nb::object {
2868 return nb::cpp_function(
2869 [opClass,
2870 replace](nb::type_object adaptorClass) -> nb::type_object {
2871 std::string operationName =
2872 nb::cast<std::string>(adaptorClass.attr("OPERATION_NAME"));
2873 PyGlobals::get().registerOpAdaptorImpl(operationName,
2874 adaptorClass, replace);
2875 // Dict-stuff the new adaptorClass by name onto the opClass.
2876 opClass.attr("Adaptor") = adaptorClass;
2877 return adaptorClass;
2878 });
2879 },
2880 // clang-format off
2881 nb::sig("def register_op_adaptor(op_class: type, *, replace: bool = False) "
2882 "-> typing.Callable[[type[T]], type[T]]"),
2883 // clang-format on
2884 "op_class"_a, nb::kw_only(), "replace"_a = false,
2885 "Produce a class decorator for registering an OpAdaptor class for an "
2886 "operation.");
2887 m.def(
2889 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2890 return nb::cpp_function([mlirTypeID, replace](
2891 nb::callable typeCaster) -> nb::object {
2892 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
2893 return typeCaster;
2894 });
2895 },
2896 // clang-format off
2897 nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2898 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2899 // clang-format on
2900 "typeid"_a, nb::kw_only(), "replace"_a = false,
2901 "Register a type caster for casting MLIR types to custom user types.");
2902 m.def(
2904 [](PyTypeID mlirTypeID, bool replace) -> nb::object {
2905 return nb::cpp_function(
2906 [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
2907 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
2908 replace);
2909 return valueCaster;
2910 });
2911 },
2912 // clang-format off
2913 nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
2914 "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
2915 // clang-format on
2916 "typeid"_a, nb::kw_only(), "replace"_a = false,
2917 "Register a value caster for casting MLIR values to custom user values.");
2918}
2919
2920//------------------------------------------------------------------------------
2921// Populates the core exports of the 'ir' submodule.
2922//------------------------------------------------------------------------------
2923void populateIRCore(nb::module_ &m) {
2924 //----------------------------------------------------------------------------
2925 // Enums.
2926 //----------------------------------------------------------------------------
2927 nb::enum_<PyDiagnosticSeverity>(m, "DiagnosticSeverity")
2928 .value("ERROR", PyDiagnosticSeverity::Error)
2929 .value("WARNING", PyDiagnosticSeverity::Warning)
2930 .value("NOTE", PyDiagnosticSeverity::Note)
2931 .value("REMARK", PyDiagnosticSeverity::Remark);
2932
2933 nb::enum_<PyWalkOrder>(m, "WalkOrder")
2934 .value("PRE_ORDER", PyWalkOrder::PreOrder)
2935 .value("POST_ORDER", PyWalkOrder::PostOrder);
2936 nb::enum_<PyWalkResult>(m, "WalkResult")
2937 .value("ADVANCE", PyWalkResult::Advance)
2938 .value("INTERRUPT", PyWalkResult::Interrupt)
2939 .value("SKIP", PyWalkResult::Skip);
2940
2941 //----------------------------------------------------------------------------
2942 // Mapping of Diagnostics.
2943 //----------------------------------------------------------------------------
2944 nb::class_<PyDiagnostic>(m, "Diagnostic")
2945 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
2946 "Returns the severity of the diagnostic.")
2947 .def_prop_ro("location", &PyDiagnostic::getLocation,
2948 "Returns the location associated with the diagnostic.")
2949 .def_prop_ro("message", &PyDiagnostic::getMessage,
2950 "Returns the message text of the diagnostic.")
2951 .def_prop_ro("notes", &PyDiagnostic::getNotes,
2952 "Returns a tuple of attached note diagnostics.")
2953 .def(
2954 "__str__",
2955 [](PyDiagnostic &self) -> nb::str {
2956 if (!self.isValid())
2957 return nb::str("<Invalid Diagnostic>");
2958 return self.getMessage();
2959 },
2960 "Returns the diagnostic message as a string.");
2961
2962 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2963 .def(
2964 "__init__",
2966 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2967 },
2968 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
2969 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
2970 "The severity level of the diagnostic.")
2971 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
2972 "The location associated with the diagnostic.")
2973 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
2974 "The message text of the diagnostic.")
2975 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
2976 "List of attached note diagnostics.")
2977 .def(
2978 "__str__",
2979 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
2980 "Returns the diagnostic message as a string.");
2981
2982 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2983 .def("detach", &PyDiagnosticHandler::detach,
2984 "Detaches the diagnostic handler from the context.")
2985 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
2986 "Returns True if the handler is attached to a context.")
2987 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
2988 "Returns True if an error was encountered during diagnostic "
2989 "handling.")
2990 .def("__enter__", &PyDiagnosticHandler::contextEnter,
2991 "Enters the diagnostic handler as a context manager.",
2992 nb::sig("def __enter__(self, /) -> DiagnosticHandler"))
2993 .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
2994 "exc_value"_a.none(), "traceback"_a.none(),
2995 "Exits the diagnostic handler context manager.");
2996
2997 // Expose DefaultThreadPool to python
2998 nb::class_<PyThreadPool>(m, "ThreadPool")
2999 .def(
3000 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
3001 "Creates a new thread pool with default concurrency.")
3002 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
3003 "Returns the maximum number of threads in the pool.")
3004 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
3005 "Returns the raw pointer to the LLVM thread pool as a string.");
3006
3007 nb::class_<PyMlirContext>(m, "Context")
3008 .def(
3009 "__init__",
3010 [](PyMlirContext &self) {
3011 MlirContext context = mlirContextCreateWithThreading(false);
3012 new (&self) PyMlirContext(context);
3013 },
3014 R"(
3015 Creates a new MLIR context.
3016
3017 The context is the top-level container for all MLIR objects. It owns the storage
3018 for types, attributes, locations, and other core IR objects. A context can be
3019 configured to allow or disallow unregistered dialects and can have dialects
3020 loaded on-demand.)")
3021 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
3022 "Gets the number of live Context objects.")
3023 .def(
3024 "_get_context_again",
3025 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
3027 return ref.releaseObject();
3028 },
3029 "Gets another reference to the same context.")
3030 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
3031 "Gets the number of live modules owned by this context.")
3033 "Gets a capsule wrapping the MlirContext.")
3036 "Creates a Context from a capsule wrapping MlirContext.")
3037 .def("__enter__", &PyMlirContext::contextEnter,
3038 "Enters the context as a context manager.",
3039 nb::sig("def __enter__(self, /) -> Context"))
3040 .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
3041 "exc_value"_a.none(), "traceback"_a.none(),
3042 "Exits the context manager.")
3043 .def_prop_ro_static(
3044 "current",
3045 [](nb::object & /*class*/)
3046 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3048 if (!context)
3049 return {};
3050 return nb::cast(context);
3051 },
3052 nb::sig("def current(/) -> Context | None"),
3053 "Gets the Context bound to the current thread or returns None if no "
3054 "context is set.")
3055 .def_prop_ro(
3056 "dialects",
3057 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3058 "Gets a container for accessing dialects by name.")
3059 .def_prop_ro(
3060 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3061 "Alias for `dialects`.")
3062 .def(
3063 "get_dialect_descriptor",
3064 [=](PyMlirContext &self, std::string &name) {
3065 MlirDialect dialect = mlirContextGetOrLoadDialect(
3066 self.get(), {name.data(), name.size()});
3067 if (mlirDialectIsNull(dialect)) {
3068 throw nb::value_error(
3069 join("Dialect '", name, "' not found").c_str());
3070 }
3071 return PyDialectDescriptor(self.getRef(), dialect);
3072 },
3073 "dialect_name"_a,
3074 "Gets or loads a dialect by name, returning its descriptor object.")
3075 .def_prop_rw(
3076 "allow_unregistered_dialects",
3077 [](PyMlirContext &self) -> bool {
3078 return mlirContextGetAllowUnregisteredDialects(self.get());
3079 },
3080 [](PyMlirContext &self, bool value) {
3081 mlirContextSetAllowUnregisteredDialects(self.get(), value);
3082 },
3083 "Controls whether unregistered dialects are allowed in this context.")
3084 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3085 "callback"_a,
3086 "Attaches a diagnostic handler that will receive callbacks.")
3087 .def(
3088 "enable_multithreading",
3089 [](PyMlirContext &self, bool enable) {
3090 mlirContextEnableMultithreading(self.get(), enable);
3091 },
3092 "enable"_a,
3093 R"(
3094 Enables or disables multi-threading support in the context.
3095
3096 Args:
3097 enable: Whether to enable (True) or disable (False) multi-threading.
3098 )")
3099 .def(
3100 "set_thread_pool",
3101 [](PyMlirContext &self, PyThreadPool &pool) {
3102 // we should disable multi-threading first before setting
3103 // new thread pool otherwise the assert in
3104 // MLIRContext::setThreadPool will be raised.
3105 mlirContextEnableMultithreading(self.get(), false);
3106 mlirContextSetThreadPool(self.get(), pool.get());
3107 },
3108 R"(
3109 Sets a custom thread pool for the context to use.
3110
3111 Args:
3112 pool: A ThreadPool object to use for parallel operations.
3113
3114 Note:
3115 Multi-threading is automatically disabled before setting the thread pool.)")
3116 .def(
3117 "get_num_threads",
3118 [](PyMlirContext &self) {
3119 return mlirContextGetNumThreads(self.get());
3120 },
3121 "Gets the number of threads in the context's thread pool.")
3122 .def(
3123 "_mlir_thread_pool_ptr",
3124 [](PyMlirContext &self) {
3125 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3126 std::stringstream ss;
3127 ss << pool.ptr;
3128 return ss.str();
3129 },
3130 "Gets the raw pointer to the LLVM thread pool as a string.")
3131 .def(
3132 "is_registered_operation",
3133 [](PyMlirContext &self, std::string &name) {
3135 self.get(), MlirStringRef{name.data(), name.size()});
3136 },
3137 "operation_name"_a,
3138 R"(
3139 Checks whether an operation with the given name is registered.
3140
3141 Args:
3142 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3143
3144 Returns:
3145 True if the operation is registered, False otherwise.)")
3146 .def(
3147 "append_dialect_registry",
3148 [](PyMlirContext &self, PyDialectRegistry &registry) {
3149 mlirContextAppendDialectRegistry(self.get(), registry);
3150 },
3151 "registry"_a,
3152 R"(
3153 Appends the contents of a dialect registry to the context.
3154
3155 Args:
3156 registry: A DialectRegistry containing dialects to append.)")
3157 .def_prop_rw("emit_error_diagnostics",
3160 R"(
3161 Controls whether error diagnostics are emitted to diagnostic handlers.
3162
3163 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3164 .def(
3165 "load_all_available_dialects",
3166 [](PyMlirContext &self) {
3168 },
3169 R"(
3170 Loads all dialects available in the registry into the context.
3171
3172 This eagerly loads all dialects that have been registered, making them
3173 immediately available for use.)");
3174
3175 //----------------------------------------------------------------------------
3176 // Mapping of PyDialectDescriptor
3177 //----------------------------------------------------------------------------
3178 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3179 .def_prop_ro(
3180 "namespace",
3181 [](PyDialectDescriptor &self) {
3182 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3183 return nb::str(ns.data, ns.length);
3184 },
3185 "Returns the namespace of the dialect.")
3186 .def(
3187 "__repr__",
3188 [](PyDialectDescriptor &self) {
3189 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3190 std::string repr("<DialectDescriptor ");
3191 repr.append(ns.data, ns.length);
3192 repr.append(">");
3193 return repr;
3194 },
3195 nb::sig("def __repr__(self) -> str"),
3196 "Returns a string representation of the dialect descriptor.");
3197
3198 //----------------------------------------------------------------------------
3199 // Mapping of PyDialects
3200 //----------------------------------------------------------------------------
3201 nb::class_<PyDialects>(m, "Dialects")
3202 .def(
3203 "__getitem__",
3204 [=](PyDialects &self, std::string keyName) {
3205 MlirDialect dialect =
3206 self.getDialectForKey(keyName, /*attrError=*/false);
3207 nb::object descriptor =
3208 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3209 return createCustomDialectWrapper(keyName, std::move(descriptor));
3210 },
3211 "Gets a dialect by name using subscript notation.")
3212 .def(
3213 "__getattr__",
3214 [=](PyDialects &self, std::string attrName) {
3215 MlirDialect dialect =
3216 self.getDialectForKey(attrName, /*attrError=*/true);
3217 nb::object descriptor =
3218 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3219 return createCustomDialectWrapper(attrName, std::move(descriptor));
3220 },
3221 "Gets a dialect by name using attribute notation.");
3222
3223 //----------------------------------------------------------------------------
3224 // Mapping of PyDialect
3225 //----------------------------------------------------------------------------
3226 nb::class_<PyDialect>(m, "Dialect")
3227 .def(nb::init<nb::object>(), "descriptor"_a,
3228 "Creates a Dialect from a DialectDescriptor.")
3229 .def_prop_ro(
3230 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3231 "Returns the DialectDescriptor for this dialect.")
3232 .def(
3233 "__repr__",
3234 [](const nb::object &self) {
3235 auto clazz = self.attr("__class__");
3236 return nb::str("<Dialect ") +
3237 self.attr("descriptor").attr("namespace") +
3238 nb::str(" (class ") + clazz.attr("__module__") +
3239 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3240 },
3241 nb::sig("def __repr__(self) -> str"),
3242 "Returns a string representation of the dialect.");
3243
3244 //----------------------------------------------------------------------------
3245 // Mapping of PyDialectRegistry
3246 //----------------------------------------------------------------------------
3247 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3249 "Gets a capsule wrapping the MlirDialectRegistry.")
3252 "Creates a DialectRegistry from a capsule wrapping "
3253 "`MlirDialectRegistry`.")
3254 .def(nb::init<>(), "Creates a new empty dialect registry.");
3255
3256 //----------------------------------------------------------------------------
3257 // Mapping of Location
3258 //----------------------------------------------------------------------------
3259 nb::class_<PyLocation>(m, "Location")
3261 "Gets a capsule wrapping the MlirLocation.")
3263 "Creates a Location from a capsule wrapping MlirLocation.")
3264 .def("__enter__", &PyLocation::contextEnter,
3265 "Enters the location as a context manager.",
3266 nb::sig("def __enter__(self, /) -> Location"))
3267 .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
3268 "exc_value"_a.none(), "traceback"_a.none(),
3269 "Exits the location context manager.")
3270 .def(
3271 "__eq__",
3272 [](PyLocation &self, PyLocation &other) -> bool {
3273 return mlirLocationEqual(self, other);
3274 },
3275 "Compares two locations for equality.")
3276 .def(
3277 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3278 "Compares location with non-location object (always returns False).")
3279 .def_prop_ro_static(
3280 "current",
3281 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3283 if (!loc)
3284 return std::nullopt;
3285 return loc;
3286 },
3287 // clang-format off
3288 nb::sig("def current(/) -> Location | None"),
3289 // clang-format on
3290 "Gets the Location bound to the current thread or raises ValueError.")
3291 .def_static(
3292 "unknown",
3293 [](DefaultingPyMlirContext context) {
3294 return PyLocation(context->getRef(),
3295 mlirLocationUnknownGet(context->get()));
3296 },
3297 "context"_a = nb::none(),
3298 "Gets a Location representing an unknown location.")
3299 .def_static(
3300 "callsite",
3301 [](PyLocation callee, const std::vector<PyLocation> &frames,
3302 DefaultingPyMlirContext context) {
3303 if (frames.empty())
3304 throw nb::value_error("No caller frames provided.");
3305 MlirLocation caller = frames.back().get();
3306 for (size_t index = frames.size() - 1; index-- > 0;) {
3307 caller = mlirLocationCallSiteGet(frames[index].get(), caller);
3308 }
3309 return PyLocation(context->getRef(),
3310 mlirLocationCallSiteGet(callee.get(), caller));
3311 },
3312 "callee"_a, "frames"_a, "context"_a = nb::none(),
3313 "Gets a Location representing a caller and callsite.")
3314 .def("is_a_callsite", mlirLocationIsACallSite,
3315 "Returns True if this location is a CallSiteLoc.")
3316 .def_prop_ro(
3317 "callee",
3318 [](PyLocation &self) {
3319 return PyLocation(self.getContext(),
3321 },
3322 "Gets the callee location from a CallSiteLoc.")
3323 .def_prop_ro(
3324 "caller",
3325 [](PyLocation &self) {
3326 return PyLocation(self.getContext(),
3328 },
3329 "Gets the caller location from a CallSiteLoc.")
3330 .def_static(
3331 "file",
3332 [](std::string filename, int line, int col,
3333 DefaultingPyMlirContext context) {
3334 return PyLocation(
3335 context->getRef(),
3337 context->get(), toMlirStringRef(filename), line, col));
3338 },
3339 "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
3340 "Gets a Location representing a file, line and column.")
3341 .def_static(
3342 "file",
3343 [](std::string filename, int startLine, int startCol, int endLine,
3344 int endCol, DefaultingPyMlirContext context) {
3345 return PyLocation(context->getRef(),
3347 context->get(), toMlirStringRef(filename),
3348 startLine, startCol, endLine, endCol));
3349 },
3350 "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
3351 "end_col"_a, "context"_a = nb::none(),
3352 "Gets a Location representing a file, line and column range.")
3353 .def("is_a_file", mlirLocationIsAFileLineColRange,
3354 "Returns True if this location is a FileLineColLoc.")
3355 .def_prop_ro(
3356 "filename",
3357 [](PyLocation loc) {
3358 return mlirIdentifierStr(
3360 },
3361 "Gets the filename from a FileLineColLoc.")
3362 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
3363 "Gets the start line number from a `FileLineColLoc`.")
3364 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
3365 "Gets the start column number from a `FileLineColLoc`.")
3366 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
3367 "Gets the end line number from a `FileLineColLoc`.")
3368 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
3369 "Gets the end column number from a `FileLineColLoc`.")
3370 .def_static(
3371 "fused",
3372 [](const std::vector<PyLocation> &pyLocations,
3373 std::optional<PyAttribute> metadata,
3374 DefaultingPyMlirContext context) {
3375 std::vector<MlirLocation> locations;
3376 locations.reserve(pyLocations.size());
3377 for (const PyLocation &pyLocation : pyLocations)
3378 locations.push_back(pyLocation.get());
3379 MlirLocation location = mlirLocationFusedGet(
3380 context->get(), locations.size(), locations.data(),
3381 metadata ? metadata->get() : MlirAttribute{0});
3382 return PyLocation(context->getRef(), location);
3383 },
3384 "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
3385 "Gets a Location representing a fused location with optional "
3386 "metadata.")
3387 .def("is_a_fused", mlirLocationIsAFused,
3388 "Returns True if this location is a `FusedLoc`.")
3389 .def_prop_ro(
3390 "locations",
3391 [](PyLocation &self) {
3392 unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3393 std::vector<MlirLocation> locations(numLocations);
3394 if (numLocations)
3395 mlirLocationFusedGetLocations(self, locations.data());
3396 std::vector<PyLocation> pyLocations{};
3397 pyLocations.reserve(numLocations);
3398 for (unsigned i = 0; i < numLocations; ++i)
3399 pyLocations.emplace_back(self.getContext(), locations[i]);
3400 return pyLocations;
3401 },
3402 "Gets the list of locations from a `FusedLoc`.")
3403 .def_static(
3404 "name",
3405 [](std::string name, std::optional<PyLocation> childLoc,
3406 DefaultingPyMlirContext context) {
3407 return PyLocation(
3408 context->getRef(),
3410 context->get(), toMlirStringRef(name),
3411 childLoc ? childLoc->get()
3412 : mlirLocationUnknownGet(context->get())));
3413 },
3414 "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
3415 "Gets a Location representing a named location with optional child "
3416 "location.")
3417 .def("is_a_name", mlirLocationIsAName,
3418 "Returns True if this location is a `NameLoc`.")
3419 .def_prop_ro(
3420 "name_str",
3421 [](PyLocation loc) {
3423 },
3424 "Gets the name string from a `NameLoc`.")
3425 .def_prop_ro(
3426 "child_loc",
3427 [](PyLocation &self) {
3428 return PyLocation(self.getContext(),
3430 },
3431 "Gets the child location from a `NameLoc`.")
3432 .def_static(
3433 "from_attr",
3434 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3435 return PyLocation(context->getRef(),
3436 mlirLocationFromAttribute(attribute));
3437 },
3438 "attribute"_a, "context"_a = nb::none(),
3439 "Gets a Location from a `LocationAttr`.")
3440 .def_prop_ro(
3441 "context",
3442 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3443 return self.getContext().getObject();
3444 },
3445 "Context that owns the `Location`.")
3446 .def_prop_ro(
3447 "attr",
3448 [](PyLocation &self) {
3449 return PyAttribute(self.getContext(),
3451 },
3452 "Get the underlying `LocationAttr`.")
3453 .def(
3454 "emit_error",
3455 [](PyLocation &self, std::string message) {
3456 mlirEmitError(self, message.c_str());
3457 },
3458 "message"_a,
3459 R"(
3460 Emits an error diagnostic at this location.
3461
3462 Args:
3463 message: The error message to emit.)")
3464 .def(
3465 "__repr__",
3466 [](PyLocation &self) {
3467 PyPrintAccumulator printAccum;
3468 mlirLocationPrint(self, printAccum.getCallback(),
3469 printAccum.getUserData());
3470 return printAccum.join();
3471 },
3472 "Returns the assembly representation of the location.");
3473
3474 //----------------------------------------------------------------------------
3475 // Mapping of Module
3476 //----------------------------------------------------------------------------
3477 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3479 "Gets a capsule wrapping the MlirModule.")
3481 R"(
3482 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3483
3484 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3485 prevent double-frees (of the underlying `mlir::Module`).)")
3486 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3487 R"(
3488 Clears the internal MLIR module reference.
3489
3490 This is used internally to prevent double-free when ownership is transferred
3491 via the C API capsule mechanism. Not intended for normal use.)")
3492 .def_static(
3493 "parse",
3494 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3495 -> nb::typed<nb::object, PyModule> {
3496 PyMlirContext::ErrorCapture errors(context->getRef());
3497 MlirModule module = mlirModuleCreateParse(
3498 context->get(), toMlirStringRef(moduleAsm));
3499 if (mlirModuleIsNull(module))
3500 throw MLIRError("Unable to parse module assembly", errors.take());
3501 return PyModule::forModule(module).releaseObject();
3502 },
3503 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3504 .def_static(
3505 "parse",
3506 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3507 -> nb::typed<nb::object, PyModule> {
3508 PyMlirContext::ErrorCapture errors(context->getRef());
3509 MlirModule module = mlirModuleCreateParse(
3510 context->get(), toMlirStringRef(moduleAsm));
3511 if (mlirModuleIsNull(module))
3512 throw MLIRError("Unable to parse module assembly", errors.take());
3513 return PyModule::forModule(module).releaseObject();
3514 },
3515 "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
3516 .def_static(
3517 "parseFile",
3518 [](const std::string &path, DefaultingPyMlirContext context)
3519 -> nb::typed<nb::object, PyModule> {
3520 PyMlirContext::ErrorCapture errors(context->getRef());
3521 MlirModule module = mlirModuleCreateParseFromFile(
3522 context->get(), toMlirStringRef(path));
3523 if (mlirModuleIsNull(module))
3524 throw MLIRError("Unable to parse module assembly", errors.take());
3525 return PyModule::forModule(module).releaseObject();
3526 },
3527 "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
3528 .def_static(
3529 "create",
3530 [](const std::optional<PyLocation> &loc)
3531 -> nb::typed<nb::object, PyModule> {
3532 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3533 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3534 return PyModule::forModule(module).releaseObject();
3535 },
3536 "loc"_a = nb::none(), "Creates an empty module.")
3537 .def_prop_ro(
3538 "context",
3539 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3540 return self.getContext().getObject();
3541 },
3542 "Context that created the `Module`.")
3543 .def_prop_ro(
3544 "operation",
3545 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3546 return PyOperation::forOperation(self.getContext(),
3547 mlirModuleGetOperation(self.get()),
3548 self.getRef().releaseObject())
3549 .releaseObject();
3550 },
3551 "Accesses the module as an operation.")
3552 .def_prop_ro(
3553 "body",
3554 [](PyModule &self) {
3556 self.getContext(), mlirModuleGetOperation(self.get()),
3557 self.getRef().releaseObject());
3558 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3559 return returnBlock;
3560 },
3561 "Return the block for this module.")
3562 .def(
3563 "dump",
3564 [](PyModule &self) {
3566 },
3568 .def(
3569 "__str__",
3570 [](const nb::object &self) {
3571 // Defer to the operation's __str__.
3572 return self.attr("operation").attr("__str__")();
3573 },
3574 nb::sig("def __str__(self) -> str"),
3575 R"(
3576 Gets the assembly form of the operation with default options.
3577
3578 If more advanced control over the assembly formatting or I/O options is needed,
3579 use the dedicated print or get_asm method, which supports keyword arguments to
3580 customize behavior.
3581 )")
3582 .def(
3583 "__eq__",
3584 [](PyModule &self, PyModule &other) {
3585 return mlirModuleEqual(self.get(), other.get());
3586 },
3587 "other"_a, "Compares two modules for equality.")
3588 .def(
3589 "__hash__",
3590 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3591 "Returns the hash value of the module.");
3592
3593 //----------------------------------------------------------------------------
3594 // Mapping of Operation.
3595 //----------------------------------------------------------------------------
3596 nb::class_<PyOperationBase>(m, "_OperationBase")
3597 .def_prop_ro(
3599 [](PyOperationBase &self) {
3600 return self.getOperation().getCapsule();
3601 },
3602 "Gets a capsule wrapping the `MlirOperation`.")
3603 .def(
3604 "__eq__",
3605 [](PyOperationBase &self, PyOperationBase &other) {
3606 return mlirOperationEqual(self.getOperation().get(),
3607 other.getOperation().get());
3608 },
3609 "Compares two operations for equality.")
3610 .def(
3611 "__eq__",
3612 [](PyOperationBase &self, nb::object other) { return false; },
3613 "Compares operation with non-operation object (always returns "
3614 "False).")
3615 .def(
3616 "__hash__",
3617 [](PyOperationBase &self) {
3618 return mlirOperationHashValue(self.getOperation().get());
3619 },
3620 "Returns the hash value of the operation.")
3621 .def_prop_ro(
3622 "attributes",
3623 [](PyOperationBase &self) {
3624 return PyOpAttributeMap(self.getOperation().getRef());
3625 },
3626 "Returns a dictionary-like map of operation attributes.")
3627 .def_prop_ro(
3628 "context",
3629 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3630 PyOperation &concreteOperation = self.getOperation();
3631 concreteOperation.checkValid();
3632 return concreteOperation.getContext().getObject();
3633 },
3634 "Context that owns the operation.")
3635 .def_prop_ro(
3636 "name",
3637 [](PyOperationBase &self) {
3638 auto &concreteOperation = self.getOperation();
3639 concreteOperation.checkValid();
3640 MlirOperation operation = concreteOperation.get();
3641 return mlirIdentifierStr(mlirOperationGetName(operation));
3642 },
3643 "Returns the fully qualified name of the operation.")
3644 .def_prop_ro(
3645 "operands",
3646 [](PyOperationBase &self) {
3647 return PyOpOperandList(self.getOperation().getRef());
3648 },
3649 "Returns the list of operation operands.")
3650 .def_prop_ro(
3651 "op_operands",
3652 [](PyOperationBase &self) {
3653 return PyOpOperands(self.getOperation().getRef());
3654 },
3655 "Returns the list of op operands.")
3656 .def_prop_ro(
3657 "regions",
3658 [](PyOperationBase &self) {
3659 return PyRegionList(self.getOperation().getRef());
3660 },
3661 "Returns the list of operation regions.")
3662 .def_prop_ro(
3663 "results",
3664 [](PyOperationBase &self) {
3665 return PyOpResultList(self.getOperation().getRef());
3666 },
3667 "Returns the list of Operation results.")
3668 .def_prop_ro(
3669 "result",
3670 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3671 auto &operation = self.getOperation();
3672 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3673 .maybeDownCast();
3674 },
3675 "Shortcut to get an op result if it has only one (throws an error "
3676 "otherwise).")
3677 .def_prop_rw(
3678 "location",
3679 [](PyOperationBase &self) {
3680 PyOperation &operation = self.getOperation();
3681 return PyLocation(operation.getContext(),
3682 mlirOperationGetLocation(operation.get()));
3683 },
3684 [](PyOperationBase &self, const PyLocation &location) {
3685 PyOperation &operation = self.getOperation();
3686 mlirOperationSetLocation(operation.get(), location.get());
3687 },
3688 nb::for_getter("Returns the source location the operation was "
3689 "defined or derived from."),
3690 nb::for_setter("Sets the source location the operation was defined "
3691 "or derived from."))
3692 .def_prop_ro(
3693 "parent",
3694 [](PyOperationBase &self)
3695 -> std::optional<nb::typed<nb::object, PyOperation>> {
3696 auto parent = self.getOperation().getParentOperation();
3697 if (parent)
3698 return parent->getObject();
3699 return {};
3700 },
3701 "Returns the parent operation, or `None` if at top level.")
3702 .def(
3703 "__str__",
3704 [](PyOperationBase &self) {
3705 return self.getAsm(/*binary=*/false,
3706 /*largeElementsLimit=*/std::nullopt,
3707 /*largeResourceLimit=*/std::nullopt,
3708 /*enableDebugInfo=*/false,
3709 /*prettyDebugInfo=*/false,
3710 /*printGenericOpForm=*/false,
3711 /*useLocalScope=*/false,
3712 /*useNameLocAsPrefix=*/false,
3713 /*assumeVerified=*/false,
3714 /*skipRegions=*/false);
3715 },
3716 nb::sig("def __str__(self) -> str"),
3717 "Returns the assembly form of the operation.")
3718 .def("print",
3719 nb::overload_cast<PyAsmState &, nb::object, bool>(
3721 "state"_a, "file"_a = nb::none(), "binary"_a = false,
3722 R"(
3723 Prints the assembly form of the operation to a file like object.
3724
3725 Args:
3726 state: `AsmState` capturing the operation numbering and flags.
3727 file: Optional file like object to write to. Defaults to sys.stdout.
3728 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3729 .def("print",
3730 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3731 bool, bool, bool, bool, bool, bool, nb::object,
3732 bool, bool>(&PyOperationBase::print),
3733 // Careful: Lots of arguments must match up with print method.
3734 "large_elements_limit"_a = nb::none(),
3735 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3736 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3737 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3738 "assume_verified"_a = false, "file"_a = nb::none(),
3739 "binary"_a = false, "skip_regions"_a = false,
3740 R"(
3741 Prints the assembly form of the operation to a file like object.
3742
3743 Args:
3744 large_elements_limit: Whether to elide elements attributes above this
3745 number of elements. Defaults to None (no limit).
3746 large_resource_limit: Whether to elide resource attributes above this
3747 number of characters. Defaults to None (no limit). If large_elements_limit
3748 is set and this is None, the behavior will be to use large_elements_limit
3749 as large_resource_limit.
3750 enable_debug_info: Whether to print debug/location information. Defaults
3751 to False.
3752 pretty_debug_info: Whether to format debug information for easier reading
3753 by a human (warning: the result is unparseable). Defaults to False.
3754 print_generic_op_form: Whether to print the generic assembly forms of all
3755 ops. Defaults to False.
3756 use_local_scope: Whether to print in a way that is more optimized for
3757 multi-threaded access but may not be consistent with how the overall
3758 module prints.
3759 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3760 prefixes for the SSA identifiers. Defaults to False.
3761 assume_verified: By default, if not printing generic form, the verifier
3762 will be run and if it fails, generic form will be printed with a comment
3763 about failed verification. While a reasonable default for interactive use,
3764 for systematic use, it is often better for the caller to verify explicitly
3765 and report failures in a more robust fashion. Set this to True if doing this
3766 in order to avoid running a redundant verification. If the IR is actually
3767 invalid, behavior is undefined.
3768 file: The file like object to write to. Defaults to sys.stdout.
3769 binary: Whether to write bytes (True) or str (False). Defaults to False.
3770 skip_regions: Whether to skip printing regions. Defaults to False.)")
3771 .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
3772 "desired_version"_a = nb::none(),
3773 R"(
3774 Write the bytecode form of the operation to a file like object.
3775
3776 Args:
3777 file: The file like object to write to.
3778 desired_version: Optional version of bytecode to emit.
3779 Returns:
3780 The bytecode writer status.)")
3781 .def("get_asm", &PyOperationBase::getAsm,
3782 // Careful: Lots of arguments must match up with get_asm method.
3783 "binary"_a = false, "large_elements_limit"_a = nb::none(),
3784 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
3785 "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
3786 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
3787 "assume_verified"_a = false, "skip_regions"_a = false,
3788 R"(
3789 Gets the assembly form of the operation with all options available.
3790
3791 Args:
3792 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3793 False.
3794 ... others ...: See the print() method for common keyword arguments for
3795 configuring the printout.
3796 Returns:
3797 Either a bytes or str object, depending on the setting of the `binary`
3798 argument.)")
3799 .def("verify", &PyOperationBase::verify,
3800 "Verify the operation. Raises MLIRError if verification fails, and "
3801 "returns true otherwise.")
3802 .def("move_after", &PyOperationBase::moveAfter, "other"_a,
3803 "Puts self immediately after the other operation in its parent "
3804 "block.")
3805 .def("move_before", &PyOperationBase::moveBefore, "other"_a,
3806 "Puts self immediately before the other operation in its parent "
3807 "block.")
3808 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
3809 R"(
3810 Checks if this operation is before another in the same block.
3811
3812 Args:
3813 other: Another operation in the same parent block.
3814
3815 Returns:
3816 True if this operation is before `other` in the operation list of the parent block.)")
3817 .def(
3818 "clone",
3819 [](PyOperationBase &self,
3820 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3821 return self.getOperation().clone(ip);
3822 },
3823 "ip"_a = nb::none(),
3824 R"(
3825 Creates a deep copy of the operation.
3826
3827 Args:
3828 ip: Optional insertion point where the cloned operation should be inserted.
3829 If None, the current insertion point is used. If False, the operation
3830 remains detached.
3831
3832 Returns:
3833 A new Operation that is a clone of this operation.)")
3834 .def(
3835 "detach_from_parent",
3836 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3837 PyOperation &operation = self.getOperation();
3838 operation.checkValid();
3839 if (!operation.isAttached())
3840 throw nb::value_error("Detached operation has no parent.");
3841
3842 operation.detachFromParent();
3843 return operation.createOpView();
3844 },
3845 "Detaches the operation from its parent block.")
3846 .def_prop_ro(
3847 "attached",
3848 [](PyOperationBase &self) {
3849 PyOperation &operation = self.getOperation();
3850 operation.checkValid();
3851 return operation.isAttached();
3852 },
3853 "Reports if the operation is attached to its parent block.")
3854 .def(
3855 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
3856 R"(
3857 Erases the operation and frees its memory.
3858
3859 Note:
3860 After erasing, any Python references to the operation become invalid.)")
3861 .def("walk", &PyOperationBase::walk, "callback"_a,
3862 "walk_order"_a = PyWalkOrder::PostOrder,
3863 // clang-format off
3864 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
3865 // clang-format on
3866 R"(
3867 Walks the operation tree with a callback function.
3868
3869 Args:
3870 callback: A callable that takes an Operation and returns a WalkResult.
3871 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
3872
3873 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3874 .def_static(
3875 "create",
3876 [](std::string_view name,
3877 std::optional<std::vector<PyType *>> results,
3878 std::optional<std::vector<PyValue *>> operands,
3879 std::optional<nb::dict> attributes,
3880 std::optional<std::vector<PyBlock *>> successors, int regions,
3881 const std::optional<PyLocation> &location,
3882 const nb::object &maybeIp,
3883 bool inferType) -> nb::typed<nb::object, PyOperation> {
3884 // Unpack/validate operands.
3885 std::vector<MlirValue> mlirOperands;
3886 if (operands) {
3887 mlirOperands.reserve(operands->size());
3888 for (PyValue *operand : *operands) {
3889 if (!operand)
3890 throw nb::value_error("operand value cannot be None");
3891 mlirOperands.push_back(operand->get());
3892 }
3893 }
3894
3895 PyLocation pyLoc = maybeGetTracebackLocation(location);
3896 return PyOperation::create(
3897 name, results, mlirOperands.data(), mlirOperands.size(),
3898 attributes, successors, regions, pyLoc, maybeIp, inferType);
3899 },
3900 "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
3901 "attributes"_a = nb::none(), "successors"_a = nb::none(),
3902 "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
3903 "infer_type"_a = false,
3904 R"(
3905 Creates a new operation.
3906
3907 Args:
3908 name: Operation name (e.g. `dialect.operation`).
3909 results: Optional sequence of Type representing op result types.
3910 operands: Optional operands of the operation.
3911 attributes: Optional Dict of {str: Attribute}.
3912 successors: Optional List of Block for the operation's successors.
3913 regions: Number of regions to create (default = 0).
3914 location: Optional Location object (defaults to resolve from context manager).
3915 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
3916 infer_type: Whether to infer result types (default = False).
3917 Returns:
3918 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
3919 .def_static(
3920 "parse",
3921 [](const std::string &sourceStr, const std::string &sourceName,
3923 -> nb::typed<nb::object, PyOpView> {
3924 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3925 ->createOpView();
3926 },
3927 "source"_a, nb::kw_only(), "source_name"_a = "",
3928 "context"_a = nb::none(),
3929 "Parses an operation. Supports both text assembly format and binary "
3930 "bytecode format.")
3932 "Gets a capsule wrapping the MlirOperation.")
3935 "Creates an Operation from a capsule wrapping MlirOperation.")
3936 .def_prop_ro(
3937 "operation",
3938 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3939 return self;
3940 },
3941 "Returns self (the operation).")
3942 .def_prop_ro(
3943 "opview",
3944 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
3945 return self.createOpView();
3946 },
3947 R"(
3948 Returns an OpView of this operation.
3949
3950 Note:
3951 If the operation has a registered and loaded dialect then this OpView will
3952 be concrete wrapper class.)")
3953 .def_prop_ro("block", &PyOperation::getBlock,
3954 "Returns the block containing this operation.")
3955 .def_prop_ro(
3956 "successors",
3957 [](PyOperationBase &self) {
3958 return PyOpSuccessors(self.getOperation().getRef());
3959 },
3960 "Returns the list of Operation successors.")
3961 .def(
3962 "replace_uses_of_with",
3963 [](PyOperation &self, PyValue &of, PyValue &with) {
3964 mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
3965 },
3966 "of"_a, "with_"_a,
3967 "Replaces uses of the 'of' value with the 'with' value inside the "
3968 "operation.")
3969 .def("_set_invalid", &PyOperation::setInvalid,
3970 "Invalidate the operation.");
3971
3972 auto opViewClass =
3973 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3974 .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
3975 .def(
3976 "__init__",
3977 [](PyOpView *self, std::string_view name,
3978 std::tuple<int, bool> opRegionSpec,
3979 nb::object operandSegmentSpecObj,
3980 nb::object resultSegmentSpecObj,
3981 std::optional<nb::list> resultTypeList, nb::list operandList,
3982 std::optional<nb::dict> attributes,
3983 std::optional<std::vector<PyBlock *>> successors,
3984 std::optional<int> regions,
3985 const std::optional<PyLocation> &location,
3986 const nb::object &maybeIp) {
3987 PyLocation pyLoc = maybeGetTracebackLocation(location);
3989 name, opRegionSpec, operandSegmentSpecObj,
3990 resultSegmentSpecObj, resultTypeList, operandList,
3991 attributes, successors, regions, pyLoc, maybeIp));
3992 },
3993 "name"_a, "opRegionSpec"_a,
3994 "operandSegmentSpecObj"_a = nb::none(),
3995 "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
3996 "operands"_a = nb::none(), "attributes"_a = nb::none(),
3997 "successors"_a = nb::none(), "regions"_a = nb::none(),
3998 "loc"_a = nb::none(), "ip"_a = nb::none())
3999 .def_prop_ro(
4000 "operation",
4001 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
4002 return self.getOperationObject();
4003 })
4004 .def_prop_ro("opview",
4005 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
4006 return self;
4007 })
4008 .def(
4009 "__str__",
4010 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
4011 .def_prop_ro(
4012 "successors",
4013 [](PyOperationBase &self) {
4014 return PyOpSuccessors(self.getOperation().getRef());
4015 },
4016 "Returns the list of Operation successors.")
4017 .def(
4018 "_set_invalid",
4019 [](PyOpView &self) { self.getOperation().setInvalid(); },
4020 "Invalidate the operation.");
4021 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
4022 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
4023 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
4024 // It is faster to pass the operation_name, ods_regions, and
4025 // ods_operand_segments/ods_result_segments as arguments to the constructor,
4026 // rather than to access them as attributes.
4027 opViewClass.attr("build_generic") = classmethod(
4028 [](nb::handle cls, std::optional<nb::list> resultTypeList,
4029 nb::list operandList, std::optional<nb::dict> attributes,
4030 std::optional<std::vector<PyBlock *>> successors,
4031 std::optional<int> regions, std::optional<PyLocation> location,
4032 const nb::object &maybeIp) {
4033 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4034 std::tuple<int, bool> opRegionSpec =
4035 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
4036 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
4037 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
4038 PyLocation pyLoc = maybeGetTracebackLocation(location);
4039 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
4040 resultSegmentSpec, resultTypeList,
4041 operandList, attributes, successors,
4042 regions, pyLoc, maybeIp);
4043 },
4044 "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
4045 "attributes"_a = nb::none(), "successors"_a = nb::none(),
4046 "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
4047 "Builds a specific, generated OpView based on class level attributes.");
4048 opViewClass.attr("parse") = classmethod(
4049 [](const nb::object &cls, const std::string &sourceStr,
4050 const std::string &sourceName,
4051 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
4052 PyOperationRef parsed =
4053 PyOperation::parse(context->getRef(), sourceStr, sourceName);
4054
4055 // Check if the expected operation was parsed, and cast to to the
4056 // appropriate `OpView` subclass if successful.
4057 // NOTE: This accesses attributes that have been automatically added to
4058 // `OpView` subclasses, and is not intended to be used on `OpView`
4059 // directly.
4060 std::string clsOpName =
4061 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4062 MlirStringRef identifier =
4064 std::string_view parsedOpName(identifier.data, identifier.length);
4065 if (clsOpName != parsedOpName)
4066 throw MLIRError(join("Expected a '", clsOpName, "' op, got: '",
4067 parsedOpName, "'"));
4068 return PyOpView::constructDerived(cls, parsed.getObject());
4069 },
4070 "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
4071 "context"_a = nb::none(),
4072 "Parses a specific, generated OpView based on class level attributes.");
4073
4075
4076 //----------------------------------------------------------------------------
4077 // Mapping of PyRegion.
4078 //----------------------------------------------------------------------------
4079 nb::class_<PyRegion>(m, "Region")
4080 .def_prop_ro(
4081 "blocks",
4082 [](PyRegion &self) {
4083 return PyBlockList(self.getParentOperation(), self.get());
4084 },
4085 "Returns a forward-optimized sequence of blocks.")
4086 .def_prop_ro(
4087 "owner",
4088 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4089 return self.getParentOperation()->createOpView();
4090 },
4091 "Returns the operation owning this region.")
4092 .def(
4093 "__iter__",
4094 [](PyRegion &self) {
4095 self.checkValid();
4096 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
4097 return PyBlockIterator(self.getParentOperation(), firstBlock);
4098 },
4099 "Iterates over blocks in the region.")
4100 .def(
4101 "__eq__",
4102 [](PyRegion &self, PyRegion &other) {
4103 return self.get().ptr == other.get().ptr;
4104 },
4105 "Compares two regions for pointer equality.")
4106 .def(
4107 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
4108 "Compares region with non-region object (always returns False).");
4109
4110 //----------------------------------------------------------------------------
4111 // Mapping of PyBlock.
4112 //----------------------------------------------------------------------------
4113 nb::class_<PyBlock>(m, "Block")
4115 "Gets a capsule wrapping the MlirBlock.")
4116 .def_prop_ro(
4117 "owner",
4118 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4119 return self.getParentOperation()->createOpView();
4120 },
4121 "Returns the owning operation of this block.")
4122 .def_prop_ro(
4123 "region",
4124 [](PyBlock &self) {
4125 MlirRegion region = mlirBlockGetParentRegion(self.get());
4126 return PyRegion(self.getParentOperation(), region);
4127 },
4128 "Returns the owning region of this block.")
4129 .def_prop_ro(
4130 "arguments",
4131 [](PyBlock &self) {
4132 return PyBlockArgumentList(self.getParentOperation(), self.get());
4133 },
4134 "Returns a list of block arguments.")
4135 .def(
4136 "add_argument",
4137 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
4138 return PyBlockArgument(self.getParentOperation(),
4139 mlirBlockAddArgument(self.get(), type, loc));
4140 },
4141 "type"_a, "loc"_a,
4142 R"(
4143 Appends an argument of the specified type to the block.
4144
4145 Args:
4146 type: The type of the argument to add.
4147 loc: The source location for the argument.
4148
4149 Returns:
4150 The newly added block argument.)")
4151 .def(
4152 "erase_argument",
4153 [](PyBlock &self, unsigned index) {
4154 return mlirBlockEraseArgument(self.get(), index);
4155 },
4156 "index"_a,
4157 R"(
4158 Erases the argument at the specified index.
4159
4160 Args:
4161 index: The index of the argument to erase.)")
4162 .def_prop_ro(
4163 "operations",
4164 [](PyBlock &self) {
4165 return PyOperationList(self.getParentOperation(), self.get());
4166 },
4167 "Returns a forward-optimized sequence of operations.")
4168 .def_static(
4169 "create_at_start",
4170 [](PyRegion &parent, const nb::sequence &pyArgTypes,
4171 const std::optional<nb::sequence> &pyArgLocs) {
4172 parent.checkValid();
4173 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4174 mlirRegionInsertOwnedBlock(parent, 0, block);
4175 return PyBlock(parent.getParentOperation(), block);
4176 },
4177 "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
4178 "Creates and returns a new Block at the beginning of the given "
4179 "region (with given argument types and locations).")
4180 .def(
4181 "append_to",
4182 [](PyBlock &self, PyRegion &region) {
4183 MlirBlock b = self.get();
4186 mlirRegionAppendOwnedBlock(region.get(), b);
4187 },
4188 "region"_a,
4189 R"(
4190 Appends this block to a region.
4191
4192 Transfers ownership if the block is currently owned by another region.
4193
4194 Args:
4195 region: The region to append the block to.)")
4196 .def(
4197 "create_before",
4198 [](PyBlock &self, const nb::args &pyArgTypes,
4199 const std::optional<nb::sequence> &pyArgLocs) {
4200 self.checkValid();
4201 MlirBlock block =
4202 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4203 MlirRegion region = mlirBlockGetParentRegion(self.get());
4204 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4205 return PyBlock(self.getParentOperation(), block);
4206 },
4207 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4208 "Creates and returns a new Block before this block "
4209 "(with given argument types and locations).")
4210 .def(
4211 "create_after",
4212 [](PyBlock &self, const nb::args &pyArgTypes,
4213 const std::optional<nb::sequence> &pyArgLocs) {
4214 self.checkValid();
4215 MlirBlock block =
4216 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4217 MlirRegion region = mlirBlockGetParentRegion(self.get());
4218 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4219 return PyBlock(self.getParentOperation(), block);
4220 },
4221 "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
4222 "Creates and returns a new Block after this block "
4223 "(with given argument types and locations).")
4224 .def(
4225 "__iter__",
4226 [](PyBlock &self) {
4227 self.checkValid();
4228 MlirOperation firstOperation =
4229 mlirBlockGetFirstOperation(self.get());
4230 return PyOperationIterator(self.getParentOperation(),
4231 firstOperation);
4232 },
4233 "Iterates over operations in the block.")
4234 .def(
4235 "__eq__",
4236 [](PyBlock &self, PyBlock &other) {
4237 return self.get().ptr == other.get().ptr;
4238 },
4239 "Compares two blocks for pointer equality.")
4240 .def(
4241 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4242 "Compares block with non-block object (always returns False).")
4243 .def(
4244 "__hash__", [](PyBlock &self) { return hash(self.get().ptr); },
4245 "Returns the hash value of the block.")
4246 .def(
4247 "__str__",
4248 [](PyBlock &self) {
4249 self.checkValid();
4250 PyPrintAccumulator printAccum;
4251 mlirBlockPrint(self.get(), printAccum.getCallback(),
4252 printAccum.getUserData());
4253 return printAccum.join();
4254 },
4255 "Returns the assembly form of the block.")
4256 .def(
4257 "append",
4258 [](PyBlock &self, PyOperationBase &operation) {
4259 if (operation.getOperation().isAttached())
4260 operation.getOperation().detachFromParent();
4261
4262 MlirOperation mlirOperation = operation.getOperation().get();
4263 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4264 operation.getOperation().setAttached(
4265 self.getParentOperation().getObject());
4266 },
4267 "operation"_a,
4268 R"(
4269 Appends an operation to this block.
4270
4271 If the operation is currently in another block, it will be moved.
4272
4273 Args:
4274 operation: The operation to append to the block.)")
4275 .def_prop_ro(
4276 "successors",
4277 [](PyBlock &self) {
4278 return PyBlockSuccessors(self, self.getParentOperation());
4279 },
4280 "Returns the list of Block successors.")
4281 .def_prop_ro(
4282 "predecessors",
4283 [](PyBlock &self) {
4284 return PyBlockPredecessors(self, self.getParentOperation());
4285 },
4286 "Returns the list of Block predecessors.");
4287
4288 //----------------------------------------------------------------------------
4289 // Mapping of PyInsertionPoint.
4290 //----------------------------------------------------------------------------
4291
4292 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4293 .def(nb::init<PyBlock &>(), "block"_a,
4294 "Inserts after the last operation but still inside the block.")
4295 .def("__enter__", &PyInsertionPoint::contextEnter,
4296 "Enters the insertion point as a context manager.",
4297 nb::sig("def __enter__(self, /) -> InsertionPoint"))
4298 .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
4299 "exc_value"_a.none(), "traceback"_a.none(),
4300 "Exits the insertion point context manager.")
4301 .def_prop_ro_static(
4302 "current",
4303 [](nb::object & /*class*/) {
4305 if (!ip)
4306 throw nb::value_error("No current InsertionPoint");
4307 return ip;
4308 },
4309 nb::sig("def current(/) -> InsertionPoint"),
4310 "Gets the InsertionPoint bound to the current thread or raises "
4311 "ValueError if none has been set.")
4312 .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
4313 "Inserts before a referenced operation.")
4314 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
4315 R"(
4316 Creates an insertion point at the beginning of a block.
4317
4318 Args:
4319 block: The block at whose beginning operations should be inserted.
4320
4321 Returns:
4322 An InsertionPoint at the block's beginning.)")
4323 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4324 "block"_a,
4325 R"(
4326 Creates an insertion point before a block's terminator.
4327
4328 Args:
4329 block: The block whose terminator to insert before.
4330
4331 Returns:
4332 An InsertionPoint before the terminator.
4333
4334 Raises:
4335 ValueError: If the block has no terminator.)")
4336 .def_static("after", &PyInsertionPoint::after, "operation"_a,
4337 R"(
4338 Creates an insertion point immediately after an operation.
4339
4340 Args:
4341 operation: The operation after which to insert.
4342
4343 Returns:
4344 An InsertionPoint after the operation.)")
4345 .def("insert", &PyInsertionPoint::insert, "operation"_a,
4346 R"(
4347 Inserts an operation at this insertion point.
4348
4349 Args:
4350 operation: The operation to insert.)")
4351 .def_prop_ro(
4352 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4353 "Returns the block that this `InsertionPoint` points to.")
4354 .def_prop_ro(
4355 "ref_operation",
4356 [](PyInsertionPoint &self)
4357 -> std::optional<nb::typed<nb::object, PyOperation>> {
4358 auto refOperation = self.getRefOperation();
4359 if (refOperation)
4360 return refOperation->getObject();
4361 return {};
4362 },
4363 "The reference operation before which new operations are "
4364 "inserted, or None if the insertion point is at the end of "
4365 "the block.");
4366
4367 //----------------------------------------------------------------------------
4368 // Mapping of PyAttribute.
4369 //----------------------------------------------------------------------------
4370 nb::class_<PyAttribute>(m, "Attribute")
4371 // Delegate to the PyAttribute copy constructor, which will also lifetime
4372 // extend the backing context which owns the MlirAttribute.
4373 .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
4374 "Casts the passed attribute to the generic `Attribute`.")
4376 "Gets a capsule wrapping the MlirAttribute.")
4377 .def_static(
4379 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4380 .def_static(
4381 "parse",
4382 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4383 -> nb::typed<nb::object, PyAttribute> {
4384 PyMlirContext::ErrorCapture errors(context->getRef());
4385 MlirAttribute attr = mlirAttributeParseGet(
4386 context->get(), toMlirStringRef(attrSpec));
4387 if (mlirAttributeIsNull(attr))
4388 throw MLIRError("Unable to parse attribute", errors.take());
4389 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4390 },
4391 "asm"_a, "context"_a = nb::none(),
4392 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4393 "failure.")
4394 .def_prop_ro(
4395 "context",
4396 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4397 return self.getContext().getObject();
4398 },
4399 "Context that owns the `Attribute`.")
4400 .def_prop_ro(
4401 "type",
4402 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4403 return PyType(self.getContext(), mlirAttributeGetType(self))
4404 .maybeDownCast();
4405 },
4406 "Returns the type of the `Attribute`.")
4407 .def(
4408 "get_named",
4409 [](PyAttribute &self, std::string name) {
4410 return PyNamedAttribute(self, std::move(name));
4411 },
4412 nb::keep_alive<0, 1>(),
4413 R"(
4414 Binds a name to the attribute, creating a `NamedAttribute`.
4415
4416 Args:
4417 name: The name to bind to the `Attribute`.
4418
4419 Returns:
4420 A `NamedAttribute` with the given name and this attribute.)")
4421 .def(
4422 "__eq__",
4423 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4424 "Compares two attributes for equality.")
4425 .def(
4426 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4427 "Compares attribute with non-attribute object (always returns "
4428 "False).")
4429 .def(
4430 "__hash__", [](PyAttribute &self) { return hash(self.get().ptr); },
4431 "Returns the hash value of the attribute.")
4432 .def(
4433 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4435 .def(
4436 "__str__",
4437 [](PyAttribute &self) {
4438 PyPrintAccumulator printAccum;
4439 mlirAttributePrint(self, printAccum.getCallback(),
4440 printAccum.getUserData());
4441 return printAccum.join();
4442 },
4443 "Returns the assembly form of the Attribute.")
4444 .def(
4445 "__repr__",
4446 [](PyAttribute &self) {
4447 // Generally, assembly formats are not printed for __repr__ because
4448 // this can cause exceptionally long debug output and exceptions.
4449 // However, attribute values are generally considered useful and
4450 // are printed. This may need to be re-evaluated if debug dumps end
4451 // up being excessive.
4452 PyPrintAccumulator printAccum;
4453 printAccum.parts.append("Attribute(");
4454 mlirAttributePrint(self, printAccum.getCallback(),
4455 printAccum.getUserData());
4456 printAccum.parts.append(")");
4457 return printAccum.join();
4458 },
4459 "Returns a string representation of the attribute.")
4460 .def_prop_ro(
4461 "typeid",
4462 [](PyAttribute &self) {
4463 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4464 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4465 "mlirTypeID was expected to be non-null.");
4466 return PyTypeID(mlirTypeID);
4467 },
4468 "Returns the `TypeID` of the attribute.")
4469 .def(
4471 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4472 return self.maybeDownCast();
4473 },
4474 "Downcasts the attribute to a more specific attribute if possible.");
4475
4476 //----------------------------------------------------------------------------
4477 // Mapping of PyNamedAttribute
4478 //----------------------------------------------------------------------------
4479 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4480 .def(
4481 "__repr__",
4482 [](PyNamedAttribute &self) {
4483 PyPrintAccumulator printAccum;
4484 printAccum.parts.append("NamedAttribute(");
4485 printAccum.parts.append(
4486 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4487 mlirIdentifierStr(self.namedAttr.name).length));
4488 printAccum.parts.append("=");
4489 mlirAttributePrint(self.namedAttr.attribute,
4490 printAccum.getCallback(),
4491 printAccum.getUserData());
4492 printAccum.parts.append(")");
4493 return printAccum.join();
4494 },
4495 "Returns a string representation of the named attribute.")
4496 .def_prop_ro(
4497 "name",
4498 [](PyNamedAttribute &self) {
4499 return mlirIdentifierStr(self.namedAttr.name);
4500 },
4501 "The name of the `NamedAttribute` binding.")
4502 .def_prop_ro(
4503 "attr",
4504 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4505 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4506 "The underlying generic attribute of the `NamedAttribute` binding.");
4507
4508 //----------------------------------------------------------------------------
4509 // Mapping of PyType.
4510 //----------------------------------------------------------------------------
4511 nb::class_<PyType>(m, "Type")
4512 // Delegate to the PyType copy constructor, which will also lifetime
4513 // extend the backing context which owns the MlirType.
4514 .def(nb::init<PyType &>(), "cast_from_type"_a,
4515 "Casts the passed type to the generic `Type`.")
4517 "Gets a capsule wrapping the `MlirType`.")
4519 "Creates a Type from a capsule wrapping `MlirType`.")
4520 .def_static(
4521 "parse",
4522 [](std::string typeSpec,
4523 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4524 PyMlirContext::ErrorCapture errors(context->getRef());
4525 MlirType type =
4526 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4527 if (mlirTypeIsNull(type))
4528 throw MLIRError("Unable to parse type", errors.take());
4529 return PyType(context.get()->getRef(), type).maybeDownCast();
4530 },
4531 "asm"_a, "context"_a = nb::none(),
4532 R"(
4533 Parses the assembly form of a type.
4534
4535 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4536
4537 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4538 .def_prop_ro(
4539 "context",
4540 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4541 return self.getContext().getObject();
4542 },
4543 "Context that owns the `Type`.")
4544 .def(
4545 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4546 "Compares two types for equality.")
4547 .def(
4548 "__eq__", [](PyType &self, nb::object &other) { return false; },
4549 "other"_a.none(),
4550 "Compares type with non-type object (always returns False).")
4551 .def(
4552 "__hash__", [](PyType &self) { return hash(self.get().ptr); },
4553 "Returns the hash value of the `Type`.")
4554 .def(
4555 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4556 .def(
4557 "__str__",
4558 [](PyType &self) {
4559 PyPrintAccumulator printAccum;
4560 mlirTypePrint(self, printAccum.getCallback(),
4561 printAccum.getUserData());
4562 return printAccum.join();
4563 },
4564 "Returns the assembly form of the `Type`.")
4565 .def(
4566 "__repr__",
4567 [](PyType &self) {
4568 // Generally, assembly formats are not printed for __repr__ because
4569 // this can cause exceptionally long debug output and exceptions.
4570 // However, types are an exception as they typically have compact
4571 // assembly forms and printing them is useful.
4572 PyPrintAccumulator printAccum;
4573 printAccum.parts.append("Type(");
4574 mlirTypePrint(self, printAccum.getCallback(),
4575 printAccum.getUserData());
4576 printAccum.parts.append(")");
4577 return printAccum.join();
4578 },
4579 "Returns a string representation of the `Type`.")
4580 .def(
4582 [](PyType &self) -> nb::typed<nb::object, PyType> {
4583 return self.maybeDownCast();
4584 },
4585 "Downcasts the Type to a more specific `Type` if possible.")
4586 .def_prop_ro(
4587 "typeid",
4588 [](PyType &self) {
4589 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4590 if (!mlirTypeIDIsNull(mlirTypeID))
4591 return PyTypeID(mlirTypeID);
4592 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4593 throw nb::value_error(join(origRepr, " has no typeid.").c_str());
4594 },
4595 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4596 "`Type` has no "
4597 "`TypeID`.");
4598
4599 //----------------------------------------------------------------------------
4600 // Mapping of PyTypeID.
4601 //----------------------------------------------------------------------------
4602 nb::class_<PyTypeID>(m, "TypeID")
4604 "Gets a capsule wrapping the `MlirTypeID`.")
4606 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4607 // Note, this tests whether the underlying TypeIDs are the same,
4608 // not whether the wrapper MlirTypeIDs are the same, nor whether
4609 // the Python objects are the same (i.e., PyTypeID is a value type).
4610 .def(
4611 "__eq__",
4612 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4613 "Compares two `TypeID`s for equality.")
4614 .def(
4615 "__eq__",
4616 [](PyTypeID &self, const nb::object &other) { return false; },
4617 "Compares TypeID with non-TypeID object (always returns False).")
4618 // Note, this gives the hash value of the underlying TypeID, not the
4619 // hash value of the Python object, nor the hash value of the
4620 // MlirTypeID wrapper.
4621 .def(
4622 "__hash__",
4623 [](PyTypeID &self) {
4624 return static_cast<size_t>(mlirTypeIDHashValue(self));
4625 },
4626 "Returns the hash value of the `TypeID`.");
4627
4628 //----------------------------------------------------------------------------
4629 // Mapping of Value.
4630 //----------------------------------------------------------------------------
4631 m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
4632
4633 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4634 nb::sig("class Value(typing.Generic[_T])"))
4635 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
4636 "Creates a Value reference from another `Value`.")
4638 "Gets a capsule wrapping the `MlirValue`.")
4640 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4641 .def_prop_ro(
4642 "context",
4643 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4644 return self.getParentOperation()->getContext().getObject();
4645 },
4646 "Context in which the value lives.")
4647 .def(
4648 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4650 .def_prop_ro(
4651 "owner",
4652 [](PyValue &self)
4653 -> nb::typed<nb::object, std::variant<PyOpView, PyBlock>> {
4654 MlirValue v = self.get();
4655 if (mlirValueIsAOpResult(v)) {
4656 assert(mlirOperationEqual(self.getParentOperation()->get(),
4657 mlirOpResultGetOwner(self.get())) &&
4658 "expected the owner of the value in Python to match "
4659 "that in "
4660 "the IR");
4661 return self.getParentOperation()->createOpView();
4662 }
4663
4665 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4666 return nb::cast(PyBlock(self.getParentOperation(), block));
4667 }
4668
4669 assert(false && "Value must be a block argument or an op result");
4670 return nb::none();
4671 },
4672 "Returns the owner of the value (`Operation` for results, `Block` "
4673 "for "
4674 "arguments).")
4675 .def_prop_ro(
4676 "uses",
4677 [](PyValue &self) {
4678 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
4679 },
4680 "Returns an iterator over uses of this value.")
4681 .def(
4682 "__eq__",
4683 [](PyValue &self, PyValue &other) {
4684 return self.get().ptr == other.get().ptr;
4685 },
4686 "Compares two values for pointer equality.")
4687 .def(
4688 "__eq__", [](PyValue &self, nb::object other) { return false; },
4689 "Compares value with non-value object (always returns False).")
4690 .def(
4691 "__hash__", [](PyValue &self) { return hash(self.get().ptr); },
4692 "Returns the hash value of the value.")
4693 .def(
4694 "__str__",
4695 [](PyValue &self) {
4696 PyPrintAccumulator printAccum;
4697 printAccum.parts.append("Value(");
4698 mlirValuePrint(self.get(), printAccum.getCallback(),
4699 printAccum.getUserData());
4700 printAccum.parts.append(")");
4701 return printAccum.join();
4702 },
4703 R"(
4704 Returns the string form of the value.
4705
4706 If the value is a block argument, this is the assembly form of its type and the
4707 position in the argument list. If the value is an operation result, this is
4708 equivalent to printing the operation that produced it.
4709 )")
4710 .def(
4711 "get_name",
4712 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4713 PyPrintAccumulator printAccum;
4714 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4715 if (useLocalScope)
4717 if (useNameLocAsPrefix)
4719 MlirAsmState valueState =
4720 mlirAsmStateCreateForValue(self.get(), flags);
4721 mlirValuePrintAsOperand(self.get(), valueState,
4722 printAccum.getCallback(),
4723 printAccum.getUserData());
4725 mlirAsmStateDestroy(valueState);
4726 return printAccum.join();
4727 },
4728 "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
4729 R"(
4730 Returns the string form of value as an operand.
4731
4732 Args:
4733 use_local_scope: Whether to use local scope for naming.
4734 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4735
4736 Returns:
4737 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4738 .def(
4739 "get_name",
4740 [](PyValue &self, PyAsmState &state) {
4741 PyPrintAccumulator printAccum;
4742 MlirAsmState valueState = state.get();
4743 mlirValuePrintAsOperand(self.get(), valueState,
4744 printAccum.getCallback(),
4745 printAccum.getUserData());
4746 return printAccum.join();
4747 },
4748 "state"_a,
4749 "Returns the string form of value as an operand (i.e., the ValueID).")
4750 .def_prop_ro(
4751 "type",
4752 [](PyValue &self) -> nb::typed<nb::object, PyType> {
4753 return PyType(self.getParentOperation()->getContext(),
4754 mlirValueGetType(self.get()))
4755 .maybeDownCast();
4756 },
4757 "Returns the type of the value.")
4758 .def(
4759 "set_type",
4760 [](PyValue &self, const PyType &type) {
4761 mlirValueSetType(self.get(), type);
4762 },
4763 "type"_a, "Sets the type of the value.",
4764 nb::sig("def set_type(self, type: _T)"))
4765 .def(
4766 "replace_all_uses_with",
4767 [](PyValue &self, PyValue &with) {
4768 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4769 },
4770 "Replace all uses of value with the new value, updating anything in "
4771 "the IR that uses `self` to use the other value instead.")
4772 .def(
4773 "replace_all_uses_except",
4774 [](PyValue &self, PyValue &with, PyOperation &exception) {
4775 MlirOperation exceptedUser = exception.get();
4776 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4777 },
4778 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4779 .def(
4780 "replace_all_uses_except",
4781 [](PyValue &self, PyValue &with,
4782 std::vector<PyOperation> &exceptions) {
4783 // Convert Python list to a std::vector of MlirOperations
4784 std::vector<MlirOperation> exceptionOps;
4785 for (PyOperation &exception : exceptions)
4786 exceptionOps.push_back(exception);
4788 self, with, static_cast<intptr_t>(exceptionOps.size()),
4789 exceptionOps.data());
4790 },
4791 "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
4792 .def(
4794 [](PyValue &self) { return self.maybeDownCast(); },
4795 "Downcasts the `Value` to a more specific kind if possible.")
4796 .def_prop_ro(
4797 "location",
4798 [](PyValue self) {
4799 return PyLocation(
4801 mlirValueGetLocation(self));
4802 },
4803 "Returns the source location of the value.");
4804
4808
4809 nb::class_<PyAsmState>(m, "AsmState")
4810 .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
4811 R"(
4812 Creates an `AsmState` for consistent SSA value naming.
4813
4814 Args:
4815 value: The value to create state for.
4816 use_local_scope: Whether to use local scope for naming.)")
4817 .def(nb::init<PyOperationBase &, bool>(), "op"_a,
4818 "use_local_scope"_a = false,
4819 R"(
4820 Creates an AsmState for consistent SSA value naming.
4821
4822 Args:
4823 op: The operation to create state for.
4824 use_local_scope: Whether to use local scope for naming.)");
4825
4826 //----------------------------------------------------------------------------
4827 // Mapping of SymbolTable.
4828 //----------------------------------------------------------------------------
4829 nb::class_<PySymbolTable>(m, "SymbolTable")
4830 .def(nb::init<PyOperationBase &>(),
4831 R"(
4832 Creates a symbol table for an operation.
4833
4834 Args:
4835 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4836
4837 Raises:
4838 TypeError: If the operation is not a symbol table.)")
4839 .def(
4840 "__getitem__",
4841 [](PySymbolTable &self,
4842 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4843 return self.dunderGetItem(name);
4844 },
4845 R"(
4846 Looks up a symbol by name in the symbol table.
4847
4848 Args:
4849 name: The name of the symbol to look up.
4850
4851 Returns:
4852 The operation defining the symbol.
4853
4854 Raises:
4855 KeyError: If the symbol is not found.)")
4856 .def("insert", &PySymbolTable::insert, "operation"_a,
4857 R"(
4858 Inserts a symbol operation into the symbol table.
4859
4860 Args:
4861 operation: An operation with a symbol name to insert.
4862
4863 Returns:
4864 The symbol name attribute of the inserted operation.
4865
4866 Raises:
4867 ValueError: If the operation does not have a symbol name.)")
4868 .def("erase", &PySymbolTable::erase, "operation"_a,
4869 R"(
4870 Erases a symbol operation from the symbol table.
4871
4872 Args:
4873 operation: The symbol operation to erase.
4874
4875 Note:
4876 The operation is also erased from the IR and invalidated.)")
4877 .def("__delitem__", &PySymbolTable::dunderDel,
4878 "Deletes a symbol by name from the symbol table.")
4879 .def(
4880 "__contains__",
4881 [](PySymbolTable &table, const std::string &name) {
4882 return !mlirOperationIsNull(mlirSymbolTableLookup(
4883 table, mlirStringRefCreate(name.data(), name.length())));
4884 },
4885 "Checks if a symbol with the given name exists in the table.")
4886 // Static helpers.
4887 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
4888 "name"_a, "Sets the symbol name for a symbol operation.")
4889 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
4890 "Gets the symbol name from a symbol operation.")
4891 .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
4892 "Gets the visibility attribute of a symbol operation.")
4893 .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
4894 "visibility"_a,
4895 "Sets the visibility attribute of a symbol operation.")
4896 .def_static("replace_all_symbol_uses",
4897 &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
4898 "new_symbol"_a, "from_op"_a,
4899 "Replaces all uses of a symbol with a new symbol name within "
4900 "the given operation.")
4901 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4902 "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
4903 "Walks symbol tables starting from an operation with a "
4904 "callback function.");
4905
4906 // Container bindings.
4922
4923 // Debug bindings.
4925
4926 // Attribute builder getter.
4928
4929 // Extensible Dialect
4933}
4934} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
4935} // namespace python
4936} // namespace mlir
return success()
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:32
static const char kModuleParseDocstring[]
Definition IRCore.cpp:32
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:55
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:60
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:43
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:445
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:471
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1099
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:668
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1119
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:680
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:672
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1136
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:520
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:735
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:436
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:177
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:731
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:638
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1284
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1155
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1085
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:201
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:467
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:209
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:796
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:704
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:219
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:459
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:214
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:232
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1131
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:743
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1303
MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags)
Definition IR.cpp:156
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:642
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:236
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:251
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1095
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:453
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:424
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:224
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:676
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1159
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:739
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1299
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:414
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:815
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:720
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1146
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:428
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:646
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:503
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:634
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1127
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1141
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:712
MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:716
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:654
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:810
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1295
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:512
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:658
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1288
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:820
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1123
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1104
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:684
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:588
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:255
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1280
void mlirOperationRemoveFromParent(MlirOperation op)
Definition IR.cpp:640
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1091
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:805
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:205
void mlirValueDump(MlirValue value)
Definition IR.cpp:1163
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1269
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:449
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:625
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:195
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1150
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:516
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:247
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:228
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:240
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:508
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:479
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:708
void mlirTypeDump(MlirType type)
Definition IR.cpp:1274
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:801
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the value caster registration bindin...
Definition Interop.h:142
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Definition Interop.h:130
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
Accumulates into a file, either writing text (default) or binary.
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
nanobind::class_< PyRegionList > ClassTy
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:297
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:278
PyAsmState(MlirValue value, bool useLocalScope)
Definition IRCore.cpp:1774
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1001
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRCore.h:1003
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:1884
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:1880
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:1888
nanobind::typed< nanobind::object, PyAttribute > maybeDownCast()
Definition IRCore.cpp:1896
PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2213
Python wrapper for MlirBlockArgument.
Definition IRCore.h:1628
nanobind::typed< nanobind::object, PyBlock > dunderNext()
Definition IRCore.cpp:242
Blocks are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1424
PyBlock appendBlock(const nanobind::args &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Definition IRCore.cpp:298
PyBlockPredecessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2373
PyBlockSuccessors(PyBlock block, PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2350
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:182
Represents a diagnostic handler attached to the context.
Definition IRCore.h:405
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:753
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:747
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.h:417
Python class mirroring the C MlirDiagnostic struct.
Definition IRCore.h:355
Wrapper around an MlirDialectRegistry.
Definition IRCore.h:497
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:837
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRCore.h:473
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:820
static bool attach(const nanobind::object &opName, const nanobind::object &target, PyMlirContext &context)
Definition IRCore.cpp:2576
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2629
static bool attach(const nanobind::object &opName, PyMlirContext &context)
Definition IRCore.cpp:2647
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:29
void registerOpAdaptorImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds an operation adaptor class.
Definition Globals.cpp:154
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:190
bool loadDialectModule(std::string_view dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:64
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
std::optional< nanobind::object > lookupOperationClass(std::string_view operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:220
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:112
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:99
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:143
void setDialectSearchPrefixes(std::vector< std::string > newValues)
Definition Globals.h:43
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:176
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:122
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:166
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition Globals.cpp:132
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:205
std::vector< std::string > getDialectSearchPrefixes()
Get and set the list of parent modules to search for dialect implementation classes.
Definition Globals.h:39
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:832
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1805
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:1870
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1844
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:1853
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1831
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1796
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:1866
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:861
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:853
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:849
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:865
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRCore.h:306
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:511
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:479
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:526
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:520
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:211
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRCore.cpp:471
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRCore.h:245
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:516
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:475
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:1864
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:930
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:547
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:898
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:923
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRCore.h:1027
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:1914
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRCore.h:65
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRCore.h:91
void dunderSetItem(const std::string &name, const PyAttribute &attr)
Definition IRCore.cpp:2433
nanobind::typed< nanobind::object, PyAttribute > dunderGetItemNamed(const std::string &name)
Definition IRCore.cpp:2400
nanobind::typed< nanobind::object, std::optional< PyAttribute > > get(const std::string &key, nanobind::object defaultValue)
Definition IRCore.cpp:2410
static void forEachAttr(MlirOperation op, std::function< void(MlirStringRef, MlirAttribute)> fn)
Definition IRCore.cpp:2455
PyNamedAttribute dunderGetItemIndexed(intptr_t index)
Definition IRCore.cpp:2418
nanobind::typed< nanobind::object, PyOpOperand > dunderNext()
Definition IRCore.cpp:407
PyOpOperandList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2245
void dunderSetItem(intptr_t index, PyValue value)
Definition IRCore.cpp:2253
nanobind::typed< nanobind::object, PyOpView > getOwner() const
Definition IRCore.cpp:388
PyOpOperands(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2288
Sliceable< PyOpOperandList, PyOpOperand > SliceableT
Definition IRCore.cpp:2286
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1425
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:2317
void dunderSetItem(intptr_t index, PyBlock block)
Definition IRCore.cpp:2325
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRCore.h:734
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1764
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1587
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1756
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:577
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1189
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1143
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
Definition IRCore.cpp:1042
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1090
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1171
void walk(std::function< PyWalkResult(MlirOperation)> callback, PyWalkOrder walkOrder)
Definition IRCore.cpp:1111
nanobind::typed< nanobind::object, PyOpView > dunderNext()
Definition IRCore.cpp:319
Operations are exposed by the C-API as a forward-only linked list.
Definition IRCore.h:1465
nanobind::typed< nanobind::object, PyOpView > dunderGetItem(intptr_t index)
Definition IRCore.cpp:358
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, const MlirValue *operands, size_t numOperands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1253
void setInvalid()
Invalidate the operation.
Definition IRCore.h:701
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRCore.h:634
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:999
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1368
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1229
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1205
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1224
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRCore.cpp:1011
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1388
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1215
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:990
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:938
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRCore.cpp:1026
nanobind::typed< nanobind::object, PyRegion > dunderNext()
Definition IRCore.cpp:190
Regions of an op are fixed length and indexed numerically so are represented with a sequence-like con...
Definition IRCore.h:1381
PyRegionList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:209
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2066
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2030
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2106
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2051
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2147
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2038
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2135
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2078
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2061
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2091
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2117
Tracks an entry in the thread context stack.
Definition IRCore.h:124
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:663
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:691
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:703
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:668
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:611
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:714
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:673
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:658
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:606
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRCore.h:180
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRCore.h:900
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:1960
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:1970
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:1964
Wrapper around the generic MlirType.
Definition IRCore.h:874
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:876
bool operator==(const PyType &other) const
Definition IRCore.cpp:1926
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:1930
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:1934
nanobind::typed< nanobind::object, PyType > maybeDownCast()
Definition IRCore.cpp:1942
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:1978
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRCore.h:1168
nanobind::typed< nanobind::object, std::variant< PyBlockArgument, PyOpResult, PyValue > > maybeDownCast()
Definition IRCore.cpp:1997
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:2018
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
struct MlirDiagnostic MlirDiagnostic
Definition Diagnostics.h:29
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate(void)
Get the dynamic op trait that indicates the operation is a terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate(MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData)
Create a custom dynamic op trait with the given type ID and callbacks.
MLIR_CAPI_EXPORTED bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, MlirStringRef opName, MlirContext context)
Attach a dynamic op trait to the given operation name.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate(void)
Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:264
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:855
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1206
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:116
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:845
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:272
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1388
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:136
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:310
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1367
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:848
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1315
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:391
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1372
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:83
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1336
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1249
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1232
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:72
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:955
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition IR.cpp:320
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:354
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1188
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1171
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:402
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1220
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:288
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:360
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1307
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:994
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:869
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1192
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:836
MlirWalkResult
Operation walk result.
Definition IR.h:841
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:935
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1160
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:99
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1063
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:292
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:346
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1044
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:94
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition IR.cpp:342
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:941
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:978
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:941
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1019
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1081
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1362
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1253
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:280
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1218
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1352
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:406
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:298
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:374
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1067
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:830
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1178
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:111
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:861
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1265
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1228
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:333
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1009
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:387
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:873
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:268
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1257
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1348
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:998
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:324
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:398
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:891
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:54
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:990
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:410
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1072
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1313
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1377
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:931
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:1002
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:880
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1261
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1324
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:107
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:120
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:304
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:378
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:103
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:328
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1210
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1344
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:918
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1058
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:859
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1250
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:76
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
Definition IR.cpp:909
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:865
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1165
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:852
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:70
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:924
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:124
MLIR_CAPI_EXPORTED int mlirLlvmThreadPoolGetMaxConcurrency(MlirLlvmThreadPool pool)
Returns the maximum number of threads in the thread pool.
Definition Support.cpp:38
MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool)
Destroy an LLVM thread pool.
Definition Support.cpp:34
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void)
Create an LLVM thread pool.
Definition Support.cpp:30
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
struct MlirStringRef MlirStringRef
Definition Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:201
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1553
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m)
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1238
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:197
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1568
PyObjectRef< PyOperation > PyOperationRef
Definition IRCore.h:629
static void populateResultTypes(std::string_view name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1467
MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.h:1330
static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait, PyMlirContext &context)
Definition IRCore.cpp:2561
PyObjectRef< PyModule > PyModuleRef
Definition IRCore.h:536
static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData, const char *methodName)
Definition IRCore.cpp:2550
static PyOperationRef getValueOwnerRef(MlirValue value)
Definition IRCore.cpp:1982
MlirBlock MLIR_PYTHON_API_EXPORTED createBlock(const nanobind::sequence &pyArgTypes, const std::optional< nanobind::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1414
PyWalkOrder
Traversal order for operation walk.
Definition IRCore.h:346
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m)
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:1869
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
MlirLogicalResult(* verifyTrait)(MlirOperation op, void *userData)
The callback function to verify the operation.
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
MlirLogicalResult(* verifyRegionTrait)(MlirOperation op, void *userData)
The callback function to verify the operation with access to regions.
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Named MLIR attribute.
Definition IR.h:76
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRCore.h:1317
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:144
static nanobind::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:149
static void dunderSetItemNamed(const std::string &attributeKind, nanobind::callable func, bool replace)
Definition IRCore.cpp:156
static void set(nanobind::object &o, bool enable)
Definition IRCore.cpp:106
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRCore.h:433
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRCore.h:443