MLIR 22.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Globals.h"
10#include "IRModule.h"
11#include "NanobindUtils.h"
12#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14#include "mlir-c/Debug.h"
15#include "mlir-c/Diagnostics.h"
16#include "mlir-c/IR.h"
17#include "mlir-c/Support.h"
20#include "nanobind/nanobind.h"
21#include "nanobind/typing.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/SmallVector.h"
24
25#include <optional>
26
27namespace nb = nanobind;
28using namespace nb::literals;
29using namespace mlir;
30using namespace mlir::python;
31
33using llvm::StringRef;
34using llvm::Twine;
35
36static const char kModuleParseDocstring[] =
37 R"(Parses a module's assembly format from a string.
38
39Returns a new MlirModule or raises an MLIRError if the parsing fails.
40
41See also: https://mlir.llvm.org/docs/LangRef/
42)";
43
44static const char kDumpDocstring[] =
45 "Dumps a debug representation of the object to stderr.";
46
48 R"(Replace all uses of this value with the `with` value, except for those
49in `exceptions`. `exceptions` can be either a single operation or a list of
50operations.
51)";
52
53//------------------------------------------------------------------------------
54// Utilities.
55//------------------------------------------------------------------------------
56
57/// Helper for creating an @classmethod.
58template <class Func, typename... Args>
59static nb::object classmethod(Func f, Args... args) {
60 nb::object cf = nb::cpp_function(f, args...);
61 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
62}
63
64static nb::object
65createCustomDialectWrapper(const std::string &dialectNamespace,
66 nb::object dialectDescriptor) {
67 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
68 if (!dialectClass) {
69 // Use the base class.
70 return nb::cast(PyDialect(std::move(dialectDescriptor)));
71 }
72
73 // Create the custom implementation.
74 return (*dialectClass)(std::move(dialectDescriptor));
75}
76
77static MlirStringRef toMlirStringRef(const std::string &s) {
78 return mlirStringRefCreate(s.data(), s.size());
79}
80
81static MlirStringRef toMlirStringRef(std::string_view s) {
82 return mlirStringRefCreate(s.data(), s.size());
83}
84
85static MlirStringRef toMlirStringRef(const nb::bytes &s) {
86 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
87}
88
89/// Create a block, using the current location context if no locations are
90/// specified.
91static MlirBlock createBlock(const nb::sequence &pyArgTypes,
92 const std::optional<nb::sequence> &pyArgLocs) {
93 SmallVector<MlirType> argTypes;
94 argTypes.reserve(nb::len(pyArgTypes));
95 for (const auto &pyType : pyArgTypes)
96 argTypes.push_back(nb::cast<PyType &>(pyType));
97
99 if (pyArgLocs) {
100 argLocs.reserve(nb::len(*pyArgLocs));
101 for (const auto &pyLoc : *pyArgLocs)
102 argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
103 } else if (!argTypes.empty()) {
104 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
105 }
106
107 if (argTypes.size() != argLocs.size())
108 throw nb::value_error(("Expected " + Twine(argTypes.size()) +
109 " locations, got: " + Twine(argLocs.size()))
110 .str()
111 .c_str());
112 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
113}
114
115/// Wrapper for the global LLVM debugging flag.
117 static void set(nb::object &o, bool enable) {
118 nb::ft_lock_guard lock(mutex);
119 mlirEnableGlobalDebug(enable);
120 }
121
122 static bool get(const nb::object &) {
123 nb::ft_lock_guard lock(mutex);
125 }
126
127 static void bind(nb::module_ &m) {
128 // Debug flags.
129 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
130 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
131 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
132 .def_static(
133 "set_types",
134 [](const std::string &type) {
135 nb::ft_lock_guard lock(mutex);
136 mlirSetGlobalDebugType(type.c_str());
137 },
138 "types"_a, "Sets specific debug types to be produced by LLVM.")
139 .def_static(
140 "set_types",
141 [](const std::vector<std::string> &types) {
142 std::vector<const char *> pointers;
143 pointers.reserve(types.size());
144 for (const std::string &str : types)
145 pointers.push_back(str.c_str());
146 nb::ft_lock_guard lock(mutex);
147 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
148 },
149 "types"_a,
150 "Sets multiple specific debug types to be produced by LLVM.");
151 }
152
153private:
154 static nb::ft_mutex mutex;
155};
156
157nb::ft_mutex PyGlobalDebugFlag::mutex;
158
160 static bool dunderContains(const std::string &attributeKind) {
161 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
162 }
163 static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
164 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
165 if (!builder)
166 throw nb::key_error(attributeKind.c_str());
167 return *builder;
168 }
169 static void dunderSetItemNamed(const std::string &attributeKind,
170 nb::callable func, bool replace) {
171 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
172 replace);
173 }
174
175 static void bind(nb::module_ &m) {
176 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
177 .def_static("contains", &PyAttrBuilderMap::dunderContains,
178 "attribute_kind"_a,
179 "Checks whether an attribute builder is registered for the "
180 "given attribute kind.")
181 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
182 "attribute_kind"_a,
183 "Gets the registered attribute builder for the given "
184 "attribute kind.")
185 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
186 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
187 "Register an attribute builder for building MLIR "
188 "attributes from Python values.");
189 }
190};
191
192//------------------------------------------------------------------------------
193// PyBlock
194//------------------------------------------------------------------------------
195
197 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
198}
199
200//------------------------------------------------------------------------------
201// Collections.
202//------------------------------------------------------------------------------
203
204namespace {
205
206class PyRegionIterator {
207public:
208 PyRegionIterator(PyOperationRef operation, int nextIndex)
209 : operation(std::move(operation)), nextIndex(nextIndex) {}
210
211 PyRegionIterator &dunderIter() { return *this; }
212
213 PyRegion dunderNext() {
214 operation->checkValid();
215 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
216 throw nb::stop_iteration();
217 }
218 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
219 return PyRegion(operation, region);
220 }
221
222 static void bind(nb::module_ &m) {
223 nb::class_<PyRegionIterator>(m, "RegionIterator")
224 .def("__iter__", &PyRegionIterator::dunderIter,
225 "Returns an iterator over the regions in the operation.")
226 .def("__next__", &PyRegionIterator::dunderNext,
227 "Returns the next region in the iteration.");
228 }
229
230private:
231 PyOperationRef operation;
232 intptr_t nextIndex = 0;
233};
234
235/// Regions of an op are fixed length and indexed numerically so are represented
236/// with a sequence-like container.
237class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
238public:
239 static constexpr const char *pyClassName = "RegionSequence";
240
241 PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
242 intptr_t length = -1, intptr_t step = 1)
243 : Sliceable(startIndex,
244 length == -1 ? mlirOperationGetNumRegions(operation->get())
245 : length,
246 step),
247 operation(std::move(operation)) {}
248
249 PyRegionIterator dunderIter() {
250 operation->checkValid();
251 return PyRegionIterator(operation, startIndex);
252 }
253
254 static void bindDerived(ClassTy &c) {
255 c.def("__iter__", &PyRegionList::dunderIter,
256 "Returns an iterator over the regions in the sequence.");
257 }
258
259private:
260 /// Give the parent CRTP class access to hook implementations below.
261 friend class Sliceable<PyRegionList, PyRegion>;
262
263 intptr_t getRawNumElements() {
264 operation->checkValid();
265 return mlirOperationGetNumRegions(operation->get());
266 }
267
268 PyRegion getRawElement(intptr_t pos) {
269 operation->checkValid();
270 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
271 }
272
273 PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
274 return PyRegionList(operation, startIndex, length, step);
275 }
276
277 PyOperationRef operation;
278};
279
280class PyBlockIterator {
281public:
282 PyBlockIterator(PyOperationRef operation, MlirBlock next)
283 : operation(std::move(operation)), next(next) {}
284
285 PyBlockIterator &dunderIter() { return *this; }
286
287 PyBlock dunderNext() {
288 operation->checkValid();
289 if (mlirBlockIsNull(next)) {
290 throw nb::stop_iteration();
291 }
292
293 PyBlock returnBlock(operation, next);
294 next = mlirBlockGetNextInRegion(next);
295 return returnBlock;
296 }
297
298 static void bind(nb::module_ &m) {
299 nb::class_<PyBlockIterator>(m, "BlockIterator")
300 .def("__iter__", &PyBlockIterator::dunderIter,
301 "Returns an iterator over the blocks in the operation's region.")
302 .def("__next__", &PyBlockIterator::dunderNext,
303 "Returns the next block in the iteration.");
304 }
305
306private:
307 PyOperationRef operation;
308 MlirBlock next;
309};
310
311/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
312/// we present them as a more full-featured list-like container but optimize
313/// it for forward iteration. Blocks are always owned by a region.
314class PyBlockList {
315public:
316 PyBlockList(PyOperationRef operation, MlirRegion region)
317 : operation(std::move(operation)), region(region) {}
318
319 PyBlockIterator dunderIter() {
320 operation->checkValid();
321 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
322 }
323
324 intptr_t dunderLen() {
325 operation->checkValid();
326 intptr_t count = 0;
327 MlirBlock block = mlirRegionGetFirstBlock(region);
328 while (!mlirBlockIsNull(block)) {
329 count += 1;
330 block = mlirBlockGetNextInRegion(block);
331 }
332 return count;
333 }
334
335 PyBlock dunderGetItem(intptr_t index) {
336 operation->checkValid();
337 if (index < 0) {
338 index += dunderLen();
339 }
340 if (index < 0) {
341 throw nb::index_error("attempt to access out of bounds block");
342 }
343 MlirBlock block = mlirRegionGetFirstBlock(region);
344 while (!mlirBlockIsNull(block)) {
345 if (index == 0) {
346 return PyBlock(operation, block);
347 }
348 block = mlirBlockGetNextInRegion(block);
349 index -= 1;
350 }
351 throw nb::index_error("attempt to access out of bounds block");
352 }
353
354 PyBlock appendBlock(const nb::args &pyArgTypes,
355 const std::optional<nb::sequence> &pyArgLocs) {
356 operation->checkValid();
357 MlirBlock block =
358 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
359 mlirRegionAppendOwnedBlock(region, block);
360 return PyBlock(operation, block);
361 }
362
363 static void bind(nb::module_ &m) {
364 nb::class_<PyBlockList>(m, "BlockList")
365 .def("__getitem__", &PyBlockList::dunderGetItem,
366 "Returns the block at the specified index.")
367 .def("__iter__", &PyBlockList::dunderIter,
368 "Returns an iterator over blocks in the operation's region.")
369 .def("__len__", &PyBlockList::dunderLen,
370 "Returns the number of blocks in the operation's region.")
371 .def("append", &PyBlockList::appendBlock,
372 R"(
373 Appends a new block, with argument types as positional args.
374
375 Returns:
376 The created block.
377 )",
378 nb::arg("args"), nb::kw_only(),
379 nb::arg("arg_locs") = std::nullopt);
380 }
381
382private:
383 PyOperationRef operation;
384 MlirRegion region;
385};
386
387class PyOperationIterator {
388public:
389 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
390 : parentOperation(std::move(parentOperation)), next(next) {}
391
392 PyOperationIterator &dunderIter() { return *this; }
393
394 nb::typed<nb::object, PyOpView> dunderNext() {
395 parentOperation->checkValid();
396 if (mlirOperationIsNull(next)) {
397 throw nb::stop_iteration();
398 }
399
400 PyOperationRef returnOperation =
401 PyOperation::forOperation(parentOperation->getContext(), next);
402 next = mlirOperationGetNextInBlock(next);
403 return returnOperation->createOpView();
404 }
405
406 static void bind(nb::module_ &m) {
407 nb::class_<PyOperationIterator>(m, "OperationIterator")
408 .def("__iter__", &PyOperationIterator::dunderIter,
409 "Returns an iterator over the operations in an operation's block.")
410 .def("__next__", &PyOperationIterator::dunderNext,
411 "Returns the next operation in the iteration.");
412 }
413
414private:
415 PyOperationRef parentOperation;
416 MlirOperation next;
417};
418
419/// Operations are exposed by the C-API as a forward-only linked list. In
420/// Python, we present them as a more full-featured list-like container but
421/// optimize it for forward iteration. Iterable operations are always owned
422/// by a block.
423class PyOperationList {
424public:
425 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
426 : parentOperation(std::move(parentOperation)), block(block) {}
427
428 PyOperationIterator dunderIter() {
429 parentOperation->checkValid();
430 return PyOperationIterator(parentOperation,
432 }
433
434 intptr_t dunderLen() {
435 parentOperation->checkValid();
436 intptr_t count = 0;
437 MlirOperation childOp = mlirBlockGetFirstOperation(block);
438 while (!mlirOperationIsNull(childOp)) {
439 count += 1;
440 childOp = mlirOperationGetNextInBlock(childOp);
441 }
442 return count;
443 }
444
445 nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
446 parentOperation->checkValid();
447 if (index < 0) {
448 index += dunderLen();
449 }
450 if (index < 0) {
451 throw nb::index_error("attempt to access out of bounds operation");
452 }
453 MlirOperation childOp = mlirBlockGetFirstOperation(block);
454 while (!mlirOperationIsNull(childOp)) {
455 if (index == 0) {
456 return PyOperation::forOperation(parentOperation->getContext(), childOp)
457 ->createOpView();
458 }
459 childOp = mlirOperationGetNextInBlock(childOp);
460 index -= 1;
461 }
462 throw nb::index_error("attempt to access out of bounds operation");
463 }
464
465 static void bind(nb::module_ &m) {
466 nb::class_<PyOperationList>(m, "OperationList")
467 .def("__getitem__", &PyOperationList::dunderGetItem,
468 "Returns the operation at the specified index.")
469 .def("__iter__", &PyOperationList::dunderIter,
470 "Returns an iterator over operations in the list.")
471 .def("__len__", &PyOperationList::dunderLen,
472 "Returns the number of operations in the list.");
473 }
474
475private:
476 PyOperationRef parentOperation;
477 MlirBlock block;
478};
479
480class PyOpOperand {
481public:
482 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
483
484 nb::typed<nb::object, PyOpView> getOwner() {
485 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
486 PyMlirContextRef context =
488 return PyOperation::forOperation(context, owner)->createOpView();
489 }
490
491 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
492
493 static void bind(nb::module_ &m) {
494 nb::class_<PyOpOperand>(m, "OpOperand")
495 .def_prop_ro("owner", &PyOpOperand::getOwner,
496 "Returns the operation that owns this operand.")
497 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
498 "Returns the operand number in the owning operation.");
499 }
500
501private:
502 MlirOpOperand opOperand;
503};
504
505class PyOpOperandIterator {
506public:
507 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
508
509 PyOpOperandIterator &dunderIter() { return *this; }
510
511 PyOpOperand dunderNext() {
512 if (mlirOpOperandIsNull(opOperand))
513 throw nb::stop_iteration();
514
515 PyOpOperand returnOpOperand(opOperand);
516 opOperand = mlirOpOperandGetNextUse(opOperand);
517 return returnOpOperand;
518 }
519
520 static void bind(nb::module_ &m) {
521 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
522 .def("__iter__", &PyOpOperandIterator::dunderIter,
523 "Returns an iterator over operands.")
524 .def("__next__", &PyOpOperandIterator::dunderNext,
525 "Returns the next operand in the iteration.");
526 }
527
528private:
529 MlirOpOperand opOperand;
530};
531
532} // namespace
533
534//------------------------------------------------------------------------------
535// PyMlirContext
536//------------------------------------------------------------------------------
537
538PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
539 nb::gil_scoped_acquire acquire;
540 nb::ft_lock_guard lock(live_contexts_mutex);
541 auto &liveContexts = getLiveContexts();
542 liveContexts[context.ptr] = this;
543}
544
546 // Note that the only public way to construct an instance is via the
547 // forContext method, which always puts the associated handle into
548 // liveContexts.
549 nb::gil_scoped_acquire acquire;
550 {
551 nb::ft_lock_guard lock(live_contexts_mutex);
552 getLiveContexts().erase(context.ptr);
553 }
554 mlirContextDestroy(context);
555}
556
558 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
559}
560
561nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
562 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
563 if (mlirContextIsNull(rawContext))
564 throw nb::python_error();
565 return forContext(rawContext).releaseObject();
566}
567
569 nb::gil_scoped_acquire acquire;
570 nb::ft_lock_guard lock(live_contexts_mutex);
571 auto &liveContexts = getLiveContexts();
572 auto it = liveContexts.find(context.ptr);
573 if (it == liveContexts.end()) {
574 // Create.
575 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
576 nb::object pyRef = nb::cast(unownedContextWrapper);
577 assert(pyRef && "cast to nb::object failed");
578 liveContexts[context.ptr] = unownedContextWrapper;
579 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
580 }
581 // Use existing.
582 nb::object pyRef = nb::cast(it->second);
583 return PyMlirContextRef(it->second, std::move(pyRef));
584}
585
586nb::ft_mutex PyMlirContext::live_contexts_mutex;
587
588PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
589 static LiveContextMap liveContexts;
590 return liveContexts;
591}
592
594 nb::ft_lock_guard lock(live_contexts_mutex);
595 return getLiveContexts().size();
596}
597
598nb::object PyMlirContext::contextEnter(nb::object context) {
599 return PyThreadContextEntry::pushContext(context);
600}
601
602void PyMlirContext::contextExit(const nb::object &excType,
603 const nb::object &excVal,
604 const nb::object &excTb) {
606}
607
608nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
609 // Note that ownership is transferred to the delete callback below by way of
610 // an explicit inc_ref (borrow).
611 PyDiagnosticHandler *pyHandler =
612 new PyDiagnosticHandler(get(), std::move(callback));
613 nb::object pyHandlerObject =
614 nb::cast(pyHandler, nb::rv_policy::take_ownership);
615 (void)pyHandlerObject.inc_ref();
616
617 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
618 // guaranteed to be known to pybind.
619 auto handlerCallback =
620 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
621 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
622 nb::object pyDiagnosticObject =
623 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
624
625 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
626 bool result = false;
627 {
628 // Since this can be called from arbitrary C++ contexts, always get the
629 // gil.
630 nb::gil_scoped_acquire gil;
631 try {
632 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
633 } catch (std::exception &e) {
634 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
635 e.what());
636 pyHandler->hadError = true;
637 }
638 }
639
640 pyDiagnostic->invalidate();
642 };
643 auto deleteCallback = +[](void *userData) {
644 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
645 assert(pyHandler->registeredID && "handler is not registered");
646 pyHandler->registeredID.reset();
647
648 // Decrement reference, balancing the inc_ref() above.
649 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
650 pyHandlerObject.dec_ref();
651 };
652
653 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
654 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
655 return pyHandlerObject;
656}
657
658MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
659 void *userData) {
660 auto *self = static_cast<ErrorCapture *>(userData);
661 // Check if the context requested we emit errors instead of capturing them.
662 if (self->ctx->emitErrorDiagnostics)
664
667
668 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
670}
671
674 if (!context) {
675 throw std::runtime_error(
676 "An MLIR function requires a Context but none was provided in the call "
677 "or from the surrounding environment. Either pass to the function with "
678 "a 'context=' argument or establish a default using 'with Context():'");
679 }
680 return *context;
681}
682
683//------------------------------------------------------------------------------
684// PyThreadContextEntry management
685//------------------------------------------------------------------------------
686
687std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
688 static thread_local std::vector<PyThreadContextEntry> stack;
689 return stack;
690}
691
693 auto &stack = getStack();
694 if (stack.empty())
695 return nullptr;
696 return &stack.back();
697}
698
699void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
700 nb::object insertionPoint,
701 nb::object location) {
702 auto &stack = getStack();
703 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
704 std::move(location));
705 // If the new stack has more than one entry and the context of the new top
706 // entry matches the previous, copy the insertionPoint and location from the
707 // previous entry if missing from the new top entry.
708 if (stack.size() > 1) {
709 auto &prev = *(stack.rbegin() + 1);
710 auto &current = stack.back();
711 if (current.context.is(prev.context)) {
712 // Default non-context objects from the previous entry.
713 if (!current.insertionPoint)
714 current.insertionPoint = prev.insertionPoint;
715 if (!current.location)
716 current.location = prev.location;
717 }
718 }
719}
720
722 if (!context)
723 return nullptr;
724 return nb::cast<PyMlirContext *>(context);
725}
726
728 if (!insertionPoint)
729 return nullptr;
730 return nb::cast<PyInsertionPoint *>(insertionPoint);
731}
732
734 if (!location)
735 return nullptr;
736 return nb::cast<PyLocation *>(location);
737}
738
740 auto *tos = getTopOfStack();
741 return tos ? tos->getContext() : nullptr;
742}
743
745 auto *tos = getTopOfStack();
746 return tos ? tos->getInsertionPoint() : nullptr;
747}
748
750 auto *tos = getTopOfStack();
751 return tos ? tos->getLocation() : nullptr;
752}
753
754nb::object PyThreadContextEntry::pushContext(nb::object context) {
755 push(FrameKind::Context, /*context=*/context,
756 /*insertionPoint=*/nb::object(),
757 /*location=*/nb::object());
758 return context;
759}
760
762 auto &stack = getStack();
763 if (stack.empty())
764 throw std::runtime_error("Unbalanced Context enter/exit");
765 auto &tos = stack.back();
766 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
767 throw std::runtime_error("Unbalanced Context enter/exit");
768 stack.pop_back();
769}
770
771nb::object
772PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
773 PyInsertionPoint &insertionPoint =
774 nb::cast<PyInsertionPoint &>(insertionPointObj);
775 nb::object contextObj =
776 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
778 /*context=*/contextObj,
779 /*insertionPoint=*/insertionPointObj,
780 /*location=*/nb::object());
781 return insertionPointObj;
782}
783
785 auto &stack = getStack();
786 if (stack.empty())
787 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
788 auto &tos = stack.back();
789 if (tos.frameKind != FrameKind::InsertionPoint &&
790 tos.getInsertionPoint() != &insertionPoint)
791 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
792 stack.pop_back();
793}
794
795nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
796 PyLocation &location = nb::cast<PyLocation &>(locationObj);
797 nb::object contextObj = location.getContext().getObject();
798 push(FrameKind::Location, /*context=*/contextObj,
799 /*insertionPoint=*/nb::object(),
800 /*location=*/locationObj);
801 return locationObj;
802}
803
805 auto &stack = getStack();
806 if (stack.empty())
807 throw std::runtime_error("Unbalanced Location enter/exit");
808 auto &tos = stack.back();
809 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
810 throw std::runtime_error("Unbalanced Location enter/exit");
811 stack.pop_back();
812}
813
814//------------------------------------------------------------------------------
815// PyDiagnostic*
816//------------------------------------------------------------------------------
817
819 valid = false;
820 if (materializedNotes) {
821 for (nb::handle noteObject : *materializedNotes) {
822 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
823 note->invalidate();
824 }
825 }
826}
827
829 nb::object callback)
830 : context(context), callback(std::move(callback)) {}
831
833
835 if (!registeredID)
836 return;
837 MlirDiagnosticHandlerID localID = *registeredID;
838 mlirContextDetachDiagnosticHandler(context, localID);
839 assert(!registeredID && "should have unregistered");
840 // Not strictly necessary but keeps stale pointers from being around to cause
841 // issues.
842 context = {nullptr};
843}
844
845void PyDiagnostic::checkValid() {
846 if (!valid) {
847 throw std::invalid_argument(
848 "Diagnostic is invalid (used outside of callback)");
849 }
850}
851
853 checkValid();
854 return mlirDiagnosticGetSeverity(diagnostic);
855}
856
858 checkValid();
859 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
860 MlirContext context = mlirLocationGetContext(loc);
861 return PyLocation(PyMlirContext::forContext(context), loc);
862}
863
865 checkValid();
866 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
867 PyFileAccumulator accum(fileObject, /*binary=*/false);
868 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
869 return nb::cast<nb::str>(fileObject.attr("getvalue")());
870}
871
873 checkValid();
874 if (materializedNotes)
875 return *materializedNotes;
876 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
877 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
878 for (intptr_t i = 0; i < numNotes; ++i) {
879 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
880 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
881 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
882 }
883 materializedNotes = std::move(notes);
884
885 return *materializedNotes;
886}
887
889 std::vector<DiagnosticInfo> notes;
890 for (nb::handle n : getNotes())
891 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
892 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
893 std::move(notes)};
894}
895
896//------------------------------------------------------------------------------
897// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
898//------------------------------------------------------------------------------
899
900MlirDialect PyDialects::getDialectForKey(const std::string &key,
901 bool attrError) {
902 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
903 {key.data(), key.size()});
904 if (mlirDialectIsNull(dialect)) {
905 std::string msg = (Twine("Dialect '") + key + "' not found").str();
906 if (attrError)
907 throw nb::attribute_error(msg.c_str());
908 throw nb::index_error(msg.c_str());
909 }
910 return dialect;
911}
912
914 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
915}
916
918 MlirDialectRegistry rawRegistry =
920 if (mlirDialectRegistryIsNull(rawRegistry))
921 throw nb::python_error();
922 return PyDialectRegistry(rawRegistry);
923}
924
925//------------------------------------------------------------------------------
926// PyLocation
927//------------------------------------------------------------------------------
928
930 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
931}
932
934 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
935 if (mlirLocationIsNull(rawLoc))
936 throw nb::python_error();
938 rawLoc);
939}
940
941nb::object PyLocation::contextEnter(nb::object locationObj) {
942 return PyThreadContextEntry::pushLocation(locationObj);
943}
944
945void PyLocation::contextExit(const nb::object &excType,
946 const nb::object &excVal,
947 const nb::object &excTb) {
949}
950
953 if (!location) {
954 throw std::runtime_error(
955 "An MLIR function requires a Location but none was provided in the "
956 "call or from the surrounding environment. Either pass to the function "
957 "with a 'loc=' argument or establish a default using 'with loc:'");
958 }
959 return *location;
960}
961
962//------------------------------------------------------------------------------
963// PyModule
964//------------------------------------------------------------------------------
965
966PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
967 : BaseContextObject(std::move(contextRef)), module(module) {}
968
970 nb::gil_scoped_acquire acquire;
971 auto &liveModules = getContext()->liveModules;
972 assert(liveModules.count(module.ptr) == 1 &&
973 "destroying module not in live map");
974 liveModules.erase(module.ptr);
975 mlirModuleDestroy(module);
976}
977
979 MlirContext context = mlirModuleGetContext(module);
980 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
981
982 nb::gil_scoped_acquire acquire;
983 auto &liveModules = contextRef->liveModules;
984 auto it = liveModules.find(module.ptr);
985 if (it == liveModules.end()) {
986 // Create.
987 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
988 // Note that the default return value policy on cast is automatic_reference,
989 // which does not take ownership (delete will not be called).
990 // Just be explicit.
991 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
992 unownedModule->handle = pyRef;
993 liveModules[module.ptr] =
994 std::make_pair(unownedModule->handle, unownedModule);
995 return PyModuleRef(unownedModule, std::move(pyRef));
996 }
997 // Use existing.
998 PyModule *existing = it->second.second;
999 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1000 return PyModuleRef(existing, std::move(pyRef));
1001}
1002
1003nb::object PyModule::createFromCapsule(nb::object capsule) {
1004 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1005 if (mlirModuleIsNull(rawModule))
1006 throw nb::python_error();
1007 return forModule(rawModule).releaseObject();
1008}
1009
1011 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1012}
1013
1014//------------------------------------------------------------------------------
1015// PyOperation
1016//------------------------------------------------------------------------------
1017
1018PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1019 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1020
1022 // If the operation has already been invalidated there is nothing to do.
1023 if (!valid)
1024 return;
1025 // Otherwise, invalidate the operation when it is attached.
1026 if (isAttached())
1027 setInvalid();
1028 else {
1029 // And destroy it when it is detached, i.e. owned by Python.
1030 erase();
1031 }
1032}
1033
1034namespace {
1035
1036// Constructs a new object of type T in-place on the Python heap, returning a
1037// PyObjectRef to it, loosely analogous to std::make_shared<T>().
1038template <typename T, class... Args>
1039PyObjectRef<T> makeObjectRef(Args &&...args) {
1040 nb::handle type = nb::type<T>();
1041 nb::object instance = nb::inst_alloc(type);
1042 T *ptr = nb::inst_ptr<T>(instance);
1043 new (ptr) T(std::forward<Args>(args)...);
1044 nb::inst_mark_ready(instance);
1045 return PyObjectRef<T>(ptr, std::move(instance));
1046}
1047
1048} // namespace
1049
1050PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1051 MlirOperation operation,
1052 nb::object parentKeepAlive) {
1053 // Create.
1054 PyOperationRef unownedOperation =
1055 makeObjectRef<PyOperation>(std::move(contextRef), operation);
1056 unownedOperation->handle = unownedOperation.getObject();
1057 if (parentKeepAlive) {
1058 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1059 }
1060 return unownedOperation;
1061}
1062
1064 MlirOperation operation,
1065 nb::object parentKeepAlive) {
1066 return createInstance(std::move(contextRef), operation,
1067 std::move(parentKeepAlive));
1068}
1069
1071 MlirOperation operation,
1072 nb::object parentKeepAlive) {
1073 PyOperationRef created = createInstance(std::move(contextRef), operation,
1074 std::move(parentKeepAlive));
1075 created->attached = false;
1076 return created;
1077}
1078
1080 const std::string &sourceStr,
1081 const std::string &sourceName) {
1082 PyMlirContext::ErrorCapture errors(contextRef);
1083 MlirOperation op =
1084 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1085 toMlirStringRef(sourceName));
1086 if (mlirOperationIsNull(op))
1087 throw MLIRError("Unable to parse operation assembly", errors.take());
1088 return PyOperation::createDetached(std::move(contextRef), op);
1089}
1090
1092 if (!valid) {
1093 throw std::runtime_error("the operation has been invalidated");
1094 }
1095}
1096
1097void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1098 std::optional<int64_t> largeResourceLimit,
1099 bool enableDebugInfo, bool prettyDebugInfo,
1100 bool printGenericOpForm, bool useLocalScope,
1101 bool useNameLocAsPrefix, bool assumeVerified,
1102 nb::object fileObject, bool binary,
1103 bool skipRegions) {
1104 PyOperation &operation = getOperation();
1105 operation.checkValid();
1106 if (fileObject.is_none())
1107 fileObject = nb::module_::import_("sys").attr("stdout");
1108
1109 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1110 if (largeElementsLimit)
1111 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1112 if (largeResourceLimit)
1113 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1114 if (enableDebugInfo)
1115 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1116 /*prettyForm=*/prettyDebugInfo);
1117 if (printGenericOpForm)
1119 if (useLocalScope)
1121 if (assumeVerified)
1123 if (skipRegions)
1125 if (useNameLocAsPrefix)
1127
1128 PyFileAccumulator accum(fileObject, binary);
1129 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1130 accum.getUserData());
1132}
1133
1134void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1135 bool binary) {
1136 PyOperation &operation = getOperation();
1137 operation.checkValid();
1138 if (fileObject.is_none())
1139 fileObject = nb::module_::import_("sys").attr("stdout");
1140 PyFileAccumulator accum(fileObject, binary);
1141 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1142 accum.getUserData());
1143}
1144
1145void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1146 std::optional<int64_t> bytecodeVersion) {
1147 PyOperation &operation = getOperation();
1148 operation.checkValid();
1149 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1150
1151 if (!bytecodeVersion.has_value())
1152 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1153 accum.getUserData());
1154
1155 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1158 operation, config, accum.getCallback(), accum.getUserData());
1161 throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1162 Twine(*bytecodeVersion))
1163 .str()
1164 .c_str());
1165}
1166
1168 std::function<MlirWalkResult(MlirOperation)> callback,
1169 MlirWalkOrder walkOrder) {
1170 PyOperation &operation = getOperation();
1171 operation.checkValid();
1172 struct UserData {
1173 std::function<MlirWalkResult(MlirOperation)> callback;
1174 bool gotException;
1175 std::string exceptionWhat;
1176 nb::object exceptionType;
1177 };
1178 UserData userData{callback, false, {}, {}};
1179 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1180 void *userData) {
1181 UserData *calleeUserData = static_cast<UserData *>(userData);
1182 try {
1183 return (calleeUserData->callback)(op);
1184 } catch (nb::python_error &e) {
1185 calleeUserData->gotException = true;
1186 calleeUserData->exceptionWhat = std::string(e.what());
1187 calleeUserData->exceptionType = nb::borrow(e.type());
1188 return MlirWalkResult::MlirWalkResultInterrupt;
1189 }
1190 };
1191 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1192 if (userData.gotException) {
1193 std::string message("Exception raised in callback: ");
1194 message.append(userData.exceptionWhat);
1195 throw std::runtime_error(message);
1196 }
1197}
1198
1199nb::object PyOperationBase::getAsm(bool binary,
1200 std::optional<int64_t> largeElementsLimit,
1201 std::optional<int64_t> largeResourceLimit,
1202 bool enableDebugInfo, bool prettyDebugInfo,
1203 bool printGenericOpForm, bool useLocalScope,
1204 bool useNameLocAsPrefix, bool assumeVerified,
1205 bool skipRegions) {
1206 nb::object fileObject;
1207 if (binary) {
1208 fileObject = nb::module_::import_("io").attr("BytesIO")();
1209 } else {
1210 fileObject = nb::module_::import_("io").attr("StringIO")();
1211 }
1212 print(/*largeElementsLimit=*/largeElementsLimit,
1213 /*largeResourceLimit=*/largeResourceLimit,
1214 /*enableDebugInfo=*/enableDebugInfo,
1215 /*prettyDebugInfo=*/prettyDebugInfo,
1216 /*printGenericOpForm=*/printGenericOpForm,
1217 /*useLocalScope=*/useLocalScope,
1218 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1219 /*assumeVerified=*/assumeVerified,
1220 /*fileObject=*/fileObject,
1221 /*binary=*/binary,
1222 /*skipRegions=*/skipRegions);
1223
1224 return fileObject.attr("getvalue")();
1225}
1226
1228 PyOperation &operation = getOperation();
1229 PyOperation &otherOp = other.getOperation();
1230 operation.checkValid();
1231 otherOp.checkValid();
1232 mlirOperationMoveAfter(operation, otherOp);
1233 operation.parentKeepAlive = otherOp.parentKeepAlive;
1234}
1235
1237 PyOperation &operation = getOperation();
1238 PyOperation &otherOp = other.getOperation();
1239 operation.checkValid();
1240 otherOp.checkValid();
1241 mlirOperationMoveBefore(operation, otherOp);
1242 operation.parentKeepAlive = otherOp.parentKeepAlive;
1243}
1244
1246 PyOperation &operation = getOperation();
1247 PyOperation &otherOp = other.getOperation();
1248 operation.checkValid();
1249 otherOp.checkValid();
1250 return mlirOperationIsBeforeInBlock(operation, otherOp);
1251}
1252
1254 PyOperation &op = getOperation();
1256 if (!mlirOperationVerify(op.get()))
1257 throw MLIRError("Verification failed", errors.take());
1258 return true;
1259}
1260
1261std::optional<PyOperationRef> PyOperation::getParentOperation() {
1262 checkValid();
1263 if (!isAttached())
1264 throw nb::value_error("Detached operations have no parent");
1265 MlirOperation operation = mlirOperationGetParentOperation(get());
1266 if (mlirOperationIsNull(operation))
1267 return {};
1268 return PyOperation::forOperation(getContext(), operation);
1269}
1270
1272 checkValid();
1273 std::optional<PyOperationRef> parentOperation = getParentOperation();
1274 MlirBlock block = mlirOperationGetBlock(get());
1275 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1276 assert(parentOperation && "Operation has no parent");
1277 return PyBlock{std::move(*parentOperation), block};
1278}
1279
1281 checkValid();
1282 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1283}
1284
1285nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1286 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1287 if (mlirOperationIsNull(rawOperation))
1288 throw nb::python_error();
1289 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1290 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1291 .releaseObject();
1292}
1293
1295 const nb::object &maybeIp) {
1296 // InsertPoint active?
1297 if (!maybeIp.is(nb::cast(false))) {
1298 PyInsertionPoint *ip;
1299 if (maybeIp.is_none()) {
1301 } else {
1302 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1303 }
1304 if (ip)
1305 ip->insert(*op.get());
1306 }
1307}
1308
1309nb::object PyOperation::create(std::string_view name,
1310 std::optional<std::vector<PyType *>> results,
1312 std::optional<nb::dict> attributes,
1313 std::optional<std::vector<PyBlock *>> successors,
1314 int regions, PyLocation &location,
1315 const nb::object &maybeIp, bool inferType) {
1317 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1319
1320 // General parameter validation.
1321 if (regions < 0)
1322 throw nb::value_error("number of regions must be >= 0");
1323
1324 // Unpack/validate results.
1325 if (results) {
1326 mlirResults.reserve(results->size());
1327 for (PyType *result : *results) {
1328 // TODO: Verify result type originate from the same context.
1329 if (!result)
1330 throw nb::value_error("result type cannot be None");
1331 mlirResults.push_back(*result);
1332 }
1333 }
1334 // Unpack/validate attributes.
1335 if (attributes) {
1336 mlirAttributes.reserve(attributes->size());
1337 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1338 std::string key;
1339 try {
1340 key = nb::cast<std::string>(it.first);
1341 } catch (nb::cast_error &err) {
1342 std::string msg = "Invalid attribute key (not a string) when "
1343 "attempting to create the operation \"" +
1344 std::string(name) + "\" (" + err.what() + ")";
1345 throw nb::type_error(msg.c_str());
1346 }
1347 try {
1348 auto &attribute = nb::cast<PyAttribute &>(it.second);
1349 // TODO: Verify attribute originates from the same context.
1350 mlirAttributes.emplace_back(std::move(key), attribute);
1351 } catch (nb::cast_error &err) {
1352 std::string msg = "Invalid attribute value for the key \"" + key +
1353 "\" when attempting to create the operation \"" +
1354 std::string(name) + "\" (" + err.what() + ")";
1355 throw nb::type_error(msg.c_str());
1356 } catch (std::runtime_error &) {
1357 // This exception seems thrown when the value is "None".
1358 std::string msg =
1359 "Found an invalid (`None`?) attribute value for the key \"" + key +
1360 "\" when attempting to create the operation \"" +
1361 std::string(name) + "\"";
1362 throw std::runtime_error(msg);
1363 }
1364 }
1365 }
1366 // Unpack/validate successors.
1367 if (successors) {
1368 mlirSuccessors.reserve(successors->size());
1369 for (auto *successor : *successors) {
1370 // TODO: Verify successor originate from the same context.
1371 if (!successor)
1372 throw nb::value_error("successor block cannot be None");
1373 mlirSuccessors.push_back(successor->get());
1374 }
1375 }
1376
1377 // Apply unpacked/validated to the operation state. Beyond this
1378 // point, exceptions cannot be thrown or else the state will leak.
1379 MlirOperationState state =
1380 mlirOperationStateGet(toMlirStringRef(name), location);
1381 if (!operands.empty())
1382 mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1383 state.enableResultTypeInference = inferType;
1384 if (!mlirResults.empty())
1385 mlirOperationStateAddResults(&state, mlirResults.size(),
1386 mlirResults.data());
1387 if (!mlirAttributes.empty()) {
1388 // Note that the attribute names directly reference bytes in
1389 // mlirAttributes, so that vector must not be changed from here
1390 // on.
1391 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1392 mlirNamedAttributes.reserve(mlirAttributes.size());
1393 for (auto &it : mlirAttributes)
1394 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1396 toMlirStringRef(it.first)),
1397 it.second));
1398 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1399 mlirNamedAttributes.data());
1400 }
1401 if (!mlirSuccessors.empty())
1402 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1403 mlirSuccessors.data());
1404 if (regions) {
1406 mlirRegions.resize(regions);
1407 for (int i = 0; i < regions; ++i)
1408 mlirRegions[i] = mlirRegionCreate();
1409 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1410 mlirRegions.data());
1411 }
1412
1413 // Construct the operation.
1414 MlirOperation operation = mlirOperationCreate(&state);
1415 if (!operation.ptr)
1416 throw nb::value_error("Operation creation failed");
1417 PyOperationRef created =
1418 PyOperation::createDetached(location.getContext(), operation);
1419 maybeInsertOperation(created, maybeIp);
1420
1421 return created.getObject();
1422}
1423
1424nb::object PyOperation::clone(const nb::object &maybeIp) {
1425 MlirOperation clonedOperation = mlirOperationClone(operation);
1426 PyOperationRef cloned =
1427 PyOperation::createDetached(getContext(), clonedOperation);
1428 maybeInsertOperation(cloned, maybeIp);
1429
1430 return cloned->createOpView();
1431}
1432
1434 checkValid();
1435 MlirIdentifier ident = mlirOperationGetName(get());
1436 MlirStringRef identStr = mlirIdentifierStr(ident);
1437 auto operationCls = PyGlobals::get().lookupOperationClass(
1438 StringRef(identStr.data, identStr.length));
1439 if (operationCls)
1440 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1441 return nb::cast(PyOpView(getRef().getObject()));
1442}
1443
1445 checkValid();
1446 setInvalid();
1447 mlirOperationDestroy(operation);
1448}
1449
1450namespace {
1451/// CRTP base class for Python MLIR values that subclass Value and should be
1452/// castable from it. The value hierarchy is one level deep and is not supposed
1453/// to accommodate other levels unless core MLIR changes.
1454template <typename DerivedTy>
1455class PyConcreteValue : public PyValue {
1456public:
1457 // Derived classes must define statics for:
1458 // IsAFunctionTy isaFunction
1459 // const char *pyClassName
1460 // and redefine bindDerived.
1461 using ClassTy = nb::class_<DerivedTy, PyValue>;
1462 using IsAFunctionTy = bool (*)(MlirValue);
1463
1464 PyConcreteValue() = default;
1465 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1466 : PyValue(operationRef, value) {}
1467 PyConcreteValue(PyValue &orig)
1468 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1469
1470 /// Attempts to cast the original value to the derived type and throws on
1471 /// type mismatches.
1472 static MlirValue castFrom(PyValue &orig) {
1473 if (!DerivedTy::isaFunction(orig.get())) {
1474 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1475 throw nb::value_error((Twine("Cannot cast value to ") +
1476 DerivedTy::pyClassName + " (from " + origRepr +
1477 ")")
1478 .str()
1479 .c_str());
1480 }
1481 return orig.get();
1482 }
1483
1484 /// Binds the Python module objects to functions of this class.
1485 static void bind(nb::module_ &m) {
1486 auto cls = ClassTy(
1487 m, DerivedTy::pyClassName, nb::is_generic(),
1488 nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
1489 .str()
1490 .c_str()));
1491 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1492 cls.def_static(
1493 "isinstance",
1494 [](PyValue &otherValue) -> bool {
1495 return DerivedTy::isaFunction(otherValue);
1496 },
1497 nb::arg("other_value"));
1499 [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
1500 return self.maybeDownCast();
1501 });
1502 DerivedTy::bindDerived(cls);
1503 }
1504
1505 /// Implemented by derived classes to add methods to the Python subclass.
1506 static void bindDerived(ClassTy &m) {}
1507};
1508
1509} // namespace
1510
1511/// Python wrapper for MlirOpResult.
1512class PyOpResult : public PyConcreteValue<PyOpResult> {
1513public:
1514 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1515 static constexpr const char *pyClassName = "OpResult";
1516 using PyConcreteValue::PyConcreteValue;
1517
1518 static void bindDerived(ClassTy &c) {
1519 c.def_prop_ro(
1520 "owner",
1521 [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
1522 assert(mlirOperationEqual(self.getParentOperation()->get(),
1523 mlirOpResultGetOwner(self.get())) &&
1524 "expected the owner of the value in Python to match that in "
1525 "the IR");
1526 return self.getParentOperation().getObject();
1527 },
1528 "Returns the operation that produces this result.");
1529 c.def_prop_ro(
1530 "result_number",
1531 [](PyOpResult &self) {
1532 return mlirOpResultGetResultNumber(self.get());
1533 },
1534 "Returns the position of this result in the operation's result list.");
1535 }
1536};
1537
1538/// Returns the list of types of the values held by container.
1539template <typename Container>
1540static std::vector<nb::typed<nb::object, PyType>>
1541getValueTypes(Container &container, PyMlirContextRef &context) {
1542 std::vector<nb::typed<nb::object, PyType>> result;
1543 result.reserve(container.size());
1544 for (int i = 0, e = container.size(); i < e; ++i) {
1545 result.push_back(PyType(context->getRef(),
1546 mlirValueGetType(container.getElement(i).get()))
1547 .maybeDownCast());
1548 }
1549 return result;
1550}
1551
1552/// A list of operation results. Internally, these are stored as consecutive
1553/// elements, random access is cheap. The (returned) result list is associated
1554/// with the operation whose results these are, and thus extends the lifetime of
1555/// this operation.
1556class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1557public:
1558 static constexpr const char *pyClassName = "OpResultList";
1560
1562 intptr_t length = -1, intptr_t step = 1)
1564 length == -1 ? mlirOperationGetNumResults(operation->get())
1565 : length,
1566 step),
1567 operation(std::move(operation)) {}
1568
1569 static void bindDerived(ClassTy &c) {
1570 c.def_prop_ro(
1571 "types",
1572 [](PyOpResultList &self) {
1573 return getValueTypes(self, self.operation->getContext());
1574 },
1575 "Returns a list of types for all results in this result list.");
1576 c.def_prop_ro(
1577 "owner",
1578 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1579 return self.operation->createOpView();
1580 },
1581 "Returns the operation that owns this result list.");
1582 }
1583
1584 PyOperationRef &getOperation() { return operation; }
1585
1586private:
1587 /// Give the parent CRTP class access to hook implementations below.
1588 friend class Sliceable<PyOpResultList, PyOpResult>;
1589
1590 intptr_t getRawNumElements() {
1591 operation->checkValid();
1592 return mlirOperationGetNumResults(operation->get());
1593 }
1594
1595 PyOpResult getRawElement(intptr_t index) {
1596 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1597 return PyOpResult(value);
1598 }
1599
1600 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1601 return PyOpResultList(operation, startIndex, length, step);
1602 }
1603
1604 PyOperationRef operation;
1605};
1606
1607//------------------------------------------------------------------------------
1608// PyOpView
1609//------------------------------------------------------------------------------
1610
1611static void populateResultTypes(StringRef name, nb::list resultTypeList,
1612 const nb::object &resultSegmentSpecObj,
1613 std::vector<int32_t> &resultSegmentLengths,
1614 std::vector<PyType *> &resultTypes) {
1615 resultTypes.reserve(resultTypeList.size());
1616 if (resultSegmentSpecObj.is_none()) {
1617 // Non-variadic result unpacking.
1618 for (const auto &it : llvm::enumerate(resultTypeList)) {
1619 try {
1620 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1621 if (!resultTypes.back())
1622 throw nb::cast_error();
1623 } catch (nb::cast_error &err) {
1624 throw nb::value_error((llvm::Twine("Result ") +
1625 llvm::Twine(it.index()) + " of operation \"" +
1626 name + "\" must be a Type (" + err.what() + ")")
1627 .str()
1628 .c_str());
1629 }
1630 }
1631 } else {
1632 // Sized result unpacking.
1633 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1634 if (resultSegmentSpec.size() != resultTypeList.size()) {
1635 throw nb::value_error((llvm::Twine("Operation \"") + name +
1636 "\" requires " +
1637 llvm::Twine(resultSegmentSpec.size()) +
1638 " result segments but was provided " +
1639 llvm::Twine(resultTypeList.size()))
1640 .str()
1641 .c_str());
1642 }
1643 resultSegmentLengths.reserve(resultTypeList.size());
1644 for (const auto &it :
1645 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1646 int segmentSpec = std::get<1>(it.value());
1647 if (segmentSpec == 1 || segmentSpec == 0) {
1648 // Unpack unary element.
1649 try {
1650 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1651 if (resultType) {
1652 resultTypes.push_back(resultType);
1653 resultSegmentLengths.push_back(1);
1654 } else if (segmentSpec == 0) {
1655 // Allowed to be optional.
1656 resultSegmentLengths.push_back(0);
1657 } else {
1658 throw nb::value_error(
1659 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1660 " of operation \"" + name +
1661 "\" must be a Type (was None and result is not optional)")
1662 .str()
1663 .c_str());
1664 }
1665 } catch (nb::cast_error &err) {
1666 throw nb::value_error((llvm::Twine("Result ") +
1667 llvm::Twine(it.index()) + " of operation \"" +
1668 name + "\" must be a Type (" + err.what() +
1669 ")")
1670 .str()
1671 .c_str());
1672 }
1673 } else if (segmentSpec == -1) {
1674 // Unpack sequence by appending.
1675 try {
1676 if (std::get<0>(it.value()).is_none()) {
1677 // Treat it as an empty list.
1678 resultSegmentLengths.push_back(0);
1679 } else {
1680 // Unpack the list.
1681 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1682 for (nb::handle segmentItem : segment) {
1683 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1684 if (!resultTypes.back()) {
1685 throw nb::type_error("contained a None item");
1686 }
1687 }
1688 resultSegmentLengths.push_back(nb::len(segment));
1689 }
1690 } catch (std::exception &err) {
1691 // NOTE: Sloppy to be using a catch-all here, but there are at least
1692 // three different unrelated exceptions that can be thrown in the
1693 // above "casts". Just keep the scope above small and catch them all.
1694 throw nb::value_error((llvm::Twine("Result ") +
1695 llvm::Twine(it.index()) + " of operation \"" +
1696 name + "\" must be a Sequence of Types (" +
1697 err.what() + ")")
1698 .str()
1699 .c_str());
1700 }
1701 } else {
1702 throw nb::value_error("Unexpected segment spec");
1703 }
1704 }
1705 }
1706}
1707
1708static MlirValue getUniqueResult(MlirOperation operation) {
1709 auto numResults = mlirOperationGetNumResults(operation);
1710 if (numResults != 1) {
1711 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1712 throw nb::value_error((Twine("Cannot call .result on operation ") +
1713 StringRef(name.data, name.length) + " which has " +
1714 Twine(numResults) +
1715 " results (it is only valid for operations with a "
1716 "single result)")
1717 .str()
1718 .c_str());
1719 }
1720 return mlirOperationGetResult(operation, 0);
1721}
1722
1723static MlirValue getOpResultOrValue(nb::handle operand) {
1724 if (operand.is_none()) {
1725 throw nb::value_error("contained a None item");
1726 }
1727 PyOperationBase *op;
1728 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1729 return getUniqueResult(op->getOperation());
1730 }
1731 PyOpResultList *opResultList;
1732 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1733 return getUniqueResult(opResultList->getOperation()->get());
1734 }
1735 PyValue *value;
1736 if (nb::try_cast<PyValue *>(operand, value)) {
1737 return value->get();
1738 }
1739 throw nb::value_error("is not a Value");
1740}
1741
1743 std::string_view name, std::tuple<int, bool> opRegionSpec,
1744 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1745 std::optional<nb::list> resultTypeList, nb::list operandList,
1746 std::optional<nb::dict> attributes,
1747 std::optional<std::vector<PyBlock *>> successors,
1748 std::optional<int> regions, PyLocation &location,
1749 const nb::object &maybeIp) {
1750 PyMlirContextRef context = location.getContext();
1751
1752 // Class level operation construction metadata.
1753 // Operand and result segment specs are either none, which does no
1754 // variadic unpacking, or a list of ints with segment sizes, where each
1755 // element is either a positive number (typically 1 for a scalar) or -1 to
1756 // indicate that it is derived from the length of the same-indexed operand
1757 // or result (implying that it is a list at that position).
1758 std::vector<int32_t> operandSegmentLengths;
1759 std::vector<int32_t> resultSegmentLengths;
1760
1761 // Validate/determine region count.
1762 int opMinRegionCount = std::get<0>(opRegionSpec);
1763 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1764 if (!regions) {
1765 regions = opMinRegionCount;
1766 }
1767 if (*regions < opMinRegionCount) {
1768 throw nb::value_error(
1769 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1770 llvm::Twine(opMinRegionCount) +
1771 " regions but was built with regions=" + llvm::Twine(*regions))
1772 .str()
1773 .c_str());
1774 }
1775 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1776 throw nb::value_error(
1777 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1778 llvm::Twine(opMinRegionCount) +
1779 " regions but was built with regions=" + llvm::Twine(*regions))
1780 .str()
1781 .c_str());
1782 }
1783
1784 // Unpack results.
1785 std::vector<PyType *> resultTypes;
1786 if (resultTypeList.has_value()) {
1787 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1788 resultSegmentLengths, resultTypes);
1789 }
1790
1791 // Unpack operands.
1793 operands.reserve(operands.size());
1794 if (operandSegmentSpecObj.is_none()) {
1795 // Non-sized operand unpacking.
1796 for (const auto &it : llvm::enumerate(operandList)) {
1797 try {
1798 operands.push_back(getOpResultOrValue(it.value()));
1799 } catch (nb::builtin_exception &err) {
1800 throw nb::value_error((llvm::Twine("Operand ") +
1801 llvm::Twine(it.index()) + " of operation \"" +
1802 name + "\" must be a Value (" + err.what() + ")")
1803 .str()
1804 .c_str());
1805 }
1806 }
1807 } else {
1808 // Sized operand unpacking.
1809 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1810 if (operandSegmentSpec.size() != operandList.size()) {
1811 throw nb::value_error((llvm::Twine("Operation \"") + name +
1812 "\" requires " +
1813 llvm::Twine(operandSegmentSpec.size()) +
1814 "operand segments but was provided " +
1815 llvm::Twine(operandList.size()))
1816 .str()
1817 .c_str());
1818 }
1819 operandSegmentLengths.reserve(operandList.size());
1820 for (const auto &it :
1821 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1822 int segmentSpec = std::get<1>(it.value());
1823 if (segmentSpec == 1 || segmentSpec == 0) {
1824 // Unpack unary element.
1825 auto &operand = std::get<0>(it.value());
1826 if (!operand.is_none()) {
1827 try {
1828
1829 operands.push_back(getOpResultOrValue(operand));
1830 } catch (nb::builtin_exception &err) {
1831 throw nb::value_error((llvm::Twine("Operand ") +
1832 llvm::Twine(it.index()) +
1833 " of operation \"" + name +
1834 "\" must be a Value (" + err.what() + ")")
1835 .str()
1836 .c_str());
1837 }
1838
1839 operandSegmentLengths.push_back(1);
1840 } else if (segmentSpec == 0) {
1841 // Allowed to be optional.
1842 operandSegmentLengths.push_back(0);
1843 } else {
1844 throw nb::value_error(
1845 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
1846 " of operation \"" + name +
1847 "\" must be a Value (was None and operand is not optional)")
1848 .str()
1849 .c_str());
1850 }
1851 } else if (segmentSpec == -1) {
1852 // Unpack sequence by appending.
1853 try {
1854 if (std::get<0>(it.value()).is_none()) {
1855 // Treat it as an empty list.
1856 operandSegmentLengths.push_back(0);
1857 } else {
1858 // Unpack the list.
1859 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1860 for (nb::handle segmentItem : segment) {
1861 operands.push_back(getOpResultOrValue(segmentItem));
1862 }
1863 operandSegmentLengths.push_back(nb::len(segment));
1864 }
1865 } catch (std::exception &err) {
1866 // NOTE: Sloppy to be using a catch-all here, but there are at least
1867 // three different unrelated exceptions that can be thrown in the
1868 // above "casts". Just keep the scope above small and catch them all.
1869 throw nb::value_error((llvm::Twine("Operand ") +
1870 llvm::Twine(it.index()) + " of operation \"" +
1871 name + "\" must be a Sequence of Values (" +
1872 err.what() + ")")
1873 .str()
1874 .c_str());
1875 }
1876 } else {
1877 throw nb::value_error("Unexpected segment spec");
1878 }
1879 }
1880 }
1881
1882 // Merge operand/result segment lengths into attributes if needed.
1883 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1884 // Dup.
1885 if (attributes) {
1886 attributes = nb::dict(*attributes);
1887 } else {
1888 attributes = nb::dict();
1889 }
1890 if (attributes->contains("resultSegmentSizes") ||
1891 attributes->contains("operandSegmentSizes")) {
1892 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1893 "'operandSegmentSizes' attribute is unsupported. "
1894 "Use Operation.create for such low-level access.");
1895 }
1896
1897 // Add resultSegmentSizes attribute.
1898 if (!resultSegmentLengths.empty()) {
1899 MlirAttribute segmentLengthAttr =
1900 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1901 resultSegmentLengths.data());
1902 (*attributes)["resultSegmentSizes"] =
1903 PyAttribute(context, segmentLengthAttr);
1904 }
1905
1906 // Add operandSegmentSizes attribute.
1907 if (!operandSegmentLengths.empty()) {
1908 MlirAttribute segmentLengthAttr =
1909 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1910 operandSegmentLengths.data());
1911 (*attributes)["operandSegmentSizes"] =
1912 PyAttribute(context, segmentLengthAttr);
1913 }
1914 }
1915
1916 // Delegate to create.
1917 return PyOperation::create(name,
1918 /*results=*/std::move(resultTypes),
1919 /*operands=*/operands,
1920 /*attributes=*/std::move(attributes),
1921 /*successors=*/std::move(successors),
1922 /*regions=*/*regions, location, maybeIp,
1923 !resultTypeList);
1924}
1925
1926nb::object PyOpView::constructDerived(const nb::object &cls,
1927 const nb::object &operation) {
1928 nb::handle opViewType = nb::type<PyOpView>();
1929 nb::object instance = cls.attr("__new__")(cls);
1930 opViewType.attr("__init__")(instance, operation);
1931 return instance;
1932}
1933
1934PyOpView::PyOpView(const nb::object &operationObject)
1935 // Casting through the PyOperationBase base-class and then back to the
1936 // Operation lets us accept any PyOperationBase subclass.
1937 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1938 operationObject(operation.getRef().getObject()) {}
1939
1940//------------------------------------------------------------------------------
1941// PyInsertionPoint.
1942//------------------------------------------------------------------------------
1943
1944PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1945
1947 : refOperation(beforeOperationBase.getOperation().getRef()),
1948 block((*refOperation)->getBlock()) {}
1949
1951 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1952
1954 PyOperation &operation = operationBase.getOperation();
1955 if (operation.isAttached())
1956 throw nb::value_error(
1957 "Attempt to insert operation that is already attached");
1958 block.getParentOperation()->checkValid();
1959 MlirOperation beforeOp = {nullptr};
1960 if (refOperation) {
1961 // Insert before operation.
1962 (*refOperation)->checkValid();
1963 beforeOp = (*refOperation)->get();
1964 } else {
1965 // Insert at end (before null) is only valid if the block does not
1966 // already end in a known terminator (violating this will cause assertion
1967 // failures later).
1968 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1969 throw nb::index_error("Cannot insert operation at the end of a block "
1970 "that already has a terminator. Did you mean to "
1971 "use 'InsertionPoint.at_block_terminator(block)' "
1972 "versus 'InsertionPoint(block)'?");
1973 }
1974 }
1975 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1976 operation.setAttached();
1977}
1978
1980 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1981 if (mlirOperationIsNull(firstOp)) {
1982 // Just insert at end.
1983 return PyInsertionPoint(block);
1984 }
1985
1986 // Insert before first op.
1988 block.getParentOperation()->getContext(), firstOp);
1989 return PyInsertionPoint{block, std::move(firstOpRef)};
1990}
1991
1993 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1994 if (mlirOperationIsNull(terminator))
1995 throw nb::value_error("Block has no terminator");
1996 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1997 block.getParentOperation()->getContext(), terminator);
1998 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1999}
2000
2002 PyOperation &operation = op.getOperation();
2003 PyBlock block = operation.getBlock();
2004 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
2005 if (mlirOperationIsNull(nextOperation))
2006 return PyInsertionPoint(block);
2008 block.getParentOperation()->getContext(), nextOperation);
2009 return PyInsertionPoint{block, std::move(nextOpRef)};
2010}
2011
2012size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
2013
2014nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2015 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
2016}
2017
2018void PyInsertionPoint::contextExit(const nb::object &excType,
2019 const nb::object &excVal,
2020 const nb::object &excTb) {
2022}
2023
2024//------------------------------------------------------------------------------
2025// PyAttribute.
2026//------------------------------------------------------------------------------
2027
2028bool PyAttribute::operator==(const PyAttribute &other) const {
2029 return mlirAttributeEqual(attr, other.attr);
2030}
2031
2033 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2034}
2035
2037 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2038 if (mlirAttributeIsNull(rawAttr))
2039 throw nb::python_error();
2040 return PyAttribute(
2042}
2043
2045 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
2046 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2047 "mlirTypeID was expected to be non-null.");
2048 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
2049 mlirTypeID, mlirAttributeGetDialect(this->get()));
2050 // nb::rv_policy::move means use std::move to move the return value
2051 // contents into a new instance that will be owned by Python.
2052 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2053 if (!typeCaster)
2054 return thisObj;
2055 return typeCaster.value()(thisObj);
2056}
2057
2058//------------------------------------------------------------------------------
2059// PyNamedAttribute.
2060//------------------------------------------------------------------------------
2061
2062PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2063 : ownedName(new std::string(std::move(ownedName))) {
2066 toMlirStringRef(*this->ownedName)),
2067 attr);
2068}
2069
2070//------------------------------------------------------------------------------
2071// PyType.
2072//------------------------------------------------------------------------------
2073
2074bool PyType::operator==(const PyType &other) const {
2075 return mlirTypeEqual(type, other.type);
2076}
2077
2079 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2080}
2081
2083 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2084 if (mlirTypeIsNull(rawType))
2085 throw nb::python_error();
2087 rawType);
2088}
2089
2091 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
2092 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2093 "mlirTypeID was expected to be non-null.");
2094 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
2095 mlirTypeID, mlirTypeGetDialect(this->get()));
2096 // nb::rv_policy::move means use std::move to move the return value
2097 // contents into a new instance that will be owned by Python.
2098 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2099 if (!typeCaster)
2100 return thisObj;
2101 return typeCaster.value()(thisObj);
2102}
2103
2104//------------------------------------------------------------------------------
2105// PyTypeID.
2106//------------------------------------------------------------------------------
2107
2109 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2110}
2111
2113 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2114 if (mlirTypeIDIsNull(mlirTypeID))
2115 throw nb::python_error();
2116 return PyTypeID(mlirTypeID);
2117}
2118bool PyTypeID::operator==(const PyTypeID &other) const {
2119 return mlirTypeIDEqual(typeID, other.typeID);
2120}
2121
2122//------------------------------------------------------------------------------
2123// PyValue and subclasses.
2124//------------------------------------------------------------------------------
2125
2127 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2128}
2129
2131 MlirType type = mlirValueGetType(get());
2132 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2133 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2134 "mlirTypeID was expected to be non-null.");
2135 std::optional<nb::callable> valueCaster =
2137 // nb::rv_policy::move means use std::move to move the return value
2138 // contents into a new instance that will be owned by Python.
2139 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2140 if (!valueCaster)
2141 return thisObj;
2142 return valueCaster.value()(thisObj);
2143}
2144
2146 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2147 if (mlirValueIsNull(value))
2148 throw nb::python_error();
2149 MlirOperation owner;
2150 if (mlirValueIsAOpResult(value))
2151 owner = mlirOpResultGetOwner(value);
2152 if (mlirValueIsABlockArgument(value))
2154 if (mlirOperationIsNull(owner))
2155 throw nb::python_error();
2156 MlirContext ctx = mlirOperationGetContext(owner);
2157 PyOperationRef ownerRef =
2159 return PyValue(ownerRef, value);
2160}
2161
2162//------------------------------------------------------------------------------
2163// PySymbolTable.
2164//------------------------------------------------------------------------------
2165
2167 : operation(operation.getOperation().getRef()) {
2168 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2169 if (mlirSymbolTableIsNull(symbolTable)) {
2170 throw nb::type_error("Operation is not a Symbol Table.");
2171 }
2172}
2173
2174nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2175 operation->checkValid();
2176 MlirOperation symbol = mlirSymbolTableLookup(
2177 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2178 if (mlirOperationIsNull(symbol))
2179 throw nb::key_error(
2180 ("Symbol '" + name + "' not in the symbol table.").c_str());
2181
2182 return PyOperation::forOperation(operation->getContext(), symbol,
2183 operation.getObject())
2184 ->createOpView();
2185}
2186
2188 operation->checkValid();
2189 symbol.getOperation().checkValid();
2190 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2191 // The operation is also erased, so we must invalidate it. There may be Python
2192 // references to this operation so we don't want to delete it from the list of
2193 // live operations here.
2194 symbol.getOperation().valid = false;
2195}
2196
2197void PySymbolTable::dunderDel(const std::string &name) {
2198 nb::object operation = dunderGetItem(name);
2199 erase(nb::cast<PyOperationBase &>(operation));
2200}
2201
2203 operation->checkValid();
2204 symbol.getOperation().checkValid();
2205 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2207 if (mlirAttributeIsNull(symbolAttr))
2208 throw nb::value_error("Expected operation to have a symbol name.");
2209 return PyStringAttribute(
2210 symbol.getOperation().getContext(),
2211 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2212}
2213
2215 // Op must already be a symbol.
2216 PyOperation &operation = symbol.getOperation();
2217 operation.checkValid();
2219 MlirAttribute existingNameAttr =
2220 mlirOperationGetAttributeByName(operation.get(), attrName);
2221 if (mlirAttributeIsNull(existingNameAttr))
2222 throw nb::value_error("Expected operation to have a symbol name.");
2223 return PyStringAttribute(symbol.getOperation().getContext(),
2224 existingNameAttr);
2225}
2226
2228 const std::string &name) {
2229 // Op must already be a symbol.
2230 PyOperation &operation = symbol.getOperation();
2231 operation.checkValid();
2233 MlirAttribute existingNameAttr =
2234 mlirOperationGetAttributeByName(operation.get(), attrName);
2235 if (mlirAttributeIsNull(existingNameAttr))
2236 throw nb::value_error("Expected operation to have a symbol name.");
2237 MlirAttribute newNameAttr =
2238 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2239 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2240}
2241
2243 PyOperation &operation = symbol.getOperation();
2244 operation.checkValid();
2246 MlirAttribute existingVisAttr =
2247 mlirOperationGetAttributeByName(operation.get(), attrName);
2248 if (mlirAttributeIsNull(existingVisAttr))
2249 throw nb::value_error("Expected operation to have a symbol visibility.");
2250 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2251}
2252
2254 const std::string &visibility) {
2255 if (visibility != "public" && visibility != "private" &&
2256 visibility != "nested")
2257 throw nb::value_error(
2258 "Expected visibility to be 'public', 'private' or 'nested'");
2259 PyOperation &operation = symbol.getOperation();
2260 operation.checkValid();
2262 MlirAttribute existingVisAttr =
2263 mlirOperationGetAttributeByName(operation.get(), attrName);
2264 if (mlirAttributeIsNull(existingVisAttr))
2265 throw nb::value_error("Expected operation to have a symbol visibility.");
2266 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2267 toMlirStringRef(visibility));
2268 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2269}
2270
2271void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2272 const std::string &newSymbol,
2273 PyOperationBase &from) {
2274 PyOperation &fromOperation = from.getOperation();
2275 fromOperation.checkValid();
2277 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2278 from.getOperation())))
2279
2280 throw nb::value_error("Symbol rename failed");
2281}
2282
2284 bool allSymUsesVisible,
2285 nb::object callback) {
2286 PyOperation &fromOperation = from.getOperation();
2287 fromOperation.checkValid();
2288 struct UserData {
2289 PyMlirContextRef context;
2290 nb::object callback;
2291 bool gotException;
2292 std::string exceptionWhat;
2293 nb::object exceptionType;
2294 };
2295 UserData userData{
2296 fromOperation.getContext(), std::move(callback), false, {}, {}};
2298 fromOperation.get(), allSymUsesVisible,
2299 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2300 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2301 auto pyFoundOp =
2302 PyOperation::forOperation(calleeUserData->context, foundOp);
2303 if (calleeUserData->gotException)
2304 return;
2305 try {
2306 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2307 } catch (nb::python_error &e) {
2308 calleeUserData->gotException = true;
2309 calleeUserData->exceptionWhat = e.what();
2310 calleeUserData->exceptionType = nb::borrow(e.type());
2311 }
2312 },
2313 static_cast<void *>(&userData));
2314 if (userData.gotException) {
2315 std::string message("Exception raised in callback: ");
2316 message.append(userData.exceptionWhat);
2317 throw std::runtime_error(message);
2318 }
2319}
2320
2321namespace {
2322
2323/// Python wrapper for MlirBlockArgument.
2324class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2325public:
2326 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2327 static constexpr const char *pyClassName = "BlockArgument";
2328 using PyConcreteValue::PyConcreteValue;
2329
2330 static void bindDerived(ClassTy &c) {
2331 c.def_prop_ro(
2332 "owner",
2333 [](PyBlockArgument &self) {
2334 return PyBlock(self.getParentOperation(),
2335 mlirBlockArgumentGetOwner(self.get()));
2336 },
2337 "Returns the block that owns this argument.");
2338 c.def_prop_ro(
2339 "arg_number",
2340 [](PyBlockArgument &self) {
2341 return mlirBlockArgumentGetArgNumber(self.get());
2342 },
2343 "Returns the position of this argument in the block's argument list.");
2344 c.def(
2345 "set_type",
2346 [](PyBlockArgument &self, PyType type) {
2347 return mlirBlockArgumentSetType(self.get(), type);
2348 },
2349 nb::arg("type"), "Sets the type of this block argument.");
2350 c.def(
2351 "set_location",
2352 [](PyBlockArgument &self, PyLocation loc) {
2353 return mlirBlockArgumentSetLocation(self.get(), loc);
2354 },
2355 nb::arg("loc"), "Sets the location of this block argument.");
2356 }
2357};
2358
2359/// A list of block arguments. Internally, these are stored as consecutive
2360/// elements, random access is cheap. The argument list is associated with the
2361/// operation that contains the block (detached blocks are not allowed in
2362/// Python bindings) and extends its lifetime.
2363class PyBlockArgumentList
2364 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2365public:
2366 static constexpr const char *pyClassName = "BlockArgumentList";
2367 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2368
2369 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2370 intptr_t startIndex = 0, intptr_t length = -1,
2371 intptr_t step = 1)
2372 : Sliceable(startIndex,
2373 length == -1 ? mlirBlockGetNumArguments(block) : length,
2374 step),
2375 operation(std::move(operation)), block(block) {}
2376
2377 static void bindDerived(ClassTy &c) {
2378 c.def_prop_ro(
2379 "types",
2380 [](PyBlockArgumentList &self) {
2381 return getValueTypes(self, self.operation->getContext());
2382 },
2383 "Returns a list of types for all arguments in this argument list.");
2384 }
2385
2386private:
2387 /// Give the parent CRTP class access to hook implementations below.
2388 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2389
2390 /// Returns the number of arguments in the list.
2391 intptr_t getRawNumElements() {
2392 operation->checkValid();
2393 return mlirBlockGetNumArguments(block);
2394 }
2395
2396 /// Returns `pos`-the element in the list.
2397 PyBlockArgument getRawElement(intptr_t pos) {
2398 MlirValue argument = mlirBlockGetArgument(block, pos);
2399 return PyBlockArgument(operation, argument);
2400 }
2401
2402 /// Returns a sublist of this list.
2403 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2404 intptr_t step) {
2405 return PyBlockArgumentList(operation, block, startIndex, length, step);
2406 }
2407
2408 PyOperationRef operation;
2409 MlirBlock block;
2410};
2411
2412/// A list of operation operands. Internally, these are stored as consecutive
2413/// elements, random access is cheap. The (returned) operand list is associated
2414/// with the operation whose operands these are, and thus extends the lifetime
2415/// of this operation.
2416class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2417public:
2418 static constexpr const char *pyClassName = "OpOperandList";
2419 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2420
2421 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2422 intptr_t length = -1, intptr_t step = 1)
2423 : Sliceable(startIndex,
2424 length == -1 ? mlirOperationGetNumOperands(operation->get())
2425 : length,
2426 step),
2427 operation(operation) {}
2428
2429 void dunderSetItem(intptr_t index, PyValue value) {
2430 index = wrapIndex(index);
2431 mlirOperationSetOperand(operation->get(), index, value.get());
2432 }
2433
2434 static void bindDerived(ClassTy &c) {
2435 c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
2436 nb::arg("value"),
2437 "Sets the operand at the specified index to a new value.");
2438 }
2439
2440private:
2441 /// Give the parent CRTP class access to hook implementations below.
2442 friend class Sliceable<PyOpOperandList, PyValue>;
2443
2444 intptr_t getRawNumElements() {
2445 operation->checkValid();
2446 return mlirOperationGetNumOperands(operation->get());
2447 }
2448
2449 PyValue getRawElement(intptr_t pos) {
2450 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2451 MlirOperation owner;
2452 if (mlirValueIsAOpResult(operand))
2453 owner = mlirOpResultGetOwner(operand);
2454 else if (mlirValueIsABlockArgument(operand))
2456 else
2457 assert(false && "Value must be an block arg or op result.");
2458 PyOperationRef pyOwner =
2459 PyOperation::forOperation(operation->getContext(), owner);
2460 return PyValue(pyOwner, operand);
2461 }
2462
2463 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2464 return PyOpOperandList(operation, startIndex, length, step);
2465 }
2466
2467 PyOperationRef operation;
2468};
2469
2470/// A list of operation successors. Internally, these are stored as consecutive
2471/// elements, random access is cheap. The (returned) successor list is
2472/// associated with the operation whose successors these are, and thus extends
2473/// the lifetime of this operation.
2474class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2475public:
2476 static constexpr const char *pyClassName = "OpSuccessors";
2477
2478 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2479 intptr_t length = -1, intptr_t step = 1)
2480 : Sliceable(startIndex,
2481 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2482 : length,
2483 step),
2484 operation(operation) {}
2485
2486 void dunderSetItem(intptr_t index, PyBlock block) {
2487 index = wrapIndex(index);
2488 mlirOperationSetSuccessor(operation->get(), index, block.get());
2489 }
2490
2491 static void bindDerived(ClassTy &c) {
2492 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
2493 nb::arg("block"), "Sets the successor block at the specified index.");
2494 }
2495
2496private:
2497 /// Give the parent CRTP class access to hook implementations below.
2498 friend class Sliceable<PyOpSuccessors, PyBlock>;
2499
2500 intptr_t getRawNumElements() {
2501 operation->checkValid();
2502 return mlirOperationGetNumSuccessors(operation->get());
2503 }
2504
2505 PyBlock getRawElement(intptr_t pos) {
2506 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2507 return PyBlock(operation, block);
2508 }
2509
2510 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2511 return PyOpSuccessors(operation, startIndex, length, step);
2512 }
2513
2514 PyOperationRef operation;
2515};
2516
2517/// A list of block successors. Internally, these are stored as consecutive
2518/// elements, random access is cheap. The (returned) successor list is
2519/// associated with the operation and block whose successors these are, and thus
2520/// extends the lifetime of this operation and block.
2521class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2522public:
2523 static constexpr const char *pyClassName = "BlockSuccessors";
2524
2525 PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2526 intptr_t startIndex = 0, intptr_t length = -1,
2527 intptr_t step = 1)
2528 : Sliceable(startIndex,
2529 length == -1 ? mlirBlockGetNumSuccessors(block.get())
2530 : length,
2531 step),
2532 operation(operation), block(block) {}
2533
2534private:
2535 /// Give the parent CRTP class access to hook implementations below.
2536 friend class Sliceable<PyBlockSuccessors, PyBlock>;
2537
2538 intptr_t getRawNumElements() {
2539 block.checkValid();
2540 return mlirBlockGetNumSuccessors(block.get());
2541 }
2542
2543 PyBlock getRawElement(intptr_t pos) {
2544 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2545 return PyBlock(operation, block);
2546 }
2547
2548 PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2549 return PyBlockSuccessors(block, operation, startIndex, length, step);
2550 }
2551
2552 PyOperationRef operation;
2553 PyBlock block;
2554};
2555
2556/// A list of block predecessors. The (returned) predecessor list is
2557/// associated with the operation and block whose predecessors these are, and
2558/// thus extends the lifetime of this operation and block.
2559///
2560/// WARNING: This Sliceable is more expensive than the others here because
2561/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
2562/// operands) anew for each indexed access.
2563class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2564public:
2565 static constexpr const char *pyClassName = "BlockPredecessors";
2566
2567 PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2568 intptr_t startIndex = 0, intptr_t length = -1,
2569 intptr_t step = 1)
2570 : Sliceable(startIndex,
2571 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2572 : length,
2573 step),
2574 operation(operation), block(block) {}
2575
2576private:
2577 /// Give the parent CRTP class access to hook implementations below.
2578 friend class Sliceable<PyBlockPredecessors, PyBlock>;
2579
2580 intptr_t getRawNumElements() {
2581 block.checkValid();
2582 return mlirBlockGetNumPredecessors(block.get());
2583 }
2584
2585 PyBlock getRawElement(intptr_t pos) {
2586 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2587 return PyBlock(operation, block);
2588 }
2589
2590 PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2591 intptr_t step) {
2592 return PyBlockPredecessors(block, operation, startIndex, length, step);
2593 }
2594
2595 PyOperationRef operation;
2596 PyBlock block;
2597};
2598
2599/// A list of operation attributes. Can be indexed by name, producing
2600/// attributes, or by index, producing named attributes.
2601class PyOpAttributeMap {
2602public:
2603 PyOpAttributeMap(PyOperationRef operation)
2604 : operation(std::move(operation)) {}
2605
2606 nb::typed<nb::object, PyAttribute>
2607 dunderGetItemNamed(const std::string &name) {
2608 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2609 toMlirStringRef(name));
2610 if (mlirAttributeIsNull(attr)) {
2611 throw nb::key_error("attempt to access a non-existent attribute");
2612 }
2613 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2614 }
2615
2616 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2617 if (index < 0) {
2618 index += dunderLen();
2619 }
2620 if (index < 0 || index >= dunderLen()) {
2621 throw nb::index_error("attempt to access out of bounds attribute");
2622 }
2623 MlirNamedAttribute namedAttr =
2624 mlirOperationGetAttribute(operation->get(), index);
2625 return PyNamedAttribute(
2626 namedAttr.attribute,
2627 std::string(mlirIdentifierStr(namedAttr.name).data,
2628 mlirIdentifierStr(namedAttr.name).length));
2629 }
2630
2631 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2633 attr);
2634 }
2635
2636 void dunderDelItem(const std::string &name) {
2637 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2638 toMlirStringRef(name));
2639 if (!removed)
2640 throw nb::key_error("attempt to delete a non-existent attribute");
2641 }
2642
2643 intptr_t dunderLen() {
2644 return mlirOperationGetNumAttributes(operation->get());
2645 }
2646
2647 bool dunderContains(const std::string &name) {
2648 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2649 operation->get(), toMlirStringRef(name)));
2650 }
2651
2652 static void
2653 forEachAttr(MlirOperation op,
2654 llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
2655 intptr_t n = mlirOperationGetNumAttributes(op);
2656 for (intptr_t i = 0; i < n; ++i) {
2659 fn(name, na.attribute);
2660 }
2661 }
2662
2663 static void bind(nb::module_ &m) {
2664 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2665 .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
2666 "Checks if an attribute with the given name exists in the map.")
2667 .def("__len__", &PyOpAttributeMap::dunderLen,
2668 "Returns the number of attributes in the map.")
2669 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
2670 nb::arg("name"), "Gets an attribute by name.")
2671 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
2672 nb::arg("index"), "Gets a named attribute by index.")
2673 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
2674 nb::arg("attr"), "Sets an attribute with the given name.")
2675 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
2676 "Deletes an attribute with the given name.")
2677 .def(
2678 "__iter__",
2679 [](PyOpAttributeMap &self) {
2680 nb::list keys;
2681 PyOpAttributeMap::forEachAttr(
2682 self.operation->get(),
2683 [&](MlirStringRef name, MlirAttribute) {
2684 keys.append(nb::str(name.data, name.length));
2685 });
2686 return nb::iter(keys);
2687 },
2688 "Iterates over attribute names.")
2689 .def(
2690 "keys",
2691 [](PyOpAttributeMap &self) {
2692 nb::list out;
2693 PyOpAttributeMap::forEachAttr(
2694 self.operation->get(),
2695 [&](MlirStringRef name, MlirAttribute) {
2696 out.append(nb::str(name.data, name.length));
2697 });
2698 return out;
2699 },
2700 "Returns a list of attribute names.")
2701 .def(
2702 "values",
2703 [](PyOpAttributeMap &self) {
2704 nb::list out;
2705 PyOpAttributeMap::forEachAttr(
2706 self.operation->get(),
2707 [&](MlirStringRef, MlirAttribute attr) {
2708 out.append(PyAttribute(self.operation->getContext(), attr)
2709 .maybeDownCast());
2710 });
2711 return out;
2712 },
2713 "Returns a list of attribute values.")
2714 .def(
2715 "items",
2716 [](PyOpAttributeMap &self) {
2717 nb::list out;
2718 PyOpAttributeMap::forEachAttr(
2719 self.operation->get(),
2720 [&](MlirStringRef name, MlirAttribute attr) {
2721 out.append(nb::make_tuple(
2722 nb::str(name.data, name.length),
2723 PyAttribute(self.operation->getContext(), attr)
2724 .maybeDownCast()));
2725 });
2726 return out;
2727 },
2728 "Returns a list of `(name, attribute)` tuples.");
2729 }
2730
2731private:
2732 PyOperationRef operation;
2733};
2734
2735// see
2736// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
2737
2738#ifndef _Py_CAST
2739#define _Py_CAST(type, expr) ((type)(expr))
2740#endif
2741
2742// Static inline functions should use _Py_NULL rather than using directly NULL
2743// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
2744// _Py_NULL is defined as nullptr.
2745#ifndef _Py_NULL
2746#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
2747 (defined(__cplusplus) && __cplusplus >= 201103)
2748#define _Py_NULL nullptr
2749#else
2750#define _Py_NULL NULL
2751#endif
2752#endif
2753
2754// Python 3.10.0a3
2755#if PY_VERSION_HEX < 0x030A00A3
2756
2757// bpo-42262 added Py_XNewRef()
2758#if !defined(Py_XNewRef)
2759[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
2760 Py_XINCREF(obj);
2761 return obj;
2762}
2763#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
2764#endif
2765
2766// bpo-42262 added Py_NewRef()
2767#if !defined(Py_NewRef)
2768[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
2769 Py_INCREF(obj);
2770 return obj;
2771}
2772#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
2773#endif
2774
2775#endif // Python 3.10.0a3
2776
2777// Python 3.9.0b1
2778#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
2779
2780// bpo-40429 added PyThreadState_GetFrame()
2781PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
2782 assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
2783 return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
2784}
2785
2786// bpo-40421 added PyFrame_GetBack()
2787PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
2788 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2789 return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
2790}
2791
2792// bpo-40421 added PyFrame_GetCode()
2793PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
2794 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2795 assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
2796 return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
2797}
2798
2799#endif // Python 3.9.0b1
2800
2801MlirLocation tracebackToLocation(MlirContext ctx) {
2802 size_t framesLimit =
2804 // Use a thread_local here to avoid requiring a large amount of space.
2805 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2806 frames;
2807 size_t count = 0;
2808
2809 nb::gil_scoped_acquire acquire;
2810 PyThreadState *tstate = PyThreadState_GET();
2811 PyFrameObject *next;
2812 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2813 // In the increment expression:
2814 // 1. get the next prev frame;
2815 // 2. decrement the ref count on the current frame (in order that it can get
2816 // gc'd, along with any objects in its closure and etc);
2817 // 3. set current = next.
2818 for (; pyFrame != nullptr && count < framesLimit;
2819 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2820 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2821 auto fileNameStr =
2822 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2823 llvm::StringRef fileName(fileNameStr);
2824 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2825 continue;
2826
2827 // co_qualname and PyCode_Addr2Location added in py3.11
2828#if PY_VERSION_HEX < 0x030B00F0
2829 std::string name =
2830 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2831 llvm::StringRef funcName(name);
2832 int startLine = PyFrame_GetLineNumber(pyFrame);
2833 MlirLocation loc =
2834 mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2835#else
2836 std::string name =
2837 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2838 llvm::StringRef funcName(name);
2839 int startLine, startCol, endLine, endCol;
2840 int lasti = PyFrame_GetLasti(pyFrame);
2841 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2842 &endCol)) {
2843 throw nb::python_error();
2844 }
2845 MlirLocation loc = mlirLocationFileLineColRangeGet(
2846 ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2847#endif
2848
2849 frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2850 ++count;
2851 }
2852 // When the loop breaks (after the last iter), current frame (if non-null)
2853 // is leaked without this.
2854 Py_XDECREF(pyFrame);
2855
2856 if (count == 0)
2857 return mlirLocationUnknownGet(ctx);
2858
2859 MlirLocation callee = frames[0];
2860 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2861 if (count == 1)
2862 return callee;
2863
2864 MlirLocation caller = frames[count - 1];
2865 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2866 for (int i = count - 2; i >= 1; i--)
2867 caller = mlirLocationCallSiteGet(frames[i], caller);
2868
2869 return mlirLocationCallSiteGet(callee, caller);
2870}
2871
2873maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2874 if (location.has_value())
2875 return location.value();
2876 if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2878
2880 MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2882 return {ref, mlirLoc};
2883}
2884
2885} // namespace
2886
2887//------------------------------------------------------------------------------
2888// Populates the core exports of the 'ir' submodule.
2889//------------------------------------------------------------------------------
2890
2891void mlir::python::populateIRCore(nb::module_ &m) {
2892 // disable leak warnings which tend to be false positives.
2893 nb::set_leak_warnings(false);
2894 //----------------------------------------------------------------------------
2895 // Enums.
2896 //----------------------------------------------------------------------------
2897 nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2898 .value("ERROR", MlirDiagnosticError)
2899 .value("WARNING", MlirDiagnosticWarning)
2900 .value("NOTE", MlirDiagnosticNote)
2901 .value("REMARK", MlirDiagnosticRemark);
2902
2903 nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2904 .value("PRE_ORDER", MlirWalkPreOrder)
2905 .value("POST_ORDER", MlirWalkPostOrder);
2906
2907 nb::enum_<MlirWalkResult>(m, "WalkResult")
2908 .value("ADVANCE", MlirWalkResultAdvance)
2909 .value("INTERRUPT", MlirWalkResultInterrupt)
2910 .value("SKIP", MlirWalkResultSkip);
2911
2912 //----------------------------------------------------------------------------
2913 // Mapping of Diagnostics.
2914 //----------------------------------------------------------------------------
2915 nb::class_<PyDiagnostic>(m, "Diagnostic")
2916 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
2917 "Returns the severity of the diagnostic.")
2918 .def_prop_ro("location", &PyDiagnostic::getLocation,
2919 "Returns the location associated with the diagnostic.")
2920 .def_prop_ro("message", &PyDiagnostic::getMessage,
2921 "Returns the message text of the diagnostic.")
2922 .def_prop_ro("notes", &PyDiagnostic::getNotes,
2923 "Returns a tuple of attached note diagnostics.")
2924 .def(
2925 "__str__",
2926 [](PyDiagnostic &self) -> nb::str {
2927 if (!self.isValid())
2928 return nb::str("<Invalid Diagnostic>");
2929 return self.getMessage();
2930 },
2931 "Returns the diagnostic message as a string.");
2932
2933 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2934 .def(
2935 "__init__",
2937 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2938 },
2939 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
2940 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
2941 "The severity level of the diagnostic.")
2942 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
2943 "The location associated with the diagnostic.")
2944 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
2945 "The message text of the diagnostic.")
2946 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
2947 "List of attached note diagnostics.")
2948 .def(
2949 "__str__",
2950 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
2951 "Returns the diagnostic message as a string.");
2952
2953 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2954 .def("detach", &PyDiagnosticHandler::detach,
2955 "Detaches the diagnostic handler from the context.")
2956 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
2957 "Returns True if the handler is attached to a context.")
2958 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
2959 "Returns True if an error was encountered during diagnostic "
2960 "handling.")
2961 .def("__enter__", &PyDiagnosticHandler::contextEnter,
2962 "Enters the diagnostic handler as a context manager.")
2963 .def("__exit__", &PyDiagnosticHandler::contextExit,
2964 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2965 nb::arg("traceback").none(),
2966 "Exits the diagnostic handler context manager.");
2967
2968 // Expose DefaultThreadPool to python
2969 nb::class_<PyThreadPool>(m, "ThreadPool")
2970 .def(
2971 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
2972 "Creates a new thread pool with default concurrency.")
2973 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
2974 "Returns the maximum number of threads in the pool.")
2975 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
2976 "Returns the raw pointer to the LLVM thread pool as a string.");
2977
2978 nb::class_<PyMlirContext>(m, "Context")
2979 .def(
2980 "__init__",
2981 [](PyMlirContext &self) {
2982 MlirContext context = mlirContextCreateWithThreading(false);
2983 new (&self) PyMlirContext(context);
2984 },
2985 R"(
2986 Creates a new MLIR context.
2987
2988 The context is the top-level container for all MLIR objects. It owns the storage
2989 for types, attributes, locations, and other core IR objects. A context can be
2990 configured to allow or disallow unregistered dialects and can have dialects
2991 loaded on-demand.)")
2992 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
2993 "Gets the number of live Context objects.")
2994 .def(
2995 "_get_context_again",
2996 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
2998 return ref.releaseObject();
2999 },
3000 "Gets another reference to the same context.")
3001 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
3002 "Gets the number of live modules owned by this context.")
3004 "Gets a capsule wrapping the MlirContext.")
3007 "Creates a Context from a capsule wrapping MlirContext.")
3008 .def("__enter__", &PyMlirContext::contextEnter,
3009 "Enters the context as a context manager.")
3010 .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
3011 nb::arg("exc_value").none(), nb::arg("traceback").none(),
3012 "Exits the context manager.")
3013 .def_prop_ro_static(
3014 "current",
3015 [](nb::object & /*class*/)
3016 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3018 if (!context)
3019 return {};
3020 return nb::cast(context);
3021 },
3022 nb::sig("def current(/) -> Context | None"),
3023 "Gets the Context bound to the current thread or returns None if no "
3024 "context is set.")
3025 .def_prop_ro(
3026 "dialects",
3027 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3028 "Gets a container for accessing dialects by name.")
3029 .def_prop_ro(
3030 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3031 "Alias for `dialects`.")
3032 .def(
3033 "get_dialect_descriptor",
3034 [=](PyMlirContext &self, std::string &name) {
3035 MlirDialect dialect = mlirContextGetOrLoadDialect(
3036 self.get(), {name.data(), name.size()});
3037 if (mlirDialectIsNull(dialect)) {
3038 throw nb::value_error(
3039 (Twine("Dialect '") + name + "' not found").str().c_str());
3040 }
3041 return PyDialectDescriptor(self.getRef(), dialect);
3042 },
3043 nb::arg("dialect_name"),
3044 "Gets or loads a dialect by name, returning its descriptor object.")
3045 .def_prop_rw(
3046 "allow_unregistered_dialects",
3047 [](PyMlirContext &self) -> bool {
3049 },
3050 [](PyMlirContext &self, bool value) {
3052 },
3053 "Controls whether unregistered dialects are allowed in this context.")
3054 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3055 nb::arg("callback"),
3056 "Attaches a diagnostic handler that will receive callbacks.")
3057 .def(
3058 "enable_multithreading",
3059 [](PyMlirContext &self, bool enable) {
3060 mlirContextEnableMultithreading(self.get(), enable);
3061 },
3062 nb::arg("enable"),
3063 R"(
3064 Enables or disables multi-threading support in the context.
3065
3066 Args:
3067 enable: Whether to enable (True) or disable (False) multi-threading.
3068 )")
3069 .def(
3070 "set_thread_pool",
3071 [](PyMlirContext &self, PyThreadPool &pool) {
3072 // we should disable multi-threading first before setting
3073 // new thread pool otherwise the assert in
3074 // MLIRContext::setThreadPool will be raised.
3075 mlirContextEnableMultithreading(self.get(), false);
3076 mlirContextSetThreadPool(self.get(), pool.get());
3077 },
3078 R"(
3079 Sets a custom thread pool for the context to use.
3080
3081 Args:
3082 pool: A ThreadPool object to use for parallel operations.
3083
3084 Note:
3085 Multi-threading is automatically disabled before setting the thread pool.)")
3086 .def(
3087 "get_num_threads",
3088 [](PyMlirContext &self) {
3089 return mlirContextGetNumThreads(self.get());
3090 },
3091 "Gets the number of threads in the context's thread pool.")
3092 .def(
3093 "_mlir_thread_pool_ptr",
3094 [](PyMlirContext &self) {
3095 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3096 std::stringstream ss;
3097 ss << pool.ptr;
3098 return ss.str();
3099 },
3100 "Gets the raw pointer to the LLVM thread pool as a string.")
3101 .def(
3102 "is_registered_operation",
3103 [](PyMlirContext &self, std::string &name) {
3105 self.get(), MlirStringRef{name.data(), name.size()});
3106 },
3107 nb::arg("operation_name"),
3108 R"(
3109 Checks whether an operation with the given name is registered.
3110
3111 Args:
3112 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3113
3114 Returns:
3115 True if the operation is registered, False otherwise.)")
3116 .def(
3117 "append_dialect_registry",
3118 [](PyMlirContext &self, PyDialectRegistry &registry) {
3119 mlirContextAppendDialectRegistry(self.get(), registry);
3120 },
3121 nb::arg("registry"),
3122 R"(
3123 Appends the contents of a dialect registry to the context.
3124
3125 Args:
3126 registry: A DialectRegistry containing dialects to append.)")
3127 .def_prop_rw("emit_error_diagnostics",
3130 R"(
3131 Controls whether error diagnostics are emitted to diagnostic handlers.
3132
3133 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3134 .def(
3135 "load_all_available_dialects",
3136 [](PyMlirContext &self) {
3138 },
3139 R"(
3140 Loads all dialects available in the registry into the context.
3141
3142 This eagerly loads all dialects that have been registered, making them
3143 immediately available for use.)");
3144
3145 //----------------------------------------------------------------------------
3146 // Mapping of PyDialectDescriptor
3147 //----------------------------------------------------------------------------
3148 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3149 .def_prop_ro(
3150 "namespace",
3151 [](PyDialectDescriptor &self) {
3153 return nb::str(ns.data, ns.length);
3154 },
3155 "Returns the namespace of the dialect.")
3156 .def(
3157 "__repr__",
3158 [](PyDialectDescriptor &self) {
3160 std::string repr("<DialectDescriptor ");
3161 repr.append(ns.data, ns.length);
3162 repr.append(">");
3163 return repr;
3164 },
3165 nb::sig("def __repr__(self) -> str"),
3166 "Returns a string representation of the dialect descriptor.");
3167
3168 //----------------------------------------------------------------------------
3169 // Mapping of PyDialects
3170 //----------------------------------------------------------------------------
3171 nb::class_<PyDialects>(m, "Dialects")
3172 .def(
3173 "__getitem__",
3174 [=](PyDialects &self, std::string keyName) {
3175 MlirDialect dialect =
3176 self.getDialectForKey(keyName, /*attrError=*/false);
3177 nb::object descriptor =
3178 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3179 return createCustomDialectWrapper(keyName, std::move(descriptor));
3180 },
3181 "Gets a dialect by name using subscript notation.")
3182 .def(
3183 "__getattr__",
3184 [=](PyDialects &self, std::string attrName) {
3185 MlirDialect dialect =
3186 self.getDialectForKey(attrName, /*attrError=*/true);
3187 nb::object descriptor =
3188 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3189 return createCustomDialectWrapper(attrName, std::move(descriptor));
3190 },
3191 "Gets a dialect by name using attribute notation.");
3192
3193 //----------------------------------------------------------------------------
3194 // Mapping of PyDialect
3195 //----------------------------------------------------------------------------
3196 nb::class_<PyDialect>(m, "Dialect")
3197 .def(nb::init<nb::object>(), nb::arg("descriptor"),
3198 "Creates a Dialect from a DialectDescriptor.")
3199 .def_prop_ro(
3200 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3201 "Returns the DialectDescriptor for this dialect.")
3202 .def(
3203 "__repr__",
3204 [](const nb::object &self) {
3205 auto clazz = self.attr("__class__");
3206 return nb::str("<Dialect ") +
3207 self.attr("descriptor").attr("namespace") +
3208 nb::str(" (class ") + clazz.attr("__module__") +
3209 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3210 },
3211 nb::sig("def __repr__(self) -> str"),
3212 "Returns a string representation of the dialect.");
3213
3214 //----------------------------------------------------------------------------
3215 // Mapping of PyDialectRegistry
3216 //----------------------------------------------------------------------------
3217 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3219 "Gets a capsule wrapping the MlirDialectRegistry.")
3222 "Creates a DialectRegistry from a capsule wrapping "
3223 "`MlirDialectRegistry`.")
3224 .def(nb::init<>(), "Creates a new empty dialect registry.");
3225
3226 //----------------------------------------------------------------------------
3227 // Mapping of Location
3228 //----------------------------------------------------------------------------
3229 nb::class_<PyLocation>(m, "Location")
3231 "Gets a capsule wrapping the MlirLocation.")
3233 "Creates a Location from a capsule wrapping MlirLocation.")
3234 .def("__enter__", &PyLocation::contextEnter,
3235 "Enters the location as a context manager.")
3236 .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
3237 nb::arg("exc_value").none(), nb::arg("traceback").none(),
3238 "Exits the location context manager.")
3239 .def(
3240 "__eq__",
3241 [](PyLocation &self, PyLocation &other) -> bool {
3242 return mlirLocationEqual(self, other);
3243 },
3244 "Compares two locations for equality.")
3245 .def(
3246 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3247 "Compares location with non-location object (always returns False).")
3248 .def_prop_ro_static(
3249 "current",
3250 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3252 if (!loc)
3253 return std::nullopt;
3254 return loc;
3255 },
3256 // clang-format off
3257 nb::sig("def current(/) -> Location | None"),
3258 // clang-format on
3259 "Gets the Location bound to the current thread or raises ValueError.")
3260 .def_static(
3261 "unknown",
3262 [](DefaultingPyMlirContext context) {
3263 return PyLocation(context->getRef(),
3264 mlirLocationUnknownGet(context->get()));
3265 },
3266 nb::arg("context") = nb::none(),
3267 "Gets a Location representing an unknown location.")
3268 .def_static(
3269 "callsite",
3270 [](PyLocation callee, const std::vector<PyLocation> &frames,
3271 DefaultingPyMlirContext context) {
3272 if (frames.empty())
3273 throw nb::value_error("No caller frames provided.");
3274 MlirLocation caller = frames.back().get();
3275 for (const PyLocation &frame :
3276 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
3277 caller = mlirLocationCallSiteGet(frame.get(), caller);
3278 return PyLocation(context->getRef(),
3279 mlirLocationCallSiteGet(callee.get(), caller));
3280 },
3281 nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
3282 "Gets a Location representing a caller and callsite.")
3283 .def("is_a_callsite", mlirLocationIsACallSite,
3284 "Returns True if this location is a CallSiteLoc.")
3285 .def_prop_ro(
3286 "callee",
3287 [](PyLocation &self) {
3288 return PyLocation(self.getContext(),
3290 },
3291 "Gets the callee location from a CallSiteLoc.")
3292 .def_prop_ro(
3293 "caller",
3294 [](PyLocation &self) {
3295 return PyLocation(self.getContext(),
3297 },
3298 "Gets the caller location from a CallSiteLoc.")
3299 .def_static(
3300 "file",
3301 [](std::string filename, int line, int col,
3302 DefaultingPyMlirContext context) {
3303 return PyLocation(
3304 context->getRef(),
3306 context->get(), toMlirStringRef(filename), line, col));
3307 },
3308 nb::arg("filename"), nb::arg("line"), nb::arg("col"),
3309 nb::arg("context") = nb::none(),
3310 "Gets a Location representing a file, line and column.")
3311 .def_static(
3312 "file",
3313 [](std::string filename, int startLine, int startCol, int endLine,
3314 int endCol, DefaultingPyMlirContext context) {
3315 return PyLocation(context->getRef(),
3317 context->get(), toMlirStringRef(filename),
3318 startLine, startCol, endLine, endCol));
3319 },
3320 nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
3321 nb::arg("end_line"), nb::arg("end_col"),
3322 nb::arg("context") = nb::none(),
3323 "Gets a Location representing a file, line and column range.")
3324 .def("is_a_file", mlirLocationIsAFileLineColRange,
3325 "Returns True if this location is a FileLineColLoc.")
3326 .def_prop_ro(
3327 "filename",
3328 [](MlirLocation loc) {
3329 return mlirIdentifierStr(
3331 },
3332 "Gets the filename from a FileLineColLoc.")
3333 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
3334 "Gets the start line number from a `FileLineColLoc`.")
3335 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
3336 "Gets the start column number from a `FileLineColLoc`.")
3337 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
3338 "Gets the end line number from a `FileLineColLoc`.")
3339 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
3340 "Gets the end column number from a `FileLineColLoc`.")
3341 .def_static(
3342 "fused",
3343 [](const std::vector<PyLocation> &pyLocations,
3344 std::optional<PyAttribute> metadata,
3345 DefaultingPyMlirContext context) {
3346 llvm::SmallVector<MlirLocation, 4> locations;
3347 locations.reserve(pyLocations.size());
3348 for (auto &pyLocation : pyLocations)
3349 locations.push_back(pyLocation.get());
3350 MlirLocation location = mlirLocationFusedGet(
3351 context->get(), locations.size(), locations.data(),
3352 metadata ? metadata->get() : MlirAttribute{0});
3353 return PyLocation(context->getRef(), location);
3354 },
3355 nb::arg("locations"), nb::arg("metadata") = nb::none(),
3356 nb::arg("context") = nb::none(),
3357 "Gets a Location representing a fused location with optional "
3358 "metadata.")
3359 .def("is_a_fused", mlirLocationIsAFused,
3360 "Returns True if this location is a `FusedLoc`.")
3361 .def_prop_ro(
3362 "locations",
3363 [](PyLocation &self) {
3364 unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3365 std::vector<MlirLocation> locations(numLocations);
3366 if (numLocations)
3367 mlirLocationFusedGetLocations(self, locations.data());
3368 std::vector<PyLocation> pyLocations{};
3369 pyLocations.reserve(numLocations);
3370 for (unsigned i = 0; i < numLocations; ++i)
3371 pyLocations.emplace_back(self.getContext(), locations[i]);
3372 return pyLocations;
3373 },
3374 "Gets the list of locations from a `FusedLoc`.")
3375 .def_static(
3376 "name",
3377 [](std::string name, std::optional<PyLocation> childLoc,
3378 DefaultingPyMlirContext context) {
3379 return PyLocation(
3380 context->getRef(),
3382 context->get(), toMlirStringRef(name),
3383 childLoc ? childLoc->get()
3384 : mlirLocationUnknownGet(context->get())));
3385 },
3386 nb::arg("name"), nb::arg("childLoc") = nb::none(),
3387 nb::arg("context") = nb::none(),
3388 "Gets a Location representing a named location with optional child "
3389 "location.")
3390 .def("is_a_name", mlirLocationIsAName,
3391 "Returns True if this location is a `NameLoc`.")
3392 .def_prop_ro(
3393 "name_str",
3394 [](MlirLocation loc) {
3396 },
3397 "Gets the name string from a `NameLoc`.")
3398 .def_prop_ro(
3399 "child_loc",
3400 [](PyLocation &self) {
3401 return PyLocation(self.getContext(),
3403 },
3404 "Gets the child location from a `NameLoc`.")
3405 .def_static(
3406 "from_attr",
3407 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3408 return PyLocation(context->getRef(),
3409 mlirLocationFromAttribute(attribute));
3410 },
3411 nb::arg("attribute"), nb::arg("context") = nb::none(),
3412 "Gets a Location from a `LocationAttr`.")
3413 .def_prop_ro(
3414 "context",
3415 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3416 return self.getContext().getObject();
3417 },
3418 "Context that owns the `Location`.")
3419 .def_prop_ro(
3420 "attr",
3421 [](PyLocation &self) {
3422 return PyAttribute(self.getContext(),
3424 },
3425 "Get the underlying `LocationAttr`.")
3426 .def(
3427 "emit_error",
3428 [](PyLocation &self, std::string message) {
3429 mlirEmitError(self, message.c_str());
3430 },
3431 nb::arg("message"),
3432 R"(
3433 Emits an error diagnostic at this location.
3434
3435 Args:
3436 message: The error message to emit.)")
3437 .def(
3438 "__repr__",
3439 [](PyLocation &self) {
3440 PyPrintAccumulator printAccum;
3441 mlirLocationPrint(self, printAccum.getCallback(),
3442 printAccum.getUserData());
3443 return printAccum.join();
3444 },
3445 "Returns the assembly representation of the location.");
3446
3447 //----------------------------------------------------------------------------
3448 // Mapping of Module
3449 //----------------------------------------------------------------------------
3450 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3452 "Gets a capsule wrapping the MlirModule.")
3454 R"(
3455 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3456
3457 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3458 prevent double-frees (of the underlying `mlir::Module`).)")
3459 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3460 R"(
3461 Clears the internal MLIR module reference.
3462
3463 This is used internally to prevent double-free when ownership is transferred
3464 via the C API capsule mechanism. Not intended for normal use.)")
3465 .def_static(
3466 "parse",
3467 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3468 -> nb::typed<nb::object, PyModule> {
3469 PyMlirContext::ErrorCapture errors(context->getRef());
3470 MlirModule module = mlirModuleCreateParse(
3471 context->get(), toMlirStringRef(moduleAsm));
3472 if (mlirModuleIsNull(module))
3473 throw MLIRError("Unable to parse module assembly", errors.take());
3474 return PyModule::forModule(module).releaseObject();
3475 },
3476 nb::arg("asm"), nb::arg("context") = nb::none(),
3478 .def_static(
3479 "parse",
3480 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3481 -> nb::typed<nb::object, PyModule> {
3482 PyMlirContext::ErrorCapture errors(context->getRef());
3483 MlirModule module = mlirModuleCreateParse(
3484 context->get(), toMlirStringRef(moduleAsm));
3485 if (mlirModuleIsNull(module))
3486 throw MLIRError("Unable to parse module assembly", errors.take());
3487 return PyModule::forModule(module).releaseObject();
3488 },
3489 nb::arg("asm"), nb::arg("context") = nb::none(),
3491 .def_static(
3492 "parseFile",
3493 [](const std::string &path, DefaultingPyMlirContext context)
3494 -> nb::typed<nb::object, PyModule> {
3495 PyMlirContext::ErrorCapture errors(context->getRef());
3496 MlirModule module = mlirModuleCreateParseFromFile(
3497 context->get(), toMlirStringRef(path));
3498 if (mlirModuleIsNull(module))
3499 throw MLIRError("Unable to parse module assembly", errors.take());
3500 return PyModule::forModule(module).releaseObject();
3501 },
3502 nb::arg("path"), nb::arg("context") = nb::none(),
3504 .def_static(
3505 "create",
3506 [](const std::optional<PyLocation> &loc)
3507 -> nb::typed<nb::object, PyModule> {
3508 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3509 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3510 return PyModule::forModule(module).releaseObject();
3511 },
3512 nb::arg("loc") = nb::none(), "Creates an empty module.")
3513 .def_prop_ro(
3514 "context",
3515 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3516 return self.getContext().getObject();
3517 },
3518 "Context that created the `Module`.")
3519 .def_prop_ro(
3520 "operation",
3521 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3524 self.getRef().releaseObject())
3525 .releaseObject();
3526 },
3527 "Accesses the module as an operation.")
3528 .def_prop_ro(
3529 "body",
3530 [](PyModule &self) {
3532 self.getContext(), mlirModuleGetOperation(self.get()),
3533 self.getRef().releaseObject());
3534 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3535 return returnBlock;
3536 },
3537 "Return the block for this module.")
3538 .def(
3539 "dump",
3540 [](PyModule &self) {
3542 },
3544 .def(
3545 "__str__",
3546 [](const nb::object &self) {
3547 // Defer to the operation's __str__.
3548 return self.attr("operation").attr("__str__")();
3549 },
3550 nb::sig("def __str__(self) -> str"),
3551 R"(
3552 Gets the assembly form of the operation with default options.
3553
3554 If more advanced control over the assembly formatting or I/O options is needed,
3555 use the dedicated print or get_asm method, which supports keyword arguments to
3556 customize behavior.
3557 )")
3558 .def(
3559 "__eq__",
3560 [](PyModule &self, PyModule &other) {
3561 return mlirModuleEqual(self.get(), other.get());
3562 },
3563 "other"_a, "Compares two modules for equality.")
3564 .def(
3565 "__hash__",
3566 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3567 "Returns the hash value of the module.");
3568
3569 //----------------------------------------------------------------------------
3570 // Mapping of Operation.
3571 //----------------------------------------------------------------------------
3572 nb::class_<PyOperationBase>(m, "_OperationBase")
3573 .def_prop_ro(
3575 [](PyOperationBase &self) {
3576 return self.getOperation().getCapsule();
3577 },
3578 "Gets a capsule wrapping the `MlirOperation`.")
3579 .def(
3580 "__eq__",
3581 [](PyOperationBase &self, PyOperationBase &other) {
3582 return mlirOperationEqual(self.getOperation().get(),
3583 other.getOperation().get());
3584 },
3585 "Compares two operations for equality.")
3586 .def(
3587 "__eq__",
3588 [](PyOperationBase &self, nb::object other) { return false; },
3589 "Compares operation with non-operation object (always returns "
3590 "False).")
3591 .def(
3592 "__hash__",
3593 [](PyOperationBase &self) {
3594 return mlirOperationHashValue(self.getOperation().get());
3595 },
3596 "Returns the hash value of the operation.")
3597 .def_prop_ro(
3598 "attributes",
3599 [](PyOperationBase &self) {
3600 return PyOpAttributeMap(self.getOperation().getRef());
3601 },
3602 "Returns a dictionary-like map of operation attributes.")
3603 .def_prop_ro(
3604 "context",
3605 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3606 PyOperation &concreteOperation = self.getOperation();
3607 concreteOperation.checkValid();
3608 return concreteOperation.getContext().getObject();
3609 },
3610 "Context that owns the operation.")
3611 .def_prop_ro(
3612 "name",
3613 [](PyOperationBase &self) {
3614 auto &concreteOperation = self.getOperation();
3615 concreteOperation.checkValid();
3616 MlirOperation operation = concreteOperation.get();
3617 return mlirIdentifierStr(mlirOperationGetName(operation));
3618 },
3619 "Returns the fully qualified name of the operation.")
3620 .def_prop_ro(
3621 "operands",
3622 [](PyOperationBase &self) {
3623 return PyOpOperandList(self.getOperation().getRef());
3624 },
3625 "Returns the list of operation operands.")
3626 .def_prop_ro(
3627 "regions",
3628 [](PyOperationBase &self) {
3629 return PyRegionList(self.getOperation().getRef());
3630 },
3631 "Returns the list of operation regions.")
3632 .def_prop_ro(
3633 "results",
3634 [](PyOperationBase &self) {
3635 return PyOpResultList(self.getOperation().getRef());
3636 },
3637 "Returns the list of Operation results.")
3638 .def_prop_ro(
3639 "result",
3640 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3641 auto &operation = self.getOperation();
3642 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3643 .maybeDownCast();
3644 },
3645 "Shortcut to get an op result if it has only one (throws an error "
3646 "otherwise).")
3647 .def_prop_rw(
3648 "location",
3649 [](PyOperationBase &self) {
3650 PyOperation &operation = self.getOperation();
3651 return PyLocation(operation.getContext(),
3652 mlirOperationGetLocation(operation.get()));
3653 },
3654 [](PyOperationBase &self, const PyLocation &location) {
3655 PyOperation &operation = self.getOperation();
3656 mlirOperationSetLocation(operation.get(), location.get());
3657 },
3658 nb::for_getter("Returns the source location the operation was "
3659 "defined or derived from."),
3660 nb::for_setter("Sets the source location the operation was defined "
3661 "or derived from."))
3662 .def_prop_ro(
3663 "parent",
3664 [](PyOperationBase &self)
3665 -> std::optional<nb::typed<nb::object, PyOperation>> {
3666 auto parent = self.getOperation().getParentOperation();
3667 if (parent)
3668 return parent->getObject();
3669 return {};
3670 },
3671 "Returns the parent operation, or `None` if at top level.")
3672 .def(
3673 "__str__",
3674 [](PyOperationBase &self) {
3675 return self.getAsm(/*binary=*/false,
3676 /*largeElementsLimit=*/std::nullopt,
3677 /*largeResourceLimit=*/std::nullopt,
3678 /*enableDebugInfo=*/false,
3679 /*prettyDebugInfo=*/false,
3680 /*printGenericOpForm=*/false,
3681 /*useLocalScope=*/false,
3682 /*useNameLocAsPrefix=*/false,
3683 /*assumeVerified=*/false,
3684 /*skipRegions=*/false);
3685 },
3686 nb::sig("def __str__(self) -> str"),
3687 "Returns the assembly form of the operation.")
3688 .def("print",
3689 nb::overload_cast<PyAsmState &, nb::object, bool>(
3691 nb::arg("state"), nb::arg("file") = nb::none(),
3692 nb::arg("binary") = false,
3693 R"(
3694 Prints the assembly form of the operation to a file like object.
3695
3696 Args:
3697 state: `AsmState` capturing the operation numbering and flags.
3698 file: Optional file like object to write to. Defaults to sys.stdout.
3699 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3700 .def("print",
3701 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3702 bool, bool, bool, bool, bool, bool, nb::object,
3703 bool, bool>(&PyOperationBase::print),
3704 // Careful: Lots of arguments must match up with print method.
3705 nb::arg("large_elements_limit") = nb::none(),
3706 nb::arg("large_resource_limit") = nb::none(),
3707 nb::arg("enable_debug_info") = false,
3708 nb::arg("pretty_debug_info") = false,
3709 nb::arg("print_generic_op_form") = false,
3710 nb::arg("use_local_scope") = false,
3711 nb::arg("use_name_loc_as_prefix") = false,
3712 nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
3713 nb::arg("binary") = false, nb::arg("skip_regions") = false,
3714 R"(
3715 Prints the assembly form of the operation to a file like object.
3716
3717 Args:
3718 large_elements_limit: Whether to elide elements attributes above this
3719 number of elements. Defaults to None (no limit).
3720 large_resource_limit: Whether to elide resource attributes above this
3721 number of characters. Defaults to None (no limit). If large_elements_limit
3722 is set and this is None, the behavior will be to use large_elements_limit
3723 as large_resource_limit.
3724 enable_debug_info: Whether to print debug/location information. Defaults
3725 to False.
3726 pretty_debug_info: Whether to format debug information for easier reading
3727 by a human (warning: the result is unparseable). Defaults to False.
3728 print_generic_op_form: Whether to print the generic assembly forms of all
3729 ops. Defaults to False.
3730 use_local_scope: Whether to print in a way that is more optimized for
3731 multi-threaded access but may not be consistent with how the overall
3732 module prints.
3733 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3734 prefixes for the SSA identifiers. Defaults to False.
3735 assume_verified: By default, if not printing generic form, the verifier
3736 will be run and if it fails, generic form will be printed with a comment
3737 about failed verification. While a reasonable default for interactive use,
3738 for systematic use, it is often better for the caller to verify explicitly
3739 and report failures in a more robust fashion. Set this to True if doing this
3740 in order to avoid running a redundant verification. If the IR is actually
3741 invalid, behavior is undefined.
3742 file: The file like object to write to. Defaults to sys.stdout.
3743 binary: Whether to write bytes (True) or str (False). Defaults to False.
3744 skip_regions: Whether to skip printing regions. Defaults to False.)")
3745 .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3746 nb::arg("desired_version") = nb::none(),
3747 R"(
3748 Write the bytecode form of the operation to a file like object.
3749
3750 Args:
3751 file: The file like object to write to.
3752 desired_version: Optional version of bytecode to emit.
3753 Returns:
3754 The bytecode writer status.)")
3755 .def("get_asm", &PyOperationBase::getAsm,
3756 // Careful: Lots of arguments must match up with get_asm method.
3757 nb::arg("binary") = false,
3758 nb::arg("large_elements_limit") = nb::none(),
3759 nb::arg("large_resource_limit") = nb::none(),
3760 nb::arg("enable_debug_info") = false,
3761 nb::arg("pretty_debug_info") = false,
3762 nb::arg("print_generic_op_form") = false,
3763 nb::arg("use_local_scope") = false,
3764 nb::arg("use_name_loc_as_prefix") = false,
3765 nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3766 R"(
3767 Gets the assembly form of the operation with all options available.
3768
3769 Args:
3770 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3771 False.
3772 ... others ...: See the print() method for common keyword arguments for
3773 configuring the printout.
3774 Returns:
3775 Either a bytes or str object, depending on the setting of the `binary`
3776 argument.)")
3777 .def("verify", &PyOperationBase::verify,
3778 "Verify the operation. Raises MLIRError if verification fails, and "
3779 "returns true otherwise.")
3780 .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3781 "Puts self immediately after the other operation in its parent "
3782 "block.")
3783 .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3784 "Puts self immediately before the other operation in its parent "
3785 "block.")
3786 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
3787 nb::arg("other"),
3788 R"(
3789 Checks if this operation is before another in the same block.
3790
3791 Args:
3792 other: Another operation in the same parent block.
3793
3794 Returns:
3795 True if this operation is before `other` in the operation list of the parent block.)")
3796 .def(
3797 "clone",
3798 [](PyOperationBase &self,
3799 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3800 return self.getOperation().clone(ip);
3801 },
3802 nb::arg("ip") = nb::none(),
3803 R"(
3804 Creates a deep copy of the operation.
3805
3806 Args:
3807 ip: Optional insertion point where the cloned operation should be inserted.
3808 If None, the current insertion point is used. If False, the operation
3809 remains detached.
3810
3811 Returns:
3812 A new Operation that is a clone of this operation.)")
3813 .def(
3814 "detach_from_parent",
3815 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3816 PyOperation &operation = self.getOperation();
3817 operation.checkValid();
3818 if (!operation.isAttached())
3819 throw nb::value_error("Detached operation has no parent.");
3820
3821 operation.detachFromParent();
3822 return operation.createOpView();
3823 },
3824 "Detaches the operation from its parent block.")
3825 .def_prop_ro(
3826 "attached",
3827 [](PyOperationBase &self) {
3828 PyOperation &operation = self.getOperation();
3829 operation.checkValid();
3830 return operation.isAttached();
3831 },
3832 "Reports if the operation is attached to its parent block.")
3833 .def(
3834 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
3835 R"(
3836 Erases the operation and frees its memory.
3837
3838 Note:
3839 After erasing, any Python references to the operation become invalid.)")
3840 .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3841 nb::arg("walk_order") = MlirWalkPostOrder,
3842 // clang-format off
3843 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
3844 // clang-format on
3845 R"(
3846 Walks the operation tree with a callback function.
3847
3848 Args:
3849 callback: A callable that takes an Operation and returns a WalkResult.
3850 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
3851
3852 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3853 .def_static(
3854 "create",
3855 [](std::string_view name,
3856 std::optional<std::vector<PyType *>> results,
3857 std::optional<std::vector<PyValue *>> operands,
3858 std::optional<nb::dict> attributes,
3859 std::optional<std::vector<PyBlock *>> successors, int regions,
3860 const std::optional<PyLocation> &location,
3861 const nb::object &maybeIp,
3862 bool inferType) -> nb::typed<nb::object, PyOperation> {
3863 // Unpack/validate operands.
3864 llvm::SmallVector<MlirValue, 4> mlirOperands;
3865 if (operands) {
3866 mlirOperands.reserve(operands->size());
3867 for (PyValue *operand : *operands) {
3868 if (!operand)
3869 throw nb::value_error("operand value cannot be None");
3870 mlirOperands.push_back(operand->get());
3871 }
3872 }
3873
3874 PyLocation pyLoc = maybeGetTracebackLocation(location);
3875 return PyOperation::create(name, results, mlirOperands, attributes,
3876 successors, regions, pyLoc, maybeIp,
3877 inferType);
3878 },
3879 nb::arg("name"), nb::arg("results") = nb::none(),
3880 nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
3881 nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
3882 nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
3883 nb::arg("infer_type") = false,
3884 R"(
3885 Creates a new operation.
3886
3887 Args:
3888 name: Operation name (e.g. `dialect.operation`).
3889 results: Optional sequence of Type representing op result types.
3890 operands: Optional operands of the operation.
3891 attributes: Optional Dict of {str: Attribute}.
3892 successors: Optional List of Block for the operation's successors.
3893 regions: Number of regions to create (default = 0).
3894 location: Optional Location object (defaults to resolve from context manager).
3895 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
3896 infer_type: Whether to infer result types (default = False).
3897 Returns:
3898 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
3899 .def_static(
3900 "parse",
3901 [](const std::string &sourceStr, const std::string &sourceName,
3903 -> nb::typed<nb::object, PyOpView> {
3904 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3905 ->createOpView();
3906 },
3907 nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3908 nb::arg("context") = nb::none(),
3909 "Parses an operation. Supports both text assembly format and binary "
3910 "bytecode format.")
3912 "Gets a capsule wrapping the MlirOperation.")
3915 "Creates an Operation from a capsule wrapping MlirOperation.")
3916 .def_prop_ro(
3917 "operation",
3918 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3919 return self;
3920 },
3921 "Returns self (the operation).")
3922 .def_prop_ro(
3923 "opview",
3924 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
3925 return self.createOpView();
3926 },
3927 R"(
3928 Returns an OpView of this operation.
3929
3930 Note:
3931 If the operation has a registered and loaded dialect then this OpView will
3932 be concrete wrapper class.)")
3933 .def_prop_ro("block", &PyOperation::getBlock,
3934 "Returns the block containing this operation.")
3935 .def_prop_ro(
3936 "successors",
3937 [](PyOperationBase &self) {
3938 return PyOpSuccessors(self.getOperation().getRef());
3939 },
3940 "Returns the list of Operation successors.")
3941 .def("_set_invalid", &PyOperation::setInvalid,
3942 "Invalidate the operation.");
3943
3944 auto opViewClass =
3945 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3946 .def(nb::init<nb::typed<nb::object, PyOperation>>(),
3947 nb::arg("operation"))
3948 .def(
3949 "__init__",
3950 [](PyOpView *self, std::string_view name,
3951 std::tuple<int, bool> opRegionSpec,
3952 nb::object operandSegmentSpecObj,
3953 nb::object resultSegmentSpecObj,
3954 std::optional<nb::list> resultTypeList, nb::list operandList,
3955 std::optional<nb::dict> attributes,
3956 std::optional<std::vector<PyBlock *>> successors,
3957 std::optional<int> regions,
3958 const std::optional<PyLocation> &location,
3959 const nb::object &maybeIp) {
3960 PyLocation pyLoc = maybeGetTracebackLocation(location);
3962 name, opRegionSpec, operandSegmentSpecObj,
3963 resultSegmentSpecObj, resultTypeList, operandList,
3964 attributes, successors, regions, pyLoc, maybeIp));
3965 },
3966 nb::arg("name"), nb::arg("opRegionSpec"),
3967 nb::arg("operandSegmentSpecObj") = nb::none(),
3968 nb::arg("resultSegmentSpecObj") = nb::none(),
3969 nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
3970 nb::arg("attributes") = nb::none(),
3971 nb::arg("successors") = nb::none(),
3972 nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
3973 nb::arg("ip") = nb::none())
3974 .def_prop_ro(
3975 "operation",
3976 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
3977 return self.getOperationObject();
3978 })
3979 .def_prop_ro("opview",
3980 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
3981 return self;
3982 })
3983 .def(
3984 "__str__",
3985 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3986 .def_prop_ro(
3987 "successors",
3988 [](PyOperationBase &self) {
3989 return PyOpSuccessors(self.getOperation().getRef());
3990 },
3991 "Returns the list of Operation successors.")
3992 .def(
3993 "_set_invalid",
3994 [](PyOpView &self) { self.getOperation().setInvalid(); },
3995 "Invalidate the operation.");
3996 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3997 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3998 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3999 // It is faster to pass the operation_name, ods_regions, and
4000 // ods_operand_segments/ods_result_segments as arguments to the constructor,
4001 // rather than to access them as attributes.
4002 opViewClass.attr("build_generic") = classmethod(
4003 [](nb::handle cls, std::optional<nb::list> resultTypeList,
4004 nb::list operandList, std::optional<nb::dict> attributes,
4005 std::optional<std::vector<PyBlock *>> successors,
4006 std::optional<int> regions, std::optional<PyLocation> location,
4007 const nb::object &maybeIp) {
4008 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4009 std::tuple<int, bool> opRegionSpec =
4010 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
4011 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
4012 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
4013 PyLocation pyLoc = maybeGetTracebackLocation(location);
4014 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
4015 resultSegmentSpec, resultTypeList,
4016 operandList, attributes, successors,
4017 regions, pyLoc, maybeIp);
4018 },
4019 nb::arg("cls"), nb::arg("results") = nb::none(),
4020 nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
4021 nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
4022 nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
4023 "Builds a specific, generated OpView based on class level attributes.");
4024 opViewClass.attr("parse") = classmethod(
4025 [](const nb::object &cls, const std::string &sourceStr,
4026 const std::string &sourceName,
4027 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
4028 PyOperationRef parsed =
4029 PyOperation::parse(context->getRef(), sourceStr, sourceName);
4030
4031 // Check if the expected operation was parsed, and cast to to the
4032 // appropriate `OpView` subclass if successful.
4033 // NOTE: This accesses attributes that have been automatically added to
4034 // `OpView` subclasses, and is not intended to be used on `OpView`
4035 // directly.
4036 std::string clsOpName =
4037 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4038 MlirStringRef identifier =
4040 std::string_view parsedOpName(identifier.data, identifier.length);
4041 if (clsOpName != parsedOpName)
4042 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
4043 parsedOpName + "'");
4044 return PyOpView::constructDerived(cls, parsed.getObject());
4045 },
4046 nb::arg("cls"), nb::arg("source"), nb::kw_only(),
4047 nb::arg("source_name") = "", nb::arg("context") = nb::none(),
4048 "Parses a specific, generated OpView based on class level attributes.");
4049
4050 //----------------------------------------------------------------------------
4051 // Mapping of PyRegion.
4052 //----------------------------------------------------------------------------
4053 nb::class_<PyRegion>(m, "Region")
4054 .def_prop_ro(
4055 "blocks",
4056 [](PyRegion &self) {
4057 return PyBlockList(self.getParentOperation(), self.get());
4058 },
4059 "Returns a forward-optimized sequence of blocks.")
4060 .def_prop_ro(
4061 "owner",
4062 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4063 return self.getParentOperation()->createOpView();
4064 },
4065 "Returns the operation owning this region.")
4066 .def(
4067 "__iter__",
4068 [](PyRegion &self) {
4069 self.checkValid();
4070 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
4071 return PyBlockIterator(self.getParentOperation(), firstBlock);
4072 },
4073 "Iterates over blocks in the region.")
4074 .def(
4075 "__eq__",
4076 [](PyRegion &self, PyRegion &other) {
4077 return self.get().ptr == other.get().ptr;
4078 },
4079 "Compares two regions for pointer equality.")
4080 .def(
4081 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
4082 "Compares region with non-region object (always returns False).");
4083
4084 //----------------------------------------------------------------------------
4085 // Mapping of PyBlock.
4086 //----------------------------------------------------------------------------
4087 nb::class_<PyBlock>(m, "Block")
4089 "Gets a capsule wrapping the MlirBlock.")
4090 .def_prop_ro(
4091 "owner",
4092 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4093 return self.getParentOperation()->createOpView();
4094 },
4095 "Returns the owning operation of this block.")
4096 .def_prop_ro(
4097 "region",
4098 [](PyBlock &self) {
4099 MlirRegion region = mlirBlockGetParentRegion(self.get());
4100 return PyRegion(self.getParentOperation(), region);
4101 },
4102 "Returns the owning region of this block.")
4103 .def_prop_ro(
4104 "arguments",
4105 [](PyBlock &self) {
4106 return PyBlockArgumentList(self.getParentOperation(), self.get());
4107 },
4108 "Returns a list of block arguments.")
4109 .def(
4110 "add_argument",
4111 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
4112 return PyBlockArgument(self.getParentOperation(),
4113 mlirBlockAddArgument(self.get(), type, loc));
4114 },
4115 "type"_a, "loc"_a,
4116 R"(
4117 Appends an argument of the specified type to the block.
4118
4119 Args:
4120 type: The type of the argument to add.
4121 loc: The source location for the argument.
4122
4123 Returns:
4124 The newly added block argument.)")
4125 .def(
4126 "erase_argument",
4127 [](PyBlock &self, unsigned index) {
4128 return mlirBlockEraseArgument(self.get(), index);
4129 },
4130 nb::arg("index"),
4131 R"(
4132 Erases the argument at the specified index.
4133
4134 Args:
4135 index: The index of the argument to erase.)")
4136 .def_prop_ro(
4137 "operations",
4138 [](PyBlock &self) {
4139 return PyOperationList(self.getParentOperation(), self.get());
4140 },
4141 "Returns a forward-optimized sequence of operations.")
4142 .def_static(
4143 "create_at_start",
4144 [](PyRegion &parent, const nb::sequence &pyArgTypes,
4145 const std::optional<nb::sequence> &pyArgLocs) {
4146 parent.checkValid();
4147 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4148 mlirRegionInsertOwnedBlock(parent, 0, block);
4149 return PyBlock(parent.getParentOperation(), block);
4150 },
4151 nb::arg("parent"), nb::arg("arg_types") = nb::list(),
4152 nb::arg("arg_locs") = std::nullopt,
4153 "Creates and returns a new Block at the beginning of the given "
4154 "region (with given argument types and locations).")
4155 .def(
4156 "append_to",
4157 [](PyBlock &self, PyRegion &region) {
4158 MlirBlock b = self.get();
4161 mlirRegionAppendOwnedBlock(region.get(), b);
4162 },
4163 nb::arg("region"),
4164 R"(
4165 Appends this block to a region.
4166
4167 Transfers ownership if the block is currently owned by another region.
4168
4169 Args:
4170 region: The region to append the block to.)")
4171 .def(
4172 "create_before",
4173 [](PyBlock &self, const nb::args &pyArgTypes,
4174 const std::optional<nb::sequence> &pyArgLocs) {
4175 self.checkValid();
4176 MlirBlock block =
4177 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4178 MlirRegion region = mlirBlockGetParentRegion(self.get());
4179 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4180 return PyBlock(self.getParentOperation(), block);
4181 },
4182 nb::arg("arg_types"), nb::kw_only(),
4183 nb::arg("arg_locs") = std::nullopt,
4184 "Creates and returns a new Block before this block "
4185 "(with given argument types and locations).")
4186 .def(
4187 "create_after",
4188 [](PyBlock &self, const nb::args &pyArgTypes,
4189 const std::optional<nb::sequence> &pyArgLocs) {
4190 self.checkValid();
4191 MlirBlock block =
4192 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4193 MlirRegion region = mlirBlockGetParentRegion(self.get());
4194 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4195 return PyBlock(self.getParentOperation(), block);
4196 },
4197 nb::arg("arg_types"), nb::kw_only(),
4198 nb::arg("arg_locs") = std::nullopt,
4199 "Creates and returns a new Block after this block "
4200 "(with given argument types and locations).")
4201 .def(
4202 "__iter__",
4203 [](PyBlock &self) {
4204 self.checkValid();
4205 MlirOperation firstOperation =
4207 return PyOperationIterator(self.getParentOperation(),
4208 firstOperation);
4209 },
4210 "Iterates over operations in the block.")
4211 .def(
4212 "__eq__",
4213 [](PyBlock &self, PyBlock &other) {
4214 return self.get().ptr == other.get().ptr;
4215 },
4216 "Compares two blocks for pointer equality.")
4217 .def(
4218 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4219 "Compares block with non-block object (always returns False).")
4220 .def(
4221 "__hash__",
4222 [](PyBlock &self) {
4223 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4224 },
4225 "Returns the hash value of the block.")
4226 .def(
4227 "__str__",
4228 [](PyBlock &self) {
4229 self.checkValid();
4230 PyPrintAccumulator printAccum;
4231 mlirBlockPrint(self.get(), printAccum.getCallback(),
4232 printAccum.getUserData());
4233 return printAccum.join();
4234 },
4235 "Returns the assembly form of the block.")
4236 .def(
4237 "append",
4238 [](PyBlock &self, PyOperationBase &operation) {
4239 if (operation.getOperation().isAttached())
4240 operation.getOperation().detachFromParent();
4241
4242 MlirOperation mlirOperation = operation.getOperation().get();
4243 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4244 operation.getOperation().setAttached(
4245 self.getParentOperation().getObject());
4246 },
4247 nb::arg("operation"),
4248 R"(
4249 Appends an operation to this block.
4250
4251 If the operation is currently in another block, it will be moved.
4252
4253 Args:
4254 operation: The operation to append to the block.)")
4255 .def_prop_ro(
4256 "successors",
4257 [](PyBlock &self) {
4258 return PyBlockSuccessors(self, self.getParentOperation());
4259 },
4260 "Returns the list of Block successors.")
4261 .def_prop_ro(
4262 "predecessors",
4263 [](PyBlock &self) {
4264 return PyBlockPredecessors(self, self.getParentOperation());
4265 },
4266 "Returns the list of Block predecessors.");
4267
4268 //----------------------------------------------------------------------------
4269 // Mapping of PyInsertionPoint.
4270 //----------------------------------------------------------------------------
4271
4272 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4273 .def(nb::init<PyBlock &>(), nb::arg("block"),
4274 "Inserts after the last operation but still inside the block.")
4275 .def("__enter__", &PyInsertionPoint::contextEnter,
4276 "Enters the insertion point as a context manager.")
4277 .def("__exit__", &PyInsertionPoint::contextExit,
4278 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
4279 nb::arg("traceback").none(),
4280 "Exits the insertion point context manager.")
4281 .def_prop_ro_static(
4282 "current",
4283 [](nb::object & /*class*/) {
4285 if (!ip)
4286 throw nb::value_error("No current InsertionPoint");
4287 return ip;
4288 },
4289 nb::sig("def current(/) -> InsertionPoint"),
4290 "Gets the InsertionPoint bound to the current thread or raises "
4291 "ValueError if none has been set.")
4292 .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
4293 "Inserts before a referenced operation.")
4294 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
4295 nb::arg("block"),
4296 R"(
4297 Creates an insertion point at the beginning of a block.
4298
4299 Args:
4300 block: The block at whose beginning operations should be inserted.
4301
4302 Returns:
4303 An InsertionPoint at the block's beginning.)")
4304 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4305 nb::arg("block"),
4306 R"(
4307 Creates an insertion point before a block's terminator.
4308
4309 Args:
4310 block: The block whose terminator to insert before.
4311
4312 Returns:
4313 An InsertionPoint before the terminator.
4314
4315 Raises:
4316 ValueError: If the block has no terminator.)")
4317 .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
4318 R"(
4319 Creates an insertion point immediately after an operation.
4320
4321 Args:
4322 operation: The operation after which to insert.
4323
4324 Returns:
4325 An InsertionPoint after the operation.)")
4326 .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
4327 R"(
4328 Inserts an operation at this insertion point.
4329
4330 Args:
4331 operation: The operation to insert.)")
4332 .def_prop_ro(
4333 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4334 "Returns the block that this `InsertionPoint` points to.")
4335 .def_prop_ro(
4336 "ref_operation",
4337 [](PyInsertionPoint &self)
4338 -> std::optional<nb::typed<nb::object, PyOperation>> {
4339 auto refOperation = self.getRefOperation();
4340 if (refOperation)
4341 return refOperation->getObject();
4342 return {};
4343 },
4344 "The reference operation before which new operations are "
4345 "inserted, or None if the insertion point is at the end of "
4346 "the block.");
4347
4348 //----------------------------------------------------------------------------
4349 // Mapping of PyAttribute.
4350 //----------------------------------------------------------------------------
4351 nb::class_<PyAttribute>(m, "Attribute")
4352 // Delegate to the PyAttribute copy constructor, which will also lifetime
4353 // extend the backing context which owns the MlirAttribute.
4354 .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
4355 "Casts the passed attribute to the generic `Attribute`.")
4357 "Gets a capsule wrapping the MlirAttribute.")
4358 .def_static(
4360 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4361 .def_static(
4362 "parse",
4363 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4364 -> nb::typed<nb::object, PyAttribute> {
4365 PyMlirContext::ErrorCapture errors(context->getRef());
4366 MlirAttribute attr = mlirAttributeParseGet(
4367 context->get(), toMlirStringRef(attrSpec));
4368 if (mlirAttributeIsNull(attr))
4369 throw MLIRError("Unable to parse attribute", errors.take());
4370 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4371 },
4372 nb::arg("asm"), nb::arg("context") = nb::none(),
4373 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4374 "failure.")
4375 .def_prop_ro(
4376 "context",
4377 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4378 return self.getContext().getObject();
4379 },
4380 "Context that owns the `Attribute`.")
4381 .def_prop_ro(
4382 "type",
4383 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4384 return PyType(self.getContext(), mlirAttributeGetType(self))
4385 .maybeDownCast();
4386 },
4387 "Returns the type of the `Attribute`.")
4388 .def(
4389 "get_named",
4390 [](PyAttribute &self, std::string name) {
4391 return PyNamedAttribute(self, std::move(name));
4392 },
4393 nb::keep_alive<0, 1>(),
4394 R"(
4395 Binds a name to the attribute, creating a `NamedAttribute`.
4396
4397 Args:
4398 name: The name to bind to the `Attribute`.
4399
4400 Returns:
4401 A `NamedAttribute` with the given name and this attribute.)")
4402 .def(
4403 "__eq__",
4404 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4405 "Compares two attributes for equality.")
4406 .def(
4407 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4408 "Compares attribute with non-attribute object (always returns "
4409 "False).")
4410 .def(
4411 "__hash__",
4412 [](PyAttribute &self) {
4413 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4414 },
4415 "Returns the hash value of the attribute.")
4416 .def(
4417 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4419 .def(
4420 "__str__",
4421 [](PyAttribute &self) {
4422 PyPrintAccumulator printAccum;
4423 mlirAttributePrint(self, printAccum.getCallback(),
4424 printAccum.getUserData());
4425 return printAccum.join();
4426 },
4427 "Returns the assembly form of the Attribute.")
4428 .def(
4429 "__repr__",
4430 [](PyAttribute &self) {
4431 // Generally, assembly formats are not printed for __repr__ because
4432 // this can cause exceptionally long debug output and exceptions.
4433 // However, attribute values are generally considered useful and
4434 // are printed. This may need to be re-evaluated if debug dumps end
4435 // up being excessive.
4436 PyPrintAccumulator printAccum;
4437 printAccum.parts.append("Attribute(");
4438 mlirAttributePrint(self, printAccum.getCallback(),
4439 printAccum.getUserData());
4440 printAccum.parts.append(")");
4441 return printAccum.join();
4442 },
4443 "Returns a string representation of the attribute.")
4444 .def_prop_ro(
4445 "typeid",
4446 [](PyAttribute &self) {
4447 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4448 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4449 "mlirTypeID was expected to be non-null.");
4450 return PyTypeID(mlirTypeID);
4451 },
4452 "Returns the `TypeID` of the attribute.")
4453 .def(
4455 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4456 return self.maybeDownCast();
4457 },
4458 "Downcasts the attribute to a more specific attribute if possible.");
4459
4460 //----------------------------------------------------------------------------
4461 // Mapping of PyNamedAttribute
4462 //----------------------------------------------------------------------------
4463 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4464 .def(
4465 "__repr__",
4466 [](PyNamedAttribute &self) {
4467 PyPrintAccumulator printAccum;
4468 printAccum.parts.append("NamedAttribute(");
4469 printAccum.parts.append(
4470 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4472 printAccum.parts.append("=");
4474 printAccum.getCallback(),
4475 printAccum.getUserData());
4476 printAccum.parts.append(")");
4477 return printAccum.join();
4478 },
4479 "Returns a string representation of the named attribute.")
4480 .def_prop_ro(
4481 "name",
4482 [](PyNamedAttribute &self) {
4483 return mlirIdentifierStr(self.namedAttr.name);
4484 },
4485 "The name of the `NamedAttribute` binding.")
4486 .def_prop_ro(
4487 "attr",
4488 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4489 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4490 "The underlying generic attribute of the `NamedAttribute` binding.");
4491
4492 //----------------------------------------------------------------------------
4493 // Mapping of PyType.
4494 //----------------------------------------------------------------------------
4495 nb::class_<PyType>(m, "Type")
4496 // Delegate to the PyType copy constructor, which will also lifetime
4497 // extend the backing context which owns the MlirType.
4498 .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
4499 "Casts the passed type to the generic `Type`.")
4501 "Gets a capsule wrapping the `MlirType`.")
4503 "Creates a Type from a capsule wrapping `MlirType`.")
4504 .def_static(
4505 "parse",
4506 [](std::string typeSpec,
4507 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4508 PyMlirContext::ErrorCapture errors(context->getRef());
4509 MlirType type =
4510 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4511 if (mlirTypeIsNull(type))
4512 throw MLIRError("Unable to parse type", errors.take());
4513 return PyType(context.get()->getRef(), type).maybeDownCast();
4514 },
4515 nb::arg("asm"), nb::arg("context") = nb::none(),
4516 R"(
4517 Parses the assembly form of a type.
4518
4519 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4520
4521 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4522 .def_prop_ro(
4523 "context",
4524 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4525 return self.getContext().getObject();
4526 },
4527 "Context that owns the `Type`.")
4528 .def(
4529 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4530 "Compares two types for equality.")
4531 .def(
4532 "__eq__", [](PyType &self, nb::object &other) { return false; },
4533 nb::arg("other").none(),
4534 "Compares type with non-type object (always returns False).")
4535 .def(
4536 "__hash__",
4537 [](PyType &self) {
4538 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4539 },
4540 "Returns the hash value of the `Type`.")
4541 .def(
4542 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4543 .def(
4544 "__str__",
4545 [](PyType &self) {
4546 PyPrintAccumulator printAccum;
4547 mlirTypePrint(self, printAccum.getCallback(),
4548 printAccum.getUserData());
4549 return printAccum.join();
4550 },
4551 "Returns the assembly form of the `Type`.")
4552 .def(
4553 "__repr__",
4554 [](PyType &self) {
4555 // Generally, assembly formats are not printed for __repr__ because
4556 // this can cause exceptionally long debug output and exceptions.
4557 // However, types are an exception as they typically have compact
4558 // assembly forms and printing them is useful.
4559 PyPrintAccumulator printAccum;
4560 printAccum.parts.append("Type(");
4561 mlirTypePrint(self, printAccum.getCallback(),
4562 printAccum.getUserData());
4563 printAccum.parts.append(")");
4564 return printAccum.join();
4565 },
4566 "Returns a string representation of the `Type`.")
4567 .def(
4569 [](PyType &self) -> nb::typed<nb::object, PyType> {
4570 return self.maybeDownCast();
4571 },
4572 "Downcasts the Type to a more specific `Type` if possible.")
4573 .def_prop_ro(
4574 "typeid",
4575 [](PyType &self) {
4576 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4577 if (!mlirTypeIDIsNull(mlirTypeID))
4578 return PyTypeID(mlirTypeID);
4579 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4580 throw nb::value_error(
4581 (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
4582 },
4583 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4584 "`Type` has no "
4585 "`TypeID`.");
4586
4587 //----------------------------------------------------------------------------
4588 // Mapping of PyTypeID.
4589 //----------------------------------------------------------------------------
4590 nb::class_<PyTypeID>(m, "TypeID")
4592 "Gets a capsule wrapping the `MlirTypeID`.")
4594 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4595 // Note, this tests whether the underlying TypeIDs are the same,
4596 // not whether the wrapper MlirTypeIDs are the same, nor whether
4597 // the Python objects are the same (i.e., PyTypeID is a value type).
4598 .def(
4599 "__eq__",
4600 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4601 "Compares two `TypeID`s for equality.")
4602 .def(
4603 "__eq__",
4604 [](PyTypeID &self, const nb::object &other) { return false; },
4605 "Compares TypeID with non-TypeID object (always returns False).")
4606 // Note, this gives the hash value of the underlying TypeID, not the
4607 // hash value of the Python object, nor the hash value of the
4608 // MlirTypeID wrapper.
4609 .def(
4610 "__hash__",
4611 [](PyTypeID &self) {
4612 return static_cast<size_t>(mlirTypeIDHashValue(self));
4613 },
4614 "Returns the hash value of the `TypeID`.");
4615
4616 //----------------------------------------------------------------------------
4617 // Mapping of Value.
4618 //----------------------------------------------------------------------------
4619 m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
4620
4621 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4622 nb::sig("class Value(Generic[_T])"))
4623 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
4624 "Creates a Value reference from another `Value`.")
4626 "Gets a capsule wrapping the `MlirValue`.")
4628 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4629 .def_prop_ro(
4630 "context",
4631 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4632 return self.getParentOperation()->getContext().getObject();
4633 },
4634 "Context in which the value lives.")
4635 .def(
4636 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4638 .def_prop_ro(
4639 "owner",
4640 [](PyValue &self) -> nb::object {
4641 MlirValue v = self.get();
4642 if (mlirValueIsAOpResult(v)) {
4643 assert(mlirOperationEqual(self.getParentOperation()->get(),
4644 mlirOpResultGetOwner(self.get())) &&
4645 "expected the owner of the value in Python to match "
4646 "that in "
4647 "the IR");
4648 return self.getParentOperation().getObject();
4649 }
4650
4652 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4653 return nb::cast(PyBlock(self.getParentOperation(), block));
4654 }
4655
4656 assert(false && "Value must be a block argument or an op result");
4657 return nb::none();
4658 },
4659 "Returns the owner of the value (`Operation` for results, `Block` "
4660 "for "
4661 "arguments).")
4662 .def_prop_ro(
4663 "uses",
4664 [](PyValue &self) {
4665 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
4666 },
4667 "Returns an iterator over uses of this value.")
4668 .def(
4669 "__eq__",
4670 [](PyValue &self, PyValue &other) {
4671 return self.get().ptr == other.get().ptr;
4672 },
4673 "Compares two values for pointer equality.")
4674 .def(
4675 "__eq__", [](PyValue &self, nb::object other) { return false; },
4676 "Compares value with non-value object (always returns False).")
4677 .def(
4678 "__hash__",
4679 [](PyValue &self) {
4680 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4681 },
4682 "Returns the hash value of the value.")
4683 .def(
4684 "__str__",
4685 [](PyValue &self) {
4686 PyPrintAccumulator printAccum;
4687 printAccum.parts.append("Value(");
4688 mlirValuePrint(self.get(), printAccum.getCallback(),
4689 printAccum.getUserData());
4690 printAccum.parts.append(")");
4691 return printAccum.join();
4692 },
4693 R"(
4694 Returns the string form of the value.
4695
4696 If the value is a block argument, this is the assembly form of its type and the
4697 position in the argument list. If the value is an operation result, this is
4698 equivalent to printing the operation that produced it.
4699 )")
4700 .def(
4701 "get_name",
4702 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4703 PyPrintAccumulator printAccum;
4704 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4705 if (useLocalScope)
4707 if (useNameLocAsPrefix)
4709 MlirAsmState valueState =
4710 mlirAsmStateCreateForValue(self.get(), flags);
4711 mlirValuePrintAsOperand(self.get(), valueState,
4712 printAccum.getCallback(),
4713 printAccum.getUserData());
4715 mlirAsmStateDestroy(valueState);
4716 return printAccum.join();
4717 },
4718 nb::arg("use_local_scope") = false,
4719 nb::arg("use_name_loc_as_prefix") = false,
4720 R"(
4721 Returns the string form of value as an operand.
4722
4723 Args:
4724 use_local_scope: Whether to use local scope for naming.
4725 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4726
4727 Returns:
4728 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4729 .def(
4730 "get_name",
4731 [](PyValue &self, PyAsmState &state) {
4732 PyPrintAccumulator printAccum;
4733 MlirAsmState valueState = state.get();
4734 mlirValuePrintAsOperand(self.get(), valueState,
4735 printAccum.getCallback(),
4736 printAccum.getUserData());
4737 return printAccum.join();
4738 },
4739 nb::arg("state"),
4740 "Returns the string form of value as an operand (i.e., the ValueID).")
4741 .def_prop_ro(
4742 "type",
4743 [](PyValue &self) -> nb::typed<nb::object, PyType> {
4744 return PyType(self.getParentOperation()->getContext(),
4745 mlirValueGetType(self.get()))
4746 .maybeDownCast();
4747 },
4748 "Returns the type of the value.")
4749 .def(
4750 "set_type",
4751 [](PyValue &self, const PyType &type) {
4752 mlirValueSetType(self.get(), type);
4753 },
4754 nb::arg("type"), "Sets the type of the value.",
4755 nb::sig("def set_type(self, type: _T)"))
4756 .def(
4757 "replace_all_uses_with",
4758 [](PyValue &self, PyValue &with) {
4759 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4760 },
4761 "Replace all uses of value with the new value, updating anything in "
4762 "the IR that uses `self` to use the other value instead.")
4763 .def(
4764 "replace_all_uses_except",
4765 [](PyValue &self, PyValue &with, PyOperation &exception) {
4766 MlirOperation exceptedUser = exception.get();
4767 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4768 },
4769 nb::arg("with_"), nb::arg("exceptions"),
4771 .def(
4772 "replace_all_uses_except",
4773 [](PyValue &self, PyValue &with, const nb::list &exceptions) {
4774 // Convert Python list to a SmallVector of MlirOperations
4775 llvm::SmallVector<MlirOperation> exceptionOps;
4776 for (nb::handle exception : exceptions) {
4777 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4778 }
4779
4781 self, with, static_cast<intptr_t>(exceptionOps.size()),
4782 exceptionOps.data());
4783 },
4784 nb::arg("with_"), nb::arg("exceptions"),
4786 .def(
4787 "replace_all_uses_except",
4788 [](PyValue &self, PyValue &with, PyOperation &exception) {
4789 MlirOperation exceptedUser = exception.get();
4790 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4791 },
4792 nb::arg("with_"), nb::arg("exceptions"),
4794 .def(
4795 "replace_all_uses_except",
4796 [](PyValue &self, PyValue &with,
4797 std::vector<PyOperation> &exceptions) {
4798 // Convert Python list to a SmallVector of MlirOperations
4799 llvm::SmallVector<MlirOperation> exceptionOps;
4800 for (PyOperation &exception : exceptions)
4801 exceptionOps.push_back(exception);
4803 self, with, static_cast<intptr_t>(exceptionOps.size()),
4804 exceptionOps.data());
4805 },
4806 nb::arg("with_"), nb::arg("exceptions"),
4808 .def(
4810 [](PyValue &self) -> nb::typed<nb::object, PyValue> {
4811 return self.maybeDownCast();
4812 },
4813 "Downcasts the `Value` to a more specific kind if possible.")
4814 .def_prop_ro(
4815 "location",
4816 [](MlirValue self) {
4817 return PyLocation(
4819 mlirValueGetLocation(self));
4820 },
4821 "Returns the source location of the value.");
4822
4823 PyBlockArgument::bind(m);
4824 PyOpResult::bind(m);
4825 PyOpOperand::bind(m);
4826
4827 nb::class_<PyAsmState>(m, "AsmState")
4828 .def(nb::init<PyValue &, bool>(), nb::arg("value"),
4829 nb::arg("use_local_scope") = false,
4830 R"(
4831 Creates an `AsmState` for consistent SSA value naming.
4832
4833 Args:
4834 value: The value to create state for.
4835 use_local_scope: Whether to use local scope for naming.)")
4836 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4837 nb::arg("use_local_scope") = false,
4838 R"(
4839 Creates an AsmState for consistent SSA value naming.
4840
4841 Args:
4842 op: The operation to create state for.
4843 use_local_scope: Whether to use local scope for naming.)");
4844
4845 //----------------------------------------------------------------------------
4846 // Mapping of SymbolTable.
4847 //----------------------------------------------------------------------------
4848 nb::class_<PySymbolTable>(m, "SymbolTable")
4849 .def(nb::init<PyOperationBase &>(),
4850 R"(
4851 Creates a symbol table for an operation.
4852
4853 Args:
4854 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4855
4856 Raises:
4857 TypeError: If the operation is not a symbol table.)")
4858 .def(
4859 "__getitem__",
4860 [](PySymbolTable &self,
4861 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4862 return self.dunderGetItem(name);
4863 },
4864 R"(
4865 Looks up a symbol by name in the symbol table.
4866
4867 Args:
4868 name: The name of the symbol to look up.
4869
4870 Returns:
4871 The operation defining the symbol.
4872
4873 Raises:
4874 KeyError: If the symbol is not found.)")
4875 .def("insert", &PySymbolTable::insert, nb::arg("operation"),
4876 R"(
4877 Inserts a symbol operation into the symbol table.
4878
4879 Args:
4880 operation: An operation with a symbol name to insert.
4881
4882 Returns:
4883 The symbol name attribute of the inserted operation.
4884
4885 Raises:
4886 ValueError: If the operation does not have a symbol name.)")
4887 .def("erase", &PySymbolTable::erase, nb::arg("operation"),
4888 R"(
4889 Erases a symbol operation from the symbol table.
4890
4891 Args:
4892 operation: The symbol operation to erase.
4893
4894 Note:
4895 The operation is also erased from the IR and invalidated.)")
4896 .def("__delitem__", &PySymbolTable::dunderDel,
4897 "Deletes a symbol by name from the symbol table.")
4898 .def(
4899 "__contains__",
4900 [](PySymbolTable &table, const std::string &name) {
4901 return !mlirOperationIsNull(mlirSymbolTableLookup(
4902 table, mlirStringRefCreate(name.data(), name.length())));
4903 },
4904 "Checks if a symbol with the given name exists in the table.")
4905 // Static helpers.
4906 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
4907 nb::arg("symbol"), nb::arg("name"),
4908 "Sets the symbol name for a symbol operation.")
4909 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
4910 nb::arg("symbol"),
4911 "Gets the symbol name from a symbol operation.")
4912 .def_static("get_visibility", &PySymbolTable::getVisibility,
4913 nb::arg("symbol"),
4914 "Gets the visibility attribute of a symbol operation.")
4915 .def_static("set_visibility", &PySymbolTable::setVisibility,
4916 nb::arg("symbol"), nb::arg("visibility"),
4917 "Sets the visibility attribute of a symbol operation.")
4918 .def_static("replace_all_symbol_uses",
4919 &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4920 nb::arg("new_symbol"), nb::arg("from_op"),
4921 "Replaces all uses of a symbol with a new symbol name within "
4922 "the given operation.")
4923 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4924 nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4925 nb::arg("callback"),
4926 "Walks symbol tables starting from an operation with a "
4927 "callback function.");
4928
4929 // Container bindings.
4930 PyBlockArgumentList::bind(m);
4931 PyBlockIterator::bind(m);
4932 PyBlockList::bind(m);
4933 PyBlockSuccessors::bind(m);
4934 PyBlockPredecessors::bind(m);
4935 PyOperationIterator::bind(m);
4936 PyOperationList::bind(m);
4937 PyOpAttributeMap::bind(m);
4938 PyOpOperandIterator::bind(m);
4939 PyOpOperandList::bind(m);
4941 PyOpSuccessors::bind(m);
4942 PyRegionIterator::bind(m);
4943 PyRegionList::bind(m);
4944
4945 // Debug bindings.
4947
4948 // Attribute builder getter.
4950
4951 nb::register_exception_translator([](const std::exception_ptr &p,
4952 void *payload) {
4953 // We can't define exceptions with custom fields through pybind, so instead
4954 // the exception class is defined in python and imported here.
4955 try {
4956 if (p)
4957 std::rethrow_exception(p);
4958 } catch (const MLIRError &e) {
4959 nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4960 .attr("MLIRError")(e.message, e.errorDiagnostics);
4961 PyErr_SetObject(PyExc_Exception, obj.ptr());
4962 }
4963 });
4964}
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:37
static MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.cpp:77
static const char kModuleParseDocstring[]
Definition IRCore.cpp:36
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1541
static nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.cpp:59
static MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1708
#define Py_XNewRef(obj)
Definition IRCore.cpp:2763
#define _Py_CAST(type, expr)
Definition IRCore.cpp:2739
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1723
#define Py_NewRef(obj)
Definition IRCore.cpp:2772
#define _Py_NULL
Definition IRCore.cpp:2750
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1294
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:65
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition IRCore.cpp:91
static void populateResultTypes(StringRef name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1611
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:47
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:446
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:472
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1090
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:669
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1110
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:681
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:673
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1127
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:521
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:732
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:437
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:178
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:728
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:639
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1275
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1146
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1076
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:202
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:468
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:210
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:793
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:705
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:220
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:460
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:215
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:233
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1122
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:740
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1294
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:643
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:237
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:252
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1086
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:454
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:425
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:225
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:677
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1150
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:736
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1290
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:415
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:812
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:717
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1137
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:429
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:647
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:504
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:635
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1118
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1132
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:713
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:655
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:807
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1286
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:513
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:659
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1279
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:817
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1114
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1095
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:685
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:589
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:256
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1271
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1082
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:802
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:206
void mlirValueDump(MlirValue value)
Definition IR.cpp:1154
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1260
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:450
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:626
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:196
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:651
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1141
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:517
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:248
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:229
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:241
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:509
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:480
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:709
void mlirTypeDump(MlirType type)
Definition IR.cpp:1265
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:798
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
A list of operation results.
Definition IRCore.cpp:1556
static void bindDerived(ClassTy &c)
Definition IRCore.cpp:1569
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1561
static constexpr const char * pyClassName
Definition IRCore.cpp:1558
Sliceable< PyOpResultList, PyOpResult > SliceableT
Definition IRCore.cpp:1559
PyOperationRef & getOperation()
Definition IRCore.cpp:1584
Python wrapper for MlirOpResult.
Definition IRCore.cpp:1512
static constexpr IsAFunctionTy isaFunction
Definition IRCore.cpp:1514
static void bindDerived(ClassTy &c)
Definition IRCore.cpp:1518
static constexpr const char * pyClassName
Definition IRCore.cpp:1515
Accumulates into a file, either writing text (default) or binary.
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
static void bind(nanobind::module_ &m)
nanobind::class_< PyOpResultList > ClassTy
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition IRModule.h:284
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRModule.h:292
BaseContextObject(PyMlirContextRef ref)
Definition IRModule.h:286
static PyLocation & resolve()
Definition IRCore.cpp:951
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRModule.h:273
static PyMlirContext & resolve()
Definition IRCore.cpp:672
ReferrentTy * get() const
Wrapper around an MlirAsmState.
Definition IRModule.h:778
MlirAsmState get()
Definition IRModule.h:803
Wrapper around the generic MlirAttribute.
Definition IRModule.h:1008
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:2036
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRModule.h:1010
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:2032
nanobind::object maybeDownCast()
Definition IRCore.cpp:2044
MlirAttribute get() const
Definition IRModule.h:1014
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:2028
Wrapper around an MlirBlock.
Definition IRModule.h:813
PyOperationRef & getParentOperation()
Definition IRModule.h:821
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:196
Represents a diagnostic handler attached to the context.
Definition IRModule.h:380
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:828
nanobind::object contextEnter()
Definition IRModule.h:391
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:834
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRModule.h:392
Python class mirroring the C MlirDiagnostic struct.
Definition IRModule.h:330
nanobind::tuple getNotes()
Definition IRCore.cpp:872
nanobind::str getMessage()
Definition IRCore.cpp:864
DiagnosticInfo getInfo()
Definition IRCore.cpp:888
PyDiagnostic(MlirDiagnostic diagnostic)
Definition IRModule.h:332
MlirDiagnosticSeverity getSeverity()
Definition IRCore.cpp:852
Wrapper around an MlirDialect.
Definition IRModule.h:435
Wrapper around an MlirDialectRegistry.
Definition IRModule.h:472
nanobind::object getCapsule()
Definition IRCore.cpp:913
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:917
User-level dialect object.
Definition IRModule.h:459
nanobind::object getDescriptor()
Definition IRModule.h:463
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRModule.h:448
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:900
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition IRModule.cpp:73
TracebackLoc & getTracebackLoc()
Definition Globals.h:153
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.h:39
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition IRModule.cpp:155
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition IRModule.cpp:169
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition IRModule.cpp:184
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition IRModule.cpp:132
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition IRModule.cpp:142
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRModule.h:837
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1992
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1979
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:2018
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:2001
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1944
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1953
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:2014
std::optional< PyOperationRef > & getRefOperation()
Definition IRModule.h:865
Wrapper around an MlirLocation.
Definition IRModule.h:299
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:929
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRModule.h:301
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:933
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:945
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:941
MlirLocation get() const
Definition IRModule.h:305
MlirContext get()
Accesses the underlying MlirContext.
Definition IRModule.h:204
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRModule.h:208
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:593
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:2012
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:608
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:557
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:602
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:568
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:561
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:598
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRModule.h:240
MlirModule get()
Gets the backing MlirModule.
Definition IRModule.h:522
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:978
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:1003
PyModuleRef getRef()
Gets a strong reference to this module.
Definition IRModule.h:525
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:1010
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRModule.h:1034
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:2062
MlirNamedAttribute namedAttr
Definition IRModule.h:1043
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRModule.h:53
nanobind::object getObject()
Definition IRModule.h:91
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRModule.h:79
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRModule.h:723
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRModule.h:726
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1926
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1742
nanobind::object getOperationObject()
Definition IRModule.h:728
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1934
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRModule.h:552
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition IRCore.cpp:1167
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1245
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1199
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1227
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1145
void moveBefore(PyOperationBase &other)
Definition IRCore.cpp:1236
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
bool verify()
Verify the operation.
Definition IRCore.cpp:1253
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRModule.h:630
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:1018
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1444
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1285
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1280
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:1070
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1424
PyOperationRef getRef()
Definition IRModule.h:643
MlirOperation get() const
Definition IRModule.h:638
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:1063
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRModule.h:648
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1261
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1433
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1271
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, llvm::ArrayRef< MlirValue > operands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1309
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:1079
void setInvalid()
Invalidate the operation.
Definition IRModule.h:690
Wrapper around an MlirRegion.
Definition IRModule.h:759
PyOperationRef & getParentOperation()
Definition IRModule.h:768
Bindings for MLIR symbol tables.
Definition IRModule.h:1266
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2197
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2283
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2271
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2253
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2227
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2202
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2187
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2166
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2214
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2242
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2174
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:692
static void popLocation(PyLocation &location)
Definition IRCore.cpp:804
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:795
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:754
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:749
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:784
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:772
static void popContext(PyMlirContext &context)
Definition IRCore.cpp:761
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:744
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:739
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:687
PyThreadContextEntry(FrameKind frameKind, nanobind::object context, nanobind::object insertionPoint, nanobind::object location)
Definition IRModule.h:120
PyInsertionPoint * getInsertionPoint()
Definition IRCore.cpp:727
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRModule.h:168
std::string _mlir_thread_pool_ptr() const
Definition IRModule.h:179
int getMaxConcurrency() const
Definition IRModule.h:176
MlirLlvmThreadPool get()
Definition IRModule.h:177
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRModule.h:904
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:2112
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:2118
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:2108
PyTypeID(MlirTypeID typeID)
Definition IRModule.h:906
Wrapper around the generic MlirType.
Definition IRModule.h:878
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRModule.h:880
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:2078
MlirType get() const
Definition IRModule.h:884
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:2082
nanobind::object maybeDownCast()
Definition IRCore.cpp:2090
bool operator==(const PyType &other) const
Definition IRCore.cpp:2074
Wrapper around the generic MlirValue.
Definition IRModule.h:1167
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRModule.h:1173
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:2145
nanobind::object maybeDownCast()
Definition IRCore.cpp:2130
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:2126
PyOperationRef & getParentOperation()
Definition IRModule.h:1178
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition Diagnostics.h:32
@ MlirDiagnosticNote
Definition Diagnostics.h:35
@ MlirDiagnosticRemark
Definition Diagnostics.h:36
@ MlirDiagnosticWarning
Definition Diagnostics.h:34
@ MlirDiagnosticError
Definition Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:265
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:851
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1197
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:117
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:841
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:273
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1379
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:137
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:311
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1358
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:844
@ MlirWalkPreOrder
Definition IR.h:845
@ MlirWalkPostOrder
Definition IR.h:846
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1306
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:392
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1363
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:84
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1327
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1240
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1223
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:73
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:946
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition IR.cpp:321
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:355
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1179
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1162
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:403
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1211
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:289
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:361
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1298
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:985
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:865
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1183
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:833
MlirWalkResult
Operation walk result.
Definition IR.h:837
@ MlirWalkResultInterrupt
Definition IR.h:839
@ MlirWalkResultSkip
Definition IR.h:840
@ MlirWalkResultAdvance
Definition IR.h:838
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:926
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1152
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:100
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1054
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:293
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:347
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1035
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:95
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition IR.cpp:343
struct MlirNamedAttribute MlirNamedAttribute
Definition IR.h:80
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:932
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:969
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:933
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1010
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1072
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1353
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1244
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:281
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1209
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1343
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:407
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:299
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:375
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1058
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:827
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1169
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:112
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:857
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1256
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1219
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:334
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1000
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:388
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:869
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:269
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1248
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1339
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:989
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:325
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:399
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:887
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:55
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:981
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:411
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1063
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1304
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1368
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:922
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:993
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:872
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1252
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1315
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:108
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:121
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:305
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:379
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:104
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:329
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1201
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1335
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:909
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1049
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:855
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1242
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:77
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:861
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1156
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:848
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:71
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:915
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:132
struct MlirStringRef MlirStringRef
Definition Support.h:77
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:47
PyObjectRef< PyOperation > PyOperationRef
Definition IRModule.h:604
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRModule.h:190
void populateIRCore(nanobind::module_ &m)
PyObjectRef< PyModule > PyModuleRef
Definition IRModule.h:511
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition Support.h:116
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:73
const char * data
Pointer to the first symbol.
Definition Support.h:74
size_t length
Length of the fragment.
Definition Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:160
static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace)
Definition IRCore.cpp:169
static nb::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:163
static void bind(nb::module_ &m)
Definition IRCore.cpp:175
Wrapper for the global LLVM debugging flag.
Definition IRCore.cpp:116
static void bind(nb::module_ &m)
Definition IRCore.cpp:127
static void set(nb::object &o, bool enable)
Definition IRCore.cpp:117
static bool get(const nb::object &)
Definition IRCore.cpp:122
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRModule.h:1318
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition IRModule.h:1323
Materialized diagnostic information.
Definition IRModule.h:342
std::vector< DiagnosticInfo > notes
Definition IRModule.h:346
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRModule.h:408
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRModule.h:418