MLIR 22.0.0git
IRCore.cpp
Go to the documentation of this file.
1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Globals.h"
10#include "IRModule.h"
11#include "NanobindUtils.h"
12#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
14#include "mlir-c/Debug.h"
15#include "mlir-c/Diagnostics.h"
16#include "mlir-c/IR.h"
17#include "mlir-c/Support.h"
20#include "nanobind/nanobind.h"
21#include "nanobind/typing.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/SmallVector.h"
24
25#include <optional>
26
27namespace nb = nanobind;
28using namespace nb::literals;
29using namespace mlir;
30using namespace mlir::python;
31
33using llvm::StringRef;
34using llvm::Twine;
35
36static const char kModuleParseDocstring[] =
37 R"(Parses a module's assembly format from a string.
38
39Returns a new MlirModule or raises an MLIRError if the parsing fails.
40
41See also: https://mlir.llvm.org/docs/LangRef/
42)";
43
44static const char kDumpDocstring[] =
45 "Dumps a debug representation of the object to stderr.";
46
48 R"(Replace all uses of this value with the `with` value, except for those
49in `exceptions`. `exceptions` can be either a single operation or a list of
50operations.
51)";
52
53//------------------------------------------------------------------------------
54// Utilities.
55//------------------------------------------------------------------------------
56
57/// Helper for creating an @classmethod.
58template <class Func, typename... Args>
59static nb::object classmethod(Func f, Args... args) {
60 nb::object cf = nb::cpp_function(f, args...);
61 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
62}
63
64static nb::object
65createCustomDialectWrapper(const std::string &dialectNamespace,
66 nb::object dialectDescriptor) {
67 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
68 if (!dialectClass) {
69 // Use the base class.
70 return nb::cast(PyDialect(std::move(dialectDescriptor)));
71 }
72
73 // Create the custom implementation.
74 return (*dialectClass)(std::move(dialectDescriptor));
75}
76
77static MlirStringRef toMlirStringRef(const std::string &s) {
78 return mlirStringRefCreate(s.data(), s.size());
79}
80
81static MlirStringRef toMlirStringRef(std::string_view s) {
82 return mlirStringRefCreate(s.data(), s.size());
83}
84
85static MlirStringRef toMlirStringRef(const nb::bytes &s) {
86 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
87}
88
89/// Create a block, using the current location context if no locations are
90/// specified.
91static MlirBlock createBlock(const nb::sequence &pyArgTypes,
92 const std::optional<nb::sequence> &pyArgLocs) {
93 SmallVector<MlirType> argTypes;
94 argTypes.reserve(nb::len(pyArgTypes));
95 for (const auto &pyType : pyArgTypes)
96 argTypes.push_back(nb::cast<PyType &>(pyType));
97
99 if (pyArgLocs) {
100 argLocs.reserve(nb::len(*pyArgLocs));
101 for (const auto &pyLoc : *pyArgLocs)
102 argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
103 } else if (!argTypes.empty()) {
104 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
105 }
106
107 if (argTypes.size() != argLocs.size())
108 throw nb::value_error(("Expected " + Twine(argTypes.size()) +
109 " locations, got: " + Twine(argLocs.size()))
110 .str()
111 .c_str());
112 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
113}
114
115/// Wrapper for the global LLVM debugging flag.
117 static void set(nb::object &o, bool enable) {
118 nb::ft_lock_guard lock(mutex);
119 mlirEnableGlobalDebug(enable);
120 }
121
122 static bool get(const nb::object &) {
123 nb::ft_lock_guard lock(mutex);
125 }
126
127 static void bind(nb::module_ &m) {
128 // Debug flags.
129 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
130 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
131 &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
132 .def_static(
133 "set_types",
134 [](const std::string &type) {
135 nb::ft_lock_guard lock(mutex);
136 mlirSetGlobalDebugType(type.c_str());
137 },
138 "types"_a, "Sets specific debug types to be produced by LLVM.")
139 .def_static(
140 "set_types",
141 [](const std::vector<std::string> &types) {
142 std::vector<const char *> pointers;
143 pointers.reserve(types.size());
144 for (const std::string &str : types)
145 pointers.push_back(str.c_str());
146 nb::ft_lock_guard lock(mutex);
147 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
148 },
149 "types"_a,
150 "Sets multiple specific debug types to be produced by LLVM.");
151 }
152
153private:
154 static nb::ft_mutex mutex;
155};
156
157nb::ft_mutex PyGlobalDebugFlag::mutex;
158
160 static bool dunderContains(const std::string &attributeKind) {
161 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
162 }
163 static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
164 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
165 if (!builder)
166 throw nb::key_error(attributeKind.c_str());
167 return *builder;
168 }
169 static void dunderSetItemNamed(const std::string &attributeKind,
170 nb::callable func, bool replace) {
171 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
172 replace);
173 }
174
175 static void bind(nb::module_ &m) {
176 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
177 .def_static("contains", &PyAttrBuilderMap::dunderContains,
178 "attribute_kind"_a,
179 "Checks whether an attribute builder is registered for the "
180 "given attribute kind.")
181 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
182 "attribute_kind"_a,
183 "Gets the registered attribute builder for the given "
184 "attribute kind.")
185 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
186 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
187 "Register an attribute builder for building MLIR "
188 "attributes from Python values.");
189 }
190};
191
192//------------------------------------------------------------------------------
193// PyBlock
194//------------------------------------------------------------------------------
195
197 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
198}
199
200//------------------------------------------------------------------------------
201// Collections.
202//------------------------------------------------------------------------------
203
204namespace {
205
206class PyRegionIterator {
207public:
208 PyRegionIterator(PyOperationRef operation, int nextIndex)
209 : operation(std::move(operation)), nextIndex(nextIndex) {}
210
211 PyRegionIterator &dunderIter() { return *this; }
212
213 PyRegion dunderNext() {
214 operation->checkValid();
215 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
216 throw nb::stop_iteration();
217 }
218 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
219 return PyRegion(operation, region);
220 }
221
222 static void bind(nb::module_ &m) {
223 nb::class_<PyRegionIterator>(m, "RegionIterator")
224 .def("__iter__", &PyRegionIterator::dunderIter,
225 "Returns an iterator over the regions in the operation.")
226 .def("__next__", &PyRegionIterator::dunderNext,
227 "Returns the next region in the iteration.");
228 }
229
230private:
231 PyOperationRef operation;
232 intptr_t nextIndex = 0;
233};
234
235/// Regions of an op are fixed length and indexed numerically so are represented
236/// with a sequence-like container.
237class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
238public:
239 static constexpr const char *pyClassName = "RegionSequence";
240
241 PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
242 intptr_t length = -1, intptr_t step = 1)
243 : Sliceable(startIndex,
244 length == -1 ? mlirOperationGetNumRegions(operation->get())
245 : length,
246 step),
247 operation(std::move(operation)) {}
248
249 PyRegionIterator dunderIter() {
250 operation->checkValid();
251 return PyRegionIterator(operation, startIndex);
252 }
253
254 static void bindDerived(ClassTy &c) {
255 c.def("__iter__", &PyRegionList::dunderIter,
256 "Returns an iterator over the regions in the sequence.");
257 }
258
259private:
260 /// Give the parent CRTP class access to hook implementations below.
261 friend class Sliceable<PyRegionList, PyRegion>;
262
263 intptr_t getRawNumElements() {
264 operation->checkValid();
265 return mlirOperationGetNumRegions(operation->get());
266 }
267
268 PyRegion getRawElement(intptr_t pos) {
269 operation->checkValid();
270 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
271 }
272
273 PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
274 return PyRegionList(operation, startIndex, length, step);
275 }
276
277 PyOperationRef operation;
278};
279
280class PyBlockIterator {
281public:
282 PyBlockIterator(PyOperationRef operation, MlirBlock next)
283 : operation(std::move(operation)), next(next) {}
284
285 PyBlockIterator &dunderIter() { return *this; }
286
287 PyBlock dunderNext() {
288 operation->checkValid();
289 if (mlirBlockIsNull(next)) {
290 throw nb::stop_iteration();
291 }
292
293 PyBlock returnBlock(operation, next);
294 next = mlirBlockGetNextInRegion(next);
295 return returnBlock;
296 }
297
298 static void bind(nb::module_ &m) {
299 nb::class_<PyBlockIterator>(m, "BlockIterator")
300 .def("__iter__", &PyBlockIterator::dunderIter,
301 "Returns an iterator over the blocks in the operation's region.")
302 .def("__next__", &PyBlockIterator::dunderNext,
303 "Returns the next block in the iteration.");
304 }
305
306private:
307 PyOperationRef operation;
308 MlirBlock next;
309};
310
311/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
312/// we present them as a more full-featured list-like container but optimize
313/// it for forward iteration. Blocks are always owned by a region.
314class PyBlockList {
315public:
316 PyBlockList(PyOperationRef operation, MlirRegion region)
317 : operation(std::move(operation)), region(region) {}
318
319 PyBlockIterator dunderIter() {
320 operation->checkValid();
321 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
322 }
323
324 intptr_t dunderLen() {
325 operation->checkValid();
326 intptr_t count = 0;
327 MlirBlock block = mlirRegionGetFirstBlock(region);
328 while (!mlirBlockIsNull(block)) {
329 count += 1;
330 block = mlirBlockGetNextInRegion(block);
331 }
332 return count;
333 }
334
335 PyBlock dunderGetItem(intptr_t index) {
336 operation->checkValid();
337 if (index < 0) {
338 index += dunderLen();
339 }
340 if (index < 0) {
341 throw nb::index_error("attempt to access out of bounds block");
342 }
343 MlirBlock block = mlirRegionGetFirstBlock(region);
344 while (!mlirBlockIsNull(block)) {
345 if (index == 0) {
346 return PyBlock(operation, block);
347 }
348 block = mlirBlockGetNextInRegion(block);
349 index -= 1;
350 }
351 throw nb::index_error("attempt to access out of bounds block");
352 }
353
354 PyBlock appendBlock(const nb::args &pyArgTypes,
355 const std::optional<nb::sequence> &pyArgLocs) {
356 operation->checkValid();
357 MlirBlock block =
358 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
359 mlirRegionAppendOwnedBlock(region, block);
360 return PyBlock(operation, block);
361 }
362
363 static void bind(nb::module_ &m) {
364 nb::class_<PyBlockList>(m, "BlockList")
365 .def("__getitem__", &PyBlockList::dunderGetItem,
366 "Returns the block at the specified index.")
367 .def("__iter__", &PyBlockList::dunderIter,
368 "Returns an iterator over blocks in the operation's region.")
369 .def("__len__", &PyBlockList::dunderLen,
370 "Returns the number of blocks in the operation's region.")
371 .def("append", &PyBlockList::appendBlock,
372 R"(
373 Appends a new block, with argument types as positional args.
374
375 Returns:
376 The created block.
377 )",
378 nb::arg("args"), nb::kw_only(),
379 nb::arg("arg_locs") = std::nullopt);
380 }
381
382private:
383 PyOperationRef operation;
384 MlirRegion region;
385};
386
387class PyOperationIterator {
388public:
389 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
390 : parentOperation(std::move(parentOperation)), next(next) {}
391
392 PyOperationIterator &dunderIter() { return *this; }
393
394 nb::typed<nb::object, PyOpView> dunderNext() {
395 parentOperation->checkValid();
396 if (mlirOperationIsNull(next)) {
397 throw nb::stop_iteration();
398 }
399
400 PyOperationRef returnOperation =
401 PyOperation::forOperation(parentOperation->getContext(), next);
402 next = mlirOperationGetNextInBlock(next);
403 return returnOperation->createOpView();
404 }
405
406 static void bind(nb::module_ &m) {
407 nb::class_<PyOperationIterator>(m, "OperationIterator")
408 .def("__iter__", &PyOperationIterator::dunderIter,
409 "Returns an iterator over the operations in an operation's block.")
410 .def("__next__", &PyOperationIterator::dunderNext,
411 "Returns the next operation in the iteration.");
412 }
413
414private:
415 PyOperationRef parentOperation;
416 MlirOperation next;
417};
418
419/// Operations are exposed by the C-API as a forward-only linked list. In
420/// Python, we present them as a more full-featured list-like container but
421/// optimize it for forward iteration. Iterable operations are always owned
422/// by a block.
423class PyOperationList {
424public:
425 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
426 : parentOperation(std::move(parentOperation)), block(block) {}
427
428 PyOperationIterator dunderIter() {
429 parentOperation->checkValid();
430 return PyOperationIterator(parentOperation,
432 }
433
434 intptr_t dunderLen() {
435 parentOperation->checkValid();
436 intptr_t count = 0;
437 MlirOperation childOp = mlirBlockGetFirstOperation(block);
438 while (!mlirOperationIsNull(childOp)) {
439 count += 1;
440 childOp = mlirOperationGetNextInBlock(childOp);
441 }
442 return count;
443 }
444
445 nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
446 parentOperation->checkValid();
447 if (index < 0) {
448 index += dunderLen();
449 }
450 if (index < 0) {
451 throw nb::index_error("attempt to access out of bounds operation");
452 }
453 MlirOperation childOp = mlirBlockGetFirstOperation(block);
454 while (!mlirOperationIsNull(childOp)) {
455 if (index == 0) {
456 return PyOperation::forOperation(parentOperation->getContext(), childOp)
457 ->createOpView();
458 }
459 childOp = mlirOperationGetNextInBlock(childOp);
460 index -= 1;
461 }
462 throw nb::index_error("attempt to access out of bounds operation");
463 }
464
465 static void bind(nb::module_ &m) {
466 nb::class_<PyOperationList>(m, "OperationList")
467 .def("__getitem__", &PyOperationList::dunderGetItem,
468 "Returns the operation at the specified index.")
469 .def("__iter__", &PyOperationList::dunderIter,
470 "Returns an iterator over operations in the list.")
471 .def("__len__", &PyOperationList::dunderLen,
472 "Returns the number of operations in the list.");
473 }
474
475private:
476 PyOperationRef parentOperation;
477 MlirBlock block;
478};
479
480class PyOpOperand {
481public:
482 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
483
484 nb::typed<nb::object, PyOpView> getOwner() {
485 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
486 PyMlirContextRef context =
488 return PyOperation::forOperation(context, owner)->createOpView();
489 }
490
491 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
492
493 static void bind(nb::module_ &m) {
494 nb::class_<PyOpOperand>(m, "OpOperand")
495 .def_prop_ro("owner", &PyOpOperand::getOwner,
496 "Returns the operation that owns this operand.")
497 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
498 "Returns the operand number in the owning operation.");
499 }
500
501private:
502 MlirOpOperand opOperand;
503};
504
505class PyOpOperandIterator {
506public:
507 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
508
509 PyOpOperandIterator &dunderIter() { return *this; }
510
511 PyOpOperand dunderNext() {
512 if (mlirOpOperandIsNull(opOperand))
513 throw nb::stop_iteration();
514
515 PyOpOperand returnOpOperand(opOperand);
516 opOperand = mlirOpOperandGetNextUse(opOperand);
517 return returnOpOperand;
518 }
519
520 static void bind(nb::module_ &m) {
521 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
522 .def("__iter__", &PyOpOperandIterator::dunderIter,
523 "Returns an iterator over operands.")
524 .def("__next__", &PyOpOperandIterator::dunderNext,
525 "Returns the next operand in the iteration.");
526 }
527
528private:
529 MlirOpOperand opOperand;
530};
531
532} // namespace
533
534//------------------------------------------------------------------------------
535// PyMlirContext
536//------------------------------------------------------------------------------
537
538PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
539 nb::gil_scoped_acquire acquire;
540 nb::ft_lock_guard lock(live_contexts_mutex);
541 auto &liveContexts = getLiveContexts();
542 liveContexts[context.ptr] = this;
543}
544
546 // Note that the only public way to construct an instance is via the
547 // forContext method, which always puts the associated handle into
548 // liveContexts.
549 nb::gil_scoped_acquire acquire;
550 {
551 nb::ft_lock_guard lock(live_contexts_mutex);
552 getLiveContexts().erase(context.ptr);
553 }
554 mlirContextDestroy(context);
555}
556
558 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
559}
560
561nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
562 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
563 if (mlirContextIsNull(rawContext))
564 throw nb::python_error();
565 return forContext(rawContext).releaseObject();
566}
567
569 nb::gil_scoped_acquire acquire;
570 nb::ft_lock_guard lock(live_contexts_mutex);
571 auto &liveContexts = getLiveContexts();
572 auto it = liveContexts.find(context.ptr);
573 if (it == liveContexts.end()) {
574 // Create.
575 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
576 nb::object pyRef = nb::cast(unownedContextWrapper);
577 assert(pyRef && "cast to nb::object failed");
578 liveContexts[context.ptr] = unownedContextWrapper;
579 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
580 }
581 // Use existing.
582 nb::object pyRef = nb::cast(it->second);
583 return PyMlirContextRef(it->second, std::move(pyRef));
584}
585
586nb::ft_mutex PyMlirContext::live_contexts_mutex;
587
588PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
589 static LiveContextMap liveContexts;
590 return liveContexts;
591}
592
594 nb::ft_lock_guard lock(live_contexts_mutex);
595 return getLiveContexts().size();
596}
597
598nb::object PyMlirContext::contextEnter(nb::object context) {
599 return PyThreadContextEntry::pushContext(context);
600}
601
602void PyMlirContext::contextExit(const nb::object &excType,
603 const nb::object &excVal,
604 const nb::object &excTb) {
606}
607
608nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
609 // Note that ownership is transferred to the delete callback below by way of
610 // an explicit inc_ref (borrow).
611 PyDiagnosticHandler *pyHandler =
612 new PyDiagnosticHandler(get(), std::move(callback));
613 nb::object pyHandlerObject =
614 nb::cast(pyHandler, nb::rv_policy::take_ownership);
615 (void)pyHandlerObject.inc_ref();
616
617 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
618 // guaranteed to be known to pybind.
619 auto handlerCallback =
620 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
621 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
622 nb::object pyDiagnosticObject =
623 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
624
625 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
626 bool result = false;
627 {
628 // Since this can be called from arbitrary C++ contexts, always get the
629 // gil.
630 nb::gil_scoped_acquire gil;
631 try {
632 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
633 } catch (std::exception &e) {
634 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
635 e.what());
636 pyHandler->hadError = true;
637 }
638 }
639
640 pyDiagnostic->invalidate();
642 };
643 auto deleteCallback = +[](void *userData) {
644 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
645 assert(pyHandler->registeredID && "handler is not registered");
646 pyHandler->registeredID.reset();
647
648 // Decrement reference, balancing the inc_ref() above.
649 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
650 pyHandlerObject.dec_ref();
651 };
652
653 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
654 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
655 return pyHandlerObject;
656}
657
658MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
659 void *userData) {
660 auto *self = static_cast<ErrorCapture *>(userData);
661 // Check if the context requested we emit errors instead of capturing them.
662 if (self->ctx->emitErrorDiagnostics)
664
667
668 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
670}
671
674 if (!context) {
675 throw std::runtime_error(
676 "An MLIR function requires a Context but none was provided in the call "
677 "or from the surrounding environment. Either pass to the function with "
678 "a 'context=' argument or establish a default using 'with Context():'");
679 }
680 return *context;
681}
682
683//------------------------------------------------------------------------------
684// PyThreadContextEntry management
685//------------------------------------------------------------------------------
686
687std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
688 static thread_local std::vector<PyThreadContextEntry> stack;
689 return stack;
690}
691
693 auto &stack = getStack();
694 if (stack.empty())
695 return nullptr;
696 return &stack.back();
697}
698
699void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
700 nb::object insertionPoint,
701 nb::object location) {
702 auto &stack = getStack();
703 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
704 std::move(location));
705 // If the new stack has more than one entry and the context of the new top
706 // entry matches the previous, copy the insertionPoint and location from the
707 // previous entry if missing from the new top entry.
708 if (stack.size() > 1) {
709 auto &prev = *(stack.rbegin() + 1);
710 auto &current = stack.back();
711 if (current.context.is(prev.context)) {
712 // Default non-context objects from the previous entry.
713 if (!current.insertionPoint)
714 current.insertionPoint = prev.insertionPoint;
715 if (!current.location)
716 current.location = prev.location;
717 }
718 }
719}
720
722 if (!context)
723 return nullptr;
724 return nb::cast<PyMlirContext *>(context);
725}
726
728 if (!insertionPoint)
729 return nullptr;
730 return nb::cast<PyInsertionPoint *>(insertionPoint);
731}
732
734 if (!location)
735 return nullptr;
736 return nb::cast<PyLocation *>(location);
737}
738
740 auto *tos = getTopOfStack();
741 return tos ? tos->getContext() : nullptr;
742}
743
745 auto *tos = getTopOfStack();
746 return tos ? tos->getInsertionPoint() : nullptr;
747}
748
750 auto *tos = getTopOfStack();
751 return tos ? tos->getLocation() : nullptr;
752}
753
754nb::object PyThreadContextEntry::pushContext(nb::object context) {
755 push(FrameKind::Context, /*context=*/context,
756 /*insertionPoint=*/nb::object(),
757 /*location=*/nb::object());
758 return context;
759}
760
762 auto &stack = getStack();
763 if (stack.empty())
764 throw std::runtime_error("Unbalanced Context enter/exit");
765 auto &tos = stack.back();
766 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
767 throw std::runtime_error("Unbalanced Context enter/exit");
768 stack.pop_back();
769}
770
771nb::object
772PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
773 PyInsertionPoint &insertionPoint =
774 nb::cast<PyInsertionPoint &>(insertionPointObj);
775 nb::object contextObj =
776 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
778 /*context=*/contextObj,
779 /*insertionPoint=*/insertionPointObj,
780 /*location=*/nb::object());
781 return insertionPointObj;
782}
783
785 auto &stack = getStack();
786 if (stack.empty())
787 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
788 auto &tos = stack.back();
789 if (tos.frameKind != FrameKind::InsertionPoint &&
790 tos.getInsertionPoint() != &insertionPoint)
791 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
792 stack.pop_back();
793}
794
795nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
796 PyLocation &location = nb::cast<PyLocation &>(locationObj);
797 nb::object contextObj = location.getContext().getObject();
798 push(FrameKind::Location, /*context=*/contextObj,
799 /*insertionPoint=*/nb::object(),
800 /*location=*/locationObj);
801 return locationObj;
802}
803
805 auto &stack = getStack();
806 if (stack.empty())
807 throw std::runtime_error("Unbalanced Location enter/exit");
808 auto &tos = stack.back();
809 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
810 throw std::runtime_error("Unbalanced Location enter/exit");
811 stack.pop_back();
812}
813
814//------------------------------------------------------------------------------
815// PyDiagnostic*
816//------------------------------------------------------------------------------
817
819 valid = false;
820 if (materializedNotes) {
821 for (nb::handle noteObject : *materializedNotes) {
822 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
823 note->invalidate();
824 }
825 }
826}
827
829 nb::object callback)
830 : context(context), callback(std::move(callback)) {}
831
833
835 if (!registeredID)
836 return;
837 MlirDiagnosticHandlerID localID = *registeredID;
838 mlirContextDetachDiagnosticHandler(context, localID);
839 assert(!registeredID && "should have unregistered");
840 // Not strictly necessary but keeps stale pointers from being around to cause
841 // issues.
842 context = {nullptr};
843}
844
845void PyDiagnostic::checkValid() {
846 if (!valid) {
847 throw std::invalid_argument(
848 "Diagnostic is invalid (used outside of callback)");
849 }
850}
851
853 checkValid();
854 return mlirDiagnosticGetSeverity(diagnostic);
855}
856
858 checkValid();
859 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
860 MlirContext context = mlirLocationGetContext(loc);
861 return PyLocation(PyMlirContext::forContext(context), loc);
862}
863
865 checkValid();
866 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
867 PyFileAccumulator accum(fileObject, /*binary=*/false);
868 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
869 return nb::cast<nb::str>(fileObject.attr("getvalue")());
870}
871
873 checkValid();
874 if (materializedNotes)
875 return *materializedNotes;
876 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
877 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
878 for (intptr_t i = 0; i < numNotes; ++i) {
879 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
880 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
881 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
882 }
883 materializedNotes = std::move(notes);
884
885 return *materializedNotes;
886}
887
889 std::vector<DiagnosticInfo> notes;
890 for (nb::handle n : getNotes())
891 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
892 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
893 std::move(notes)};
894}
895
896//------------------------------------------------------------------------------
897// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
898//------------------------------------------------------------------------------
899
900MlirDialect PyDialects::getDialectForKey(const std::string &key,
901 bool attrError) {
902 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
903 {key.data(), key.size()});
904 if (mlirDialectIsNull(dialect)) {
905 std::string msg = (Twine("Dialect '") + key + "' not found").str();
906 if (attrError)
907 throw nb::attribute_error(msg.c_str());
908 throw nb::index_error(msg.c_str());
909 }
910 return dialect;
911}
912
914 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
915}
916
918 MlirDialectRegistry rawRegistry =
920 if (mlirDialectRegistryIsNull(rawRegistry))
921 throw nb::python_error();
922 return PyDialectRegistry(rawRegistry);
923}
924
925//------------------------------------------------------------------------------
926// PyLocation
927//------------------------------------------------------------------------------
928
930 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
931}
932
934 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
935 if (mlirLocationIsNull(rawLoc))
936 throw nb::python_error();
938 rawLoc);
939}
940
941nb::object PyLocation::contextEnter(nb::object locationObj) {
942 return PyThreadContextEntry::pushLocation(locationObj);
943}
944
945void PyLocation::contextExit(const nb::object &excType,
946 const nb::object &excVal,
947 const nb::object &excTb) {
949}
950
953 if (!location) {
954 throw std::runtime_error(
955 "An MLIR function requires a Location but none was provided in the "
956 "call or from the surrounding environment. Either pass to the function "
957 "with a 'loc=' argument or establish a default using 'with loc:'");
958 }
959 return *location;
960}
961
962//------------------------------------------------------------------------------
963// PyModule
964//------------------------------------------------------------------------------
965
966PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
967 : BaseContextObject(std::move(contextRef)), module(module) {}
968
970 nb::gil_scoped_acquire acquire;
971 auto &liveModules = getContext()->liveModules;
972 assert(liveModules.count(module.ptr) == 1 &&
973 "destroying module not in live map");
974 liveModules.erase(module.ptr);
975 mlirModuleDestroy(module);
976}
977
979 MlirContext context = mlirModuleGetContext(module);
980 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
981
982 nb::gil_scoped_acquire acquire;
983 auto &liveModules = contextRef->liveModules;
984 auto it = liveModules.find(module.ptr);
985 if (it == liveModules.end()) {
986 // Create.
987 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
988 // Note that the default return value policy on cast is automatic_reference,
989 // which does not take ownership (delete will not be called).
990 // Just be explicit.
991 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
992 unownedModule->handle = pyRef;
993 liveModules[module.ptr] =
994 std::make_pair(unownedModule->handle, unownedModule);
995 return PyModuleRef(unownedModule, std::move(pyRef));
996 }
997 // Use existing.
998 PyModule *existing = it->second.second;
999 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1000 return PyModuleRef(existing, std::move(pyRef));
1001}
1002
1003nb::object PyModule::createFromCapsule(nb::object capsule) {
1004 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1005 if (mlirModuleIsNull(rawModule))
1006 throw nb::python_error();
1007 return forModule(rawModule).releaseObject();
1008}
1009
1011 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1012}
1013
1014//------------------------------------------------------------------------------
1015// PyOperation
1016//------------------------------------------------------------------------------
1017
1018PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1019 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1020
1022 // If the operation has already been invalidated there is nothing to do.
1023 if (!valid)
1024 return;
1025 // Otherwise, invalidate the operation when it is attached.
1026 if (isAttached())
1027 setInvalid();
1028 else {
1029 // And destroy it when it is detached, i.e. owned by Python.
1030 erase();
1031 }
1032}
1033
1034namespace {
1035
1036// Constructs a new object of type T in-place on the Python heap, returning a
1037// PyObjectRef to it, loosely analogous to std::make_shared<T>().
1038template <typename T, class... Args>
1039PyObjectRef<T> makeObjectRef(Args &&...args) {
1040 nb::handle type = nb::type<T>();
1041 nb::object instance = nb::inst_alloc(type);
1042 T *ptr = nb::inst_ptr<T>(instance);
1043 new (ptr) T(std::forward<Args>(args)...);
1044 nb::inst_mark_ready(instance);
1045 return PyObjectRef<T>(ptr, std::move(instance));
1046}
1047
1048} // namespace
1049
1050PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1051 MlirOperation operation,
1052 nb::object parentKeepAlive) {
1053 // Create.
1054 PyOperationRef unownedOperation =
1055 makeObjectRef<PyOperation>(std::move(contextRef), operation);
1056 unownedOperation->handle = unownedOperation.getObject();
1057 if (parentKeepAlive) {
1058 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1059 }
1060 return unownedOperation;
1061}
1062
1064 MlirOperation operation,
1065 nb::object parentKeepAlive) {
1066 return createInstance(std::move(contextRef), operation,
1067 std::move(parentKeepAlive));
1068}
1069
1071 MlirOperation operation,
1072 nb::object parentKeepAlive) {
1073 PyOperationRef created = createInstance(std::move(contextRef), operation,
1074 std::move(parentKeepAlive));
1075 created->attached = false;
1076 return created;
1077}
1078
1080 const std::string &sourceStr,
1081 const std::string &sourceName) {
1082 PyMlirContext::ErrorCapture errors(contextRef);
1083 MlirOperation op =
1084 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1085 toMlirStringRef(sourceName));
1086 if (mlirOperationIsNull(op))
1087 throw MLIRError("Unable to parse operation assembly", errors.take());
1088 return PyOperation::createDetached(std::move(contextRef), op);
1089}
1090
1092 if (!valid) {
1093 throw std::runtime_error("the operation has been invalidated");
1094 }
1095}
1096
1097void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1098 std::optional<int64_t> largeResourceLimit,
1099 bool enableDebugInfo, bool prettyDebugInfo,
1100 bool printGenericOpForm, bool useLocalScope,
1101 bool useNameLocAsPrefix, bool assumeVerified,
1102 nb::object fileObject, bool binary,
1103 bool skipRegions) {
1104 PyOperation &operation = getOperation();
1105 operation.checkValid();
1106 if (fileObject.is_none())
1107 fileObject = nb::module_::import_("sys").attr("stdout");
1108
1109 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1110 if (largeElementsLimit)
1111 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1112 if (largeResourceLimit)
1113 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit);
1114 if (enableDebugInfo)
1115 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1116 /*prettyForm=*/prettyDebugInfo);
1117 if (printGenericOpForm)
1119 if (useLocalScope)
1121 if (assumeVerified)
1123 if (skipRegions)
1125 if (useNameLocAsPrefix)
1127
1128 PyFileAccumulator accum(fileObject, binary);
1129 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1130 accum.getUserData());
1132}
1133
1134void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1135 bool binary) {
1136 PyOperation &operation = getOperation();
1137 operation.checkValid();
1138 if (fileObject.is_none())
1139 fileObject = nb::module_::import_("sys").attr("stdout");
1140 PyFileAccumulator accum(fileObject, binary);
1141 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1142 accum.getUserData());
1143}
1144
1145void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1146 std::optional<int64_t> bytecodeVersion) {
1147 PyOperation &operation = getOperation();
1148 operation.checkValid();
1149 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1150
1151 if (!bytecodeVersion.has_value())
1152 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1153 accum.getUserData());
1154
1155 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1158 operation, config, accum.getCallback(), accum.getUserData());
1161 throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1162 Twine(*bytecodeVersion))
1163 .str()
1164 .c_str());
1165}
1166
1168 std::function<MlirWalkResult(MlirOperation)> callback,
1169 MlirWalkOrder walkOrder) {
1170 PyOperation &operation = getOperation();
1171 operation.checkValid();
1172 struct UserData {
1173 std::function<MlirWalkResult(MlirOperation)> callback;
1174 bool gotException;
1175 std::string exceptionWhat;
1176 nb::object exceptionType;
1177 };
1178 UserData userData{callback, false, {}, {}};
1179 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1180 void *userData) {
1181 UserData *calleeUserData = static_cast<UserData *>(userData);
1182 try {
1183 return (calleeUserData->callback)(op);
1184 } catch (nb::python_error &e) {
1185 calleeUserData->gotException = true;
1186 calleeUserData->exceptionWhat = std::string(e.what());
1187 calleeUserData->exceptionType = nb::borrow(e.type());
1188 return MlirWalkResult::MlirWalkResultInterrupt;
1189 }
1190 };
1191 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1192 if (userData.gotException) {
1193 std::string message("Exception raised in callback: ");
1194 message.append(userData.exceptionWhat);
1195 throw std::runtime_error(message);
1196 }
1197}
1198
1199nb::object PyOperationBase::getAsm(bool binary,
1200 std::optional<int64_t> largeElementsLimit,
1201 std::optional<int64_t> largeResourceLimit,
1202 bool enableDebugInfo, bool prettyDebugInfo,
1203 bool printGenericOpForm, bool useLocalScope,
1204 bool useNameLocAsPrefix, bool assumeVerified,
1205 bool skipRegions) {
1206 nb::object fileObject;
1207 if (binary) {
1208 fileObject = nb::module_::import_("io").attr("BytesIO")();
1209 } else {
1210 fileObject = nb::module_::import_("io").attr("StringIO")();
1211 }
1212 print(/*largeElementsLimit=*/largeElementsLimit,
1213 /*largeResourceLimit=*/largeResourceLimit,
1214 /*enableDebugInfo=*/enableDebugInfo,
1215 /*prettyDebugInfo=*/prettyDebugInfo,
1216 /*printGenericOpForm=*/printGenericOpForm,
1217 /*useLocalScope=*/useLocalScope,
1218 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1219 /*assumeVerified=*/assumeVerified,
1220 /*fileObject=*/fileObject,
1221 /*binary=*/binary,
1222 /*skipRegions=*/skipRegions);
1223
1224 return fileObject.attr("getvalue")();
1225}
1226
1228 PyOperation &operation = getOperation();
1229 PyOperation &otherOp = other.getOperation();
1230 operation.checkValid();
1231 otherOp.checkValid();
1232 mlirOperationMoveAfter(operation, otherOp);
1233 operation.parentKeepAlive = otherOp.parentKeepAlive;
1234}
1235
1237 PyOperation &operation = getOperation();
1238 PyOperation &otherOp = other.getOperation();
1239 operation.checkValid();
1240 otherOp.checkValid();
1241 mlirOperationMoveBefore(operation, otherOp);
1242 operation.parentKeepAlive = otherOp.parentKeepAlive;
1243}
1244
1246 PyOperation &operation = getOperation();
1247 PyOperation &otherOp = other.getOperation();
1248 operation.checkValid();
1249 otherOp.checkValid();
1250 return mlirOperationIsBeforeInBlock(operation, otherOp);
1251}
1252
1254 PyOperation &op = getOperation();
1256 if (!mlirOperationVerify(op.get()))
1257 throw MLIRError("Verification failed", errors.take());
1258 return true;
1259}
1260
1261std::optional<PyOperationRef> PyOperation::getParentOperation() {
1262 checkValid();
1263 if (!isAttached())
1264 throw nb::value_error("Detached operations have no parent");
1265 MlirOperation operation = mlirOperationGetParentOperation(get());
1266 if (mlirOperationIsNull(operation))
1267 return {};
1268 return PyOperation::forOperation(getContext(), operation);
1269}
1270
1272 checkValid();
1273 std::optional<PyOperationRef> parentOperation = getParentOperation();
1274 MlirBlock block = mlirOperationGetBlock(get());
1275 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1276 assert(parentOperation && "Operation has no parent");
1277 return PyBlock{std::move(*parentOperation), block};
1278}
1279
1281 checkValid();
1282 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1283}
1284
1285nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
1286 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1287 if (mlirOperationIsNull(rawOperation))
1288 throw nb::python_error();
1289 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1290 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1291 .releaseObject();
1292}
1293
1295 const nb::object &maybeIp) {
1296 // InsertPoint active?
1297 if (!maybeIp.is(nb::cast(false))) {
1298 PyInsertionPoint *ip;
1299 if (maybeIp.is_none()) {
1301 } else {
1302 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1303 }
1304 if (ip)
1305 ip->insert(*op.get());
1306 }
1307}
1308
1309nb::object PyOperation::create(std::string_view name,
1310 std::optional<std::vector<PyType *>> results,
1312 std::optional<nb::dict> attributes,
1313 std::optional<std::vector<PyBlock *>> successors,
1314 int regions, PyLocation &location,
1315 const nb::object &maybeIp, bool inferType) {
1317 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1319
1320 // General parameter validation.
1321 if (regions < 0)
1322 throw nb::value_error("number of regions must be >= 0");
1323
1324 // Unpack/validate results.
1325 if (results) {
1326 mlirResults.reserve(results->size());
1327 for (PyType *result : *results) {
1328 // TODO: Verify result type originate from the same context.
1329 if (!result)
1330 throw nb::value_error("result type cannot be None");
1331 mlirResults.push_back(*result);
1332 }
1333 }
1334 // Unpack/validate attributes.
1335 if (attributes) {
1336 mlirAttributes.reserve(attributes->size());
1337 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1338 std::string key;
1339 try {
1340 key = nb::cast<std::string>(it.first);
1341 } catch (nb::cast_error &err) {
1342 std::string msg = "Invalid attribute key (not a string) when "
1343 "attempting to create the operation \"" +
1344 std::string(name) + "\" (" + err.what() + ")";
1345 throw nb::type_error(msg.c_str());
1346 }
1347 try {
1348 auto &attribute = nb::cast<PyAttribute &>(it.second);
1349 // TODO: Verify attribute originates from the same context.
1350 mlirAttributes.emplace_back(std::move(key), attribute);
1351 } catch (nb::cast_error &err) {
1352 std::string msg = "Invalid attribute value for the key \"" + key +
1353 "\" when attempting to create the operation \"" +
1354 std::string(name) + "\" (" + err.what() + ")";
1355 throw nb::type_error(msg.c_str());
1356 } catch (std::runtime_error &) {
1357 // This exception seems thrown when the value is "None".
1358 std::string msg =
1359 "Found an invalid (`None`?) attribute value for the key \"" + key +
1360 "\" when attempting to create the operation \"" +
1361 std::string(name) + "\"";
1362 throw std::runtime_error(msg);
1363 }
1364 }
1365 }
1366 // Unpack/validate successors.
1367 if (successors) {
1368 mlirSuccessors.reserve(successors->size());
1369 for (auto *successor : *successors) {
1370 // TODO: Verify successor originate from the same context.
1371 if (!successor)
1372 throw nb::value_error("successor block cannot be None");
1373 mlirSuccessors.push_back(successor->get());
1374 }
1375 }
1376
1377 // Apply unpacked/validated to the operation state. Beyond this
1378 // point, exceptions cannot be thrown or else the state will leak.
1379 MlirOperationState state =
1380 mlirOperationStateGet(toMlirStringRef(name), location);
1381 if (!operands.empty())
1382 mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1383 state.enableResultTypeInference = inferType;
1384 if (!mlirResults.empty())
1385 mlirOperationStateAddResults(&state, mlirResults.size(),
1386 mlirResults.data());
1387 if (!mlirAttributes.empty()) {
1388 // Note that the attribute names directly reference bytes in
1389 // mlirAttributes, so that vector must not be changed from here
1390 // on.
1391 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1392 mlirNamedAttributes.reserve(mlirAttributes.size());
1393 for (auto &it : mlirAttributes)
1394 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1396 toMlirStringRef(it.first)),
1397 it.second));
1398 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1399 mlirNamedAttributes.data());
1400 }
1401 if (!mlirSuccessors.empty())
1402 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1403 mlirSuccessors.data());
1404 if (regions) {
1406 mlirRegions.resize(regions);
1407 for (int i = 0; i < regions; ++i)
1408 mlirRegions[i] = mlirRegionCreate();
1409 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1410 mlirRegions.data());
1411 }
1412
1413 // Construct the operation.
1414 PyMlirContext::ErrorCapture errors(location.getContext());
1415 MlirOperation operation = mlirOperationCreate(&state);
1416 if (!operation.ptr)
1417 throw MLIRError("Operation creation failed", errors.take());
1418 PyOperationRef created =
1419 PyOperation::createDetached(location.getContext(), operation);
1420 maybeInsertOperation(created, maybeIp);
1421
1422 return created.getObject();
1423}
1424
1425nb::object PyOperation::clone(const nb::object &maybeIp) {
1426 MlirOperation clonedOperation = mlirOperationClone(operation);
1427 PyOperationRef cloned =
1428 PyOperation::createDetached(getContext(), clonedOperation);
1429 maybeInsertOperation(cloned, maybeIp);
1430
1431 return cloned->createOpView();
1432}
1433
1435 checkValid();
1436 MlirIdentifier ident = mlirOperationGetName(get());
1437 MlirStringRef identStr = mlirIdentifierStr(ident);
1438 auto operationCls = PyGlobals::get().lookupOperationClass(
1439 StringRef(identStr.data, identStr.length));
1440 if (operationCls)
1441 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1442 return nb::cast(PyOpView(getRef().getObject()));
1443}
1444
1446 checkValid();
1447 setInvalid();
1448 mlirOperationDestroy(operation);
1449}
1450
1451namespace {
1452/// CRTP base class for Python MLIR values that subclass Value and should be
1453/// castable from it. The value hierarchy is one level deep and is not supposed
1454/// to accommodate other levels unless core MLIR changes.
1455template <typename DerivedTy>
1456class PyConcreteValue : public PyValue {
1457public:
1458 // Derived classes must define statics for:
1459 // IsAFunctionTy isaFunction
1460 // const char *pyClassName
1461 // and redefine bindDerived.
1462 using ClassTy = nb::class_<DerivedTy, PyValue>;
1463 using IsAFunctionTy = bool (*)(MlirValue);
1464
1465 PyConcreteValue() = default;
1466 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1467 : PyValue(operationRef, value) {}
1468 PyConcreteValue(PyValue &orig)
1469 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1470
1471 /// Attempts to cast the original value to the derived type and throws on
1472 /// type mismatches.
1473 static MlirValue castFrom(PyValue &orig) {
1474 if (!DerivedTy::isaFunction(orig.get())) {
1475 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1476 throw nb::value_error((Twine("Cannot cast value to ") +
1477 DerivedTy::pyClassName + " (from " + origRepr +
1478 ")")
1479 .str()
1480 .c_str());
1481 }
1482 return orig.get();
1483 }
1484
1485 /// Binds the Python module objects to functions of this class.
1486 static void bind(nb::module_ &m) {
1487 auto cls = ClassTy(
1488 m, DerivedTy::pyClassName, nb::is_generic(),
1489 nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
1490 .str()
1491 .c_str()));
1492 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1493 cls.def_static(
1494 "isinstance",
1495 [](PyValue &otherValue) -> bool {
1496 return DerivedTy::isaFunction(otherValue);
1497 },
1498 nb::arg("other_value"));
1500 [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
1501 return self.maybeDownCast();
1502 });
1503 DerivedTy::bindDerived(cls);
1504 }
1505
1506 /// Implemented by derived classes to add methods to the Python subclass.
1507 static void bindDerived(ClassTy &m) {}
1508};
1509
1510} // namespace
1511
1512/// Python wrapper for MlirOpResult.
1513class PyOpResult : public PyConcreteValue<PyOpResult> {
1514public:
1515 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1516 static constexpr const char *pyClassName = "OpResult";
1517 using PyConcreteValue::PyConcreteValue;
1518
1519 static void bindDerived(ClassTy &c) {
1520 c.def_prop_ro(
1521 "owner",
1522 [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
1523 assert(mlirOperationEqual(self.getParentOperation()->get(),
1524 mlirOpResultGetOwner(self.get())) &&
1525 "expected the owner of the value in Python to match that in "
1526 "the IR");
1527 return self.getParentOperation().getObject();
1528 },
1529 "Returns the operation that produces this result.");
1530 c.def_prop_ro(
1531 "result_number",
1532 [](PyOpResult &self) {
1533 return mlirOpResultGetResultNumber(self.get());
1534 },
1535 "Returns the position of this result in the operation's result list.");
1536 }
1537};
1538
1539/// Returns the list of types of the values held by container.
1540template <typename Container>
1541static std::vector<nb::typed<nb::object, PyType>>
1542getValueTypes(Container &container, PyMlirContextRef &context) {
1543 std::vector<nb::typed<nb::object, PyType>> result;
1544 result.reserve(container.size());
1545 for (int i = 0, e = container.size(); i < e; ++i) {
1546 result.push_back(PyType(context->getRef(),
1547 mlirValueGetType(container.getElement(i).get()))
1548 .maybeDownCast());
1549 }
1550 return result;
1551}
1552
1553/// A list of operation results. Internally, these are stored as consecutive
1554/// elements, random access is cheap. The (returned) result list is associated
1555/// with the operation whose results these are, and thus extends the lifetime of
1556/// this operation.
1557class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1558public:
1559 static constexpr const char *pyClassName = "OpResultList";
1561
1563 intptr_t length = -1, intptr_t step = 1)
1565 length == -1 ? mlirOperationGetNumResults(operation->get())
1566 : length,
1567 step),
1568 operation(std::move(operation)) {}
1569
1570 static void bindDerived(ClassTy &c) {
1571 c.def_prop_ro(
1572 "types",
1573 [](PyOpResultList &self) {
1574 return getValueTypes(self, self.operation->getContext());
1575 },
1576 "Returns a list of types for all results in this result list.");
1577 c.def_prop_ro(
1578 "owner",
1579 [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
1580 return self.operation->createOpView();
1581 },
1582 "Returns the operation that owns this result list.");
1583 }
1584
1585 PyOperationRef &getOperation() { return operation; }
1586
1587private:
1588 /// Give the parent CRTP class access to hook implementations below.
1589 friend class Sliceable<PyOpResultList, PyOpResult>;
1590
1591 intptr_t getRawNumElements() {
1592 operation->checkValid();
1593 return mlirOperationGetNumResults(operation->get());
1594 }
1595
1596 PyOpResult getRawElement(intptr_t index) {
1597 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1598 return PyOpResult(value);
1599 }
1600
1601 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1602 return PyOpResultList(operation, startIndex, length, step);
1603 }
1604
1605 PyOperationRef operation;
1606};
1607
1608//------------------------------------------------------------------------------
1609// PyOpView
1610//------------------------------------------------------------------------------
1611
1612static void populateResultTypes(StringRef name, nb::list resultTypeList,
1613 const nb::object &resultSegmentSpecObj,
1614 std::vector<int32_t> &resultSegmentLengths,
1615 std::vector<PyType *> &resultTypes) {
1616 resultTypes.reserve(resultTypeList.size());
1617 if (resultSegmentSpecObj.is_none()) {
1618 // Non-variadic result unpacking.
1619 for (const auto &it : llvm::enumerate(resultTypeList)) {
1620 try {
1621 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1622 if (!resultTypes.back())
1623 throw nb::cast_error();
1624 } catch (nb::cast_error &err) {
1625 throw nb::value_error((llvm::Twine("Result ") +
1626 llvm::Twine(it.index()) + " of operation \"" +
1627 name + "\" must be a Type (" + err.what() + ")")
1628 .str()
1629 .c_str());
1630 }
1631 }
1632 } else {
1633 // Sized result unpacking.
1634 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1635 if (resultSegmentSpec.size() != resultTypeList.size()) {
1636 throw nb::value_error((llvm::Twine("Operation \"") + name +
1637 "\" requires " +
1638 llvm::Twine(resultSegmentSpec.size()) +
1639 " result segments but was provided " +
1640 llvm::Twine(resultTypeList.size()))
1641 .str()
1642 .c_str());
1643 }
1644 resultSegmentLengths.reserve(resultTypeList.size());
1645 for (const auto &it :
1646 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1647 int segmentSpec = std::get<1>(it.value());
1648 if (segmentSpec == 1 || segmentSpec == 0) {
1649 // Unpack unary element.
1650 try {
1651 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1652 if (resultType) {
1653 resultTypes.push_back(resultType);
1654 resultSegmentLengths.push_back(1);
1655 } else if (segmentSpec == 0) {
1656 // Allowed to be optional.
1657 resultSegmentLengths.push_back(0);
1658 } else {
1659 throw nb::value_error(
1660 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1661 " of operation \"" + name +
1662 "\" must be a Type (was None and result is not optional)")
1663 .str()
1664 .c_str());
1665 }
1666 } catch (nb::cast_error &err) {
1667 throw nb::value_error((llvm::Twine("Result ") +
1668 llvm::Twine(it.index()) + " of operation \"" +
1669 name + "\" must be a Type (" + err.what() +
1670 ")")
1671 .str()
1672 .c_str());
1673 }
1674 } else if (segmentSpec == -1) {
1675 // Unpack sequence by appending.
1676 try {
1677 if (std::get<0>(it.value()).is_none()) {
1678 // Treat it as an empty list.
1679 resultSegmentLengths.push_back(0);
1680 } else {
1681 // Unpack the list.
1682 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1683 for (nb::handle segmentItem : segment) {
1684 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1685 if (!resultTypes.back()) {
1686 throw nb::type_error("contained a None item");
1687 }
1688 }
1689 resultSegmentLengths.push_back(nb::len(segment));
1690 }
1691 } catch (std::exception &err) {
1692 // NOTE: Sloppy to be using a catch-all here, but there are at least
1693 // three different unrelated exceptions that can be thrown in the
1694 // above "casts". Just keep the scope above small and catch them all.
1695 throw nb::value_error((llvm::Twine("Result ") +
1696 llvm::Twine(it.index()) + " of operation \"" +
1697 name + "\" must be a Sequence of Types (" +
1698 err.what() + ")")
1699 .str()
1700 .c_str());
1701 }
1702 } else {
1703 throw nb::value_error("Unexpected segment spec");
1704 }
1705 }
1706 }
1707}
1708
1709static MlirValue getUniqueResult(MlirOperation operation) {
1710 auto numResults = mlirOperationGetNumResults(operation);
1711 if (numResults != 1) {
1712 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1713 throw nb::value_error((Twine("Cannot call .result on operation ") +
1714 StringRef(name.data, name.length) + " which has " +
1715 Twine(numResults) +
1716 " results (it is only valid for operations with a "
1717 "single result)")
1718 .str()
1719 .c_str());
1720 }
1721 return mlirOperationGetResult(operation, 0);
1722}
1723
1724static MlirValue getOpResultOrValue(nb::handle operand) {
1725 if (operand.is_none()) {
1726 throw nb::value_error("contained a None item");
1727 }
1728 PyOperationBase *op;
1729 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1730 return getUniqueResult(op->getOperation());
1731 }
1732 PyOpResultList *opResultList;
1733 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1734 return getUniqueResult(opResultList->getOperation()->get());
1735 }
1736 PyValue *value;
1737 if (nb::try_cast<PyValue *>(operand, value)) {
1738 return value->get();
1739 }
1740 throw nb::value_error("is not a Value");
1741}
1742
1744 std::string_view name, std::tuple<int, bool> opRegionSpec,
1745 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1746 std::optional<nb::list> resultTypeList, nb::list operandList,
1747 std::optional<nb::dict> attributes,
1748 std::optional<std::vector<PyBlock *>> successors,
1749 std::optional<int> regions, PyLocation &location,
1750 const nb::object &maybeIp) {
1751 PyMlirContextRef context = location.getContext();
1752
1753 // Class level operation construction metadata.
1754 // Operand and result segment specs are either none, which does no
1755 // variadic unpacking, or a list of ints with segment sizes, where each
1756 // element is either a positive number (typically 1 for a scalar) or -1 to
1757 // indicate that it is derived from the length of the same-indexed operand
1758 // or result (implying that it is a list at that position).
1759 std::vector<int32_t> operandSegmentLengths;
1760 std::vector<int32_t> resultSegmentLengths;
1761
1762 // Validate/determine region count.
1763 int opMinRegionCount = std::get<0>(opRegionSpec);
1764 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1765 if (!regions) {
1766 regions = opMinRegionCount;
1767 }
1768 if (*regions < opMinRegionCount) {
1769 throw nb::value_error(
1770 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1771 llvm::Twine(opMinRegionCount) +
1772 " regions but was built with regions=" + llvm::Twine(*regions))
1773 .str()
1774 .c_str());
1775 }
1776 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1777 throw nb::value_error(
1778 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1779 llvm::Twine(opMinRegionCount) +
1780 " regions but was built with regions=" + llvm::Twine(*regions))
1781 .str()
1782 .c_str());
1783 }
1784
1785 // Unpack results.
1786 std::vector<PyType *> resultTypes;
1787 if (resultTypeList.has_value()) {
1788 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1789 resultSegmentLengths, resultTypes);
1790 }
1791
1792 // Unpack operands.
1794 operands.reserve(operands.size());
1795 if (operandSegmentSpecObj.is_none()) {
1796 // Non-sized operand unpacking.
1797 for (const auto &it : llvm::enumerate(operandList)) {
1798 try {
1799 operands.push_back(getOpResultOrValue(it.value()));
1800 } catch (nb::builtin_exception &err) {
1801 throw nb::value_error((llvm::Twine("Operand ") +
1802 llvm::Twine(it.index()) + " of operation \"" +
1803 name + "\" must be a Value (" + err.what() + ")")
1804 .str()
1805 .c_str());
1806 }
1807 }
1808 } else {
1809 // Sized operand unpacking.
1810 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1811 if (operandSegmentSpec.size() != operandList.size()) {
1812 throw nb::value_error((llvm::Twine("Operation \"") + name +
1813 "\" requires " +
1814 llvm::Twine(operandSegmentSpec.size()) +
1815 "operand segments but was provided " +
1816 llvm::Twine(operandList.size()))
1817 .str()
1818 .c_str());
1819 }
1820 operandSegmentLengths.reserve(operandList.size());
1821 for (const auto &it :
1822 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1823 int segmentSpec = std::get<1>(it.value());
1824 if (segmentSpec == 1 || segmentSpec == 0) {
1825 // Unpack unary element.
1826 auto &operand = std::get<0>(it.value());
1827 if (!operand.is_none()) {
1828 try {
1829
1830 operands.push_back(getOpResultOrValue(operand));
1831 } catch (nb::builtin_exception &err) {
1832 throw nb::value_error((llvm::Twine("Operand ") +
1833 llvm::Twine(it.index()) +
1834 " of operation \"" + name +
1835 "\" must be a Value (" + err.what() + ")")
1836 .str()
1837 .c_str());
1838 }
1839
1840 operandSegmentLengths.push_back(1);
1841 } else if (segmentSpec == 0) {
1842 // Allowed to be optional.
1843 operandSegmentLengths.push_back(0);
1844 } else {
1845 throw nb::value_error(
1846 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
1847 " of operation \"" + name +
1848 "\" must be a Value (was None and operand is not optional)")
1849 .str()
1850 .c_str());
1851 }
1852 } else if (segmentSpec == -1) {
1853 // Unpack sequence by appending.
1854 try {
1855 if (std::get<0>(it.value()).is_none()) {
1856 // Treat it as an empty list.
1857 operandSegmentLengths.push_back(0);
1858 } else {
1859 // Unpack the list.
1860 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1861 for (nb::handle segmentItem : segment) {
1862 operands.push_back(getOpResultOrValue(segmentItem));
1863 }
1864 operandSegmentLengths.push_back(nb::len(segment));
1865 }
1866 } catch (std::exception &err) {
1867 // NOTE: Sloppy to be using a catch-all here, but there are at least
1868 // three different unrelated exceptions that can be thrown in the
1869 // above "casts". Just keep the scope above small and catch them all.
1870 throw nb::value_error((llvm::Twine("Operand ") +
1871 llvm::Twine(it.index()) + " of operation \"" +
1872 name + "\" must be a Sequence of Values (" +
1873 err.what() + ")")
1874 .str()
1875 .c_str());
1876 }
1877 } else {
1878 throw nb::value_error("Unexpected segment spec");
1879 }
1880 }
1881 }
1882
1883 // Merge operand/result segment lengths into attributes if needed.
1884 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1885 // Dup.
1886 if (attributes) {
1887 attributes = nb::dict(*attributes);
1888 } else {
1889 attributes = nb::dict();
1890 }
1891 if (attributes->contains("resultSegmentSizes") ||
1892 attributes->contains("operandSegmentSizes")) {
1893 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
1894 "'operandSegmentSizes' attribute is unsupported. "
1895 "Use Operation.create for such low-level access.");
1896 }
1897
1898 // Add resultSegmentSizes attribute.
1899 if (!resultSegmentLengths.empty()) {
1900 MlirAttribute segmentLengthAttr =
1901 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1902 resultSegmentLengths.data());
1903 (*attributes)["resultSegmentSizes"] =
1904 PyAttribute(context, segmentLengthAttr);
1905 }
1906
1907 // Add operandSegmentSizes attribute.
1908 if (!operandSegmentLengths.empty()) {
1909 MlirAttribute segmentLengthAttr =
1910 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1911 operandSegmentLengths.data());
1912 (*attributes)["operandSegmentSizes"] =
1913 PyAttribute(context, segmentLengthAttr);
1914 }
1915 }
1916
1917 // Delegate to create.
1918 return PyOperation::create(name,
1919 /*results=*/std::move(resultTypes),
1920 /*operands=*/operands,
1921 /*attributes=*/std::move(attributes),
1922 /*successors=*/std::move(successors),
1923 /*regions=*/*regions, location, maybeIp,
1924 !resultTypeList);
1925}
1926
1927nb::object PyOpView::constructDerived(const nb::object &cls,
1928 const nb::object &operation) {
1929 nb::handle opViewType = nb::type<PyOpView>();
1930 nb::object instance = cls.attr("__new__")(cls);
1931 opViewType.attr("__init__")(instance, operation);
1932 return instance;
1933}
1934
1935PyOpView::PyOpView(const nb::object &operationObject)
1936 // Casting through the PyOperationBase base-class and then back to the
1937 // Operation lets us accept any PyOperationBase subclass.
1938 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
1939 operationObject(operation.getRef().getObject()) {}
1940
1941//------------------------------------------------------------------------------
1942// PyInsertionPoint.
1943//------------------------------------------------------------------------------
1944
1945PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
1946
1948 : refOperation(beforeOperationBase.getOperation().getRef()),
1949 block((*refOperation)->getBlock()) {}
1950
1952 : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
1953
1955 PyOperation &operation = operationBase.getOperation();
1956 if (operation.isAttached())
1957 throw nb::value_error(
1958 "Attempt to insert operation that is already attached");
1959 block.getParentOperation()->checkValid();
1960 MlirOperation beforeOp = {nullptr};
1961 if (refOperation) {
1962 // Insert before operation.
1963 (*refOperation)->checkValid();
1964 beforeOp = (*refOperation)->get();
1965 } else {
1966 // Insert at end (before null) is only valid if the block does not
1967 // already end in a known terminator (violating this will cause assertion
1968 // failures later).
1969 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1970 throw nb::index_error("Cannot insert operation at the end of a block "
1971 "that already has a terminator. Did you mean to "
1972 "use 'InsertionPoint.at_block_terminator(block)' "
1973 "versus 'InsertionPoint(block)'?");
1974 }
1975 }
1976 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1977 operation.setAttached();
1978}
1979
1981 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1982 if (mlirOperationIsNull(firstOp)) {
1983 // Just insert at end.
1984 return PyInsertionPoint(block);
1985 }
1986
1987 // Insert before first op.
1989 block.getParentOperation()->getContext(), firstOp);
1990 return PyInsertionPoint{block, std::move(firstOpRef)};
1991}
1992
1994 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1995 if (mlirOperationIsNull(terminator))
1996 throw nb::value_error("Block has no terminator");
1997 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1998 block.getParentOperation()->getContext(), terminator);
1999 return PyInsertionPoint{block, std::move(terminatorOpRef)};
2000}
2001
2003 PyOperation &operation = op.getOperation();
2004 PyBlock block = operation.getBlock();
2005 MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
2006 if (mlirOperationIsNull(nextOperation))
2007 return PyInsertionPoint(block);
2009 block.getParentOperation()->getContext(), nextOperation);
2010 return PyInsertionPoint{block, std::move(nextOpRef)};
2011}
2012
2013size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
2014
2015nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2016 return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
2017}
2018
2019void PyInsertionPoint::contextExit(const nb::object &excType,
2020 const nb::object &excVal,
2021 const nb::object &excTb) {
2023}
2024
2025//------------------------------------------------------------------------------
2026// PyAttribute.
2027//------------------------------------------------------------------------------
2028
2029bool PyAttribute::operator==(const PyAttribute &other) const {
2030 return mlirAttributeEqual(attr, other.attr);
2031}
2032
2034 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2035}
2036
2038 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2039 if (mlirAttributeIsNull(rawAttr))
2040 throw nb::python_error();
2041 return PyAttribute(
2043}
2044
2046 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
2047 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2048 "mlirTypeID was expected to be non-null.");
2049 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
2050 mlirTypeID, mlirAttributeGetDialect(this->get()));
2051 // nb::rv_policy::move means use std::move to move the return value
2052 // contents into a new instance that will be owned by Python.
2053 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2054 if (!typeCaster)
2055 return thisObj;
2056 return typeCaster.value()(thisObj);
2057}
2058
2059//------------------------------------------------------------------------------
2060// PyNamedAttribute.
2061//------------------------------------------------------------------------------
2062
2063PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2064 : ownedName(new std::string(std::move(ownedName))) {
2067 toMlirStringRef(*this->ownedName)),
2068 attr);
2069}
2070
2071//------------------------------------------------------------------------------
2072// PyType.
2073//------------------------------------------------------------------------------
2074
2075bool PyType::operator==(const PyType &other) const {
2076 return mlirTypeEqual(type, other.type);
2077}
2078
2080 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2081}
2082
2084 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2085 if (mlirTypeIsNull(rawType))
2086 throw nb::python_error();
2088 rawType);
2089}
2090
2092 MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
2093 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2094 "mlirTypeID was expected to be non-null.");
2095 std::optional<nb::callable> typeCaster = PyGlobals::get().lookupTypeCaster(
2096 mlirTypeID, mlirTypeGetDialect(this->get()));
2097 // nb::rv_policy::move means use std::move to move the return value
2098 // contents into a new instance that will be owned by Python.
2099 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2100 if (!typeCaster)
2101 return thisObj;
2102 return typeCaster.value()(thisObj);
2103}
2104
2105//------------------------------------------------------------------------------
2106// PyTypeID.
2107//------------------------------------------------------------------------------
2108
2110 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2111}
2112
2114 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2115 if (mlirTypeIDIsNull(mlirTypeID))
2116 throw nb::python_error();
2117 return PyTypeID(mlirTypeID);
2118}
2119bool PyTypeID::operator==(const PyTypeID &other) const {
2120 return mlirTypeIDEqual(typeID, other.typeID);
2121}
2122
2123//------------------------------------------------------------------------------
2124// PyValue and subclasses.
2125//------------------------------------------------------------------------------
2126
2128 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2129}
2130
2132 MlirType type = mlirValueGetType(get());
2133 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2134 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2135 "mlirTypeID was expected to be non-null.");
2136 std::optional<nb::callable> valueCaster =
2138 // nb::rv_policy::move means use std::move to move the return value
2139 // contents into a new instance that will be owned by Python.
2140 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2141 if (!valueCaster)
2142 return thisObj;
2143 return valueCaster.value()(thisObj);
2144}
2145
2147 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2148 if (mlirValueIsNull(value))
2149 throw nb::python_error();
2150 MlirOperation owner;
2151 if (mlirValueIsAOpResult(value))
2152 owner = mlirOpResultGetOwner(value);
2153 if (mlirValueIsABlockArgument(value))
2155 if (mlirOperationIsNull(owner))
2156 throw nb::python_error();
2157 MlirContext ctx = mlirOperationGetContext(owner);
2158 PyOperationRef ownerRef =
2160 return PyValue(ownerRef, value);
2161}
2162
2163//------------------------------------------------------------------------------
2164// PySymbolTable.
2165//------------------------------------------------------------------------------
2166
2168 : operation(operation.getOperation().getRef()) {
2169 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2170 if (mlirSymbolTableIsNull(symbolTable)) {
2171 throw nb::type_error("Operation is not a Symbol Table.");
2172 }
2173}
2174
2175nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2176 operation->checkValid();
2177 MlirOperation symbol = mlirSymbolTableLookup(
2178 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2179 if (mlirOperationIsNull(symbol))
2180 throw nb::key_error(
2181 ("Symbol '" + name + "' not in the symbol table.").c_str());
2182
2183 return PyOperation::forOperation(operation->getContext(), symbol,
2184 operation.getObject())
2185 ->createOpView();
2186}
2187
2189 operation->checkValid();
2190 symbol.getOperation().checkValid();
2191 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2192 // The operation is also erased, so we must invalidate it. There may be Python
2193 // references to this operation so we don't want to delete it from the list of
2194 // live operations here.
2195 symbol.getOperation().valid = false;
2196}
2197
2198void PySymbolTable::dunderDel(const std::string &name) {
2199 nb::object operation = dunderGetItem(name);
2200 erase(nb::cast<PyOperationBase &>(operation));
2201}
2202
2204 operation->checkValid();
2205 symbol.getOperation().checkValid();
2206 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2208 if (mlirAttributeIsNull(symbolAttr))
2209 throw nb::value_error("Expected operation to have a symbol name.");
2210 return PyStringAttribute(
2211 symbol.getOperation().getContext(),
2212 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
2213}
2214
2216 // Op must already be a symbol.
2217 PyOperation &operation = symbol.getOperation();
2218 operation.checkValid();
2220 MlirAttribute existingNameAttr =
2221 mlirOperationGetAttributeByName(operation.get(), attrName);
2222 if (mlirAttributeIsNull(existingNameAttr))
2223 throw nb::value_error("Expected operation to have a symbol name.");
2224 return PyStringAttribute(symbol.getOperation().getContext(),
2225 existingNameAttr);
2226}
2227
2229 const std::string &name) {
2230 // Op must already be a symbol.
2231 PyOperation &operation = symbol.getOperation();
2232 operation.checkValid();
2234 MlirAttribute existingNameAttr =
2235 mlirOperationGetAttributeByName(operation.get(), attrName);
2236 if (mlirAttributeIsNull(existingNameAttr))
2237 throw nb::value_error("Expected operation to have a symbol name.");
2238 MlirAttribute newNameAttr =
2239 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2240 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2241}
2242
2244 PyOperation &operation = symbol.getOperation();
2245 operation.checkValid();
2247 MlirAttribute existingVisAttr =
2248 mlirOperationGetAttributeByName(operation.get(), attrName);
2249 if (mlirAttributeIsNull(existingVisAttr))
2250 throw nb::value_error("Expected operation to have a symbol visibility.");
2251 return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr);
2252}
2253
2255 const std::string &visibility) {
2256 if (visibility != "public" && visibility != "private" &&
2257 visibility != "nested")
2258 throw nb::value_error(
2259 "Expected visibility to be 'public', 'private' or 'nested'");
2260 PyOperation &operation = symbol.getOperation();
2261 operation.checkValid();
2263 MlirAttribute existingVisAttr =
2264 mlirOperationGetAttributeByName(operation.get(), attrName);
2265 if (mlirAttributeIsNull(existingVisAttr))
2266 throw nb::value_error("Expected operation to have a symbol visibility.");
2267 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2268 toMlirStringRef(visibility));
2269 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2270}
2271
2272void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2273 const std::string &newSymbol,
2274 PyOperationBase &from) {
2275 PyOperation &fromOperation = from.getOperation();
2276 fromOperation.checkValid();
2278 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2279 from.getOperation())))
2280
2281 throw nb::value_error("Symbol rename failed");
2282}
2283
2285 bool allSymUsesVisible,
2286 nb::object callback) {
2287 PyOperation &fromOperation = from.getOperation();
2288 fromOperation.checkValid();
2289 struct UserData {
2290 PyMlirContextRef context;
2291 nb::object callback;
2292 bool gotException;
2293 std::string exceptionWhat;
2294 nb::object exceptionType;
2295 };
2296 UserData userData{
2297 fromOperation.getContext(), std::move(callback), false, {}, {}};
2299 fromOperation.get(), allSymUsesVisible,
2300 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2301 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2302 auto pyFoundOp =
2303 PyOperation::forOperation(calleeUserData->context, foundOp);
2304 if (calleeUserData->gotException)
2305 return;
2306 try {
2307 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2308 } catch (nb::python_error &e) {
2309 calleeUserData->gotException = true;
2310 calleeUserData->exceptionWhat = e.what();
2311 calleeUserData->exceptionType = nb::borrow(e.type());
2312 }
2313 },
2314 static_cast<void *>(&userData));
2315 if (userData.gotException) {
2316 std::string message("Exception raised in callback: ");
2317 message.append(userData.exceptionWhat);
2318 throw std::runtime_error(message);
2319 }
2320}
2321
2322namespace {
2323
2324/// Python wrapper for MlirBlockArgument.
2325class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2326public:
2327 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2328 static constexpr const char *pyClassName = "BlockArgument";
2329 using PyConcreteValue::PyConcreteValue;
2330
2331 static void bindDerived(ClassTy &c) {
2332 c.def_prop_ro(
2333 "owner",
2334 [](PyBlockArgument &self) {
2335 return PyBlock(self.getParentOperation(),
2336 mlirBlockArgumentGetOwner(self.get()));
2337 },
2338 "Returns the block that owns this argument.");
2339 c.def_prop_ro(
2340 "arg_number",
2341 [](PyBlockArgument &self) {
2342 return mlirBlockArgumentGetArgNumber(self.get());
2343 },
2344 "Returns the position of this argument in the block's argument list.");
2345 c.def(
2346 "set_type",
2347 [](PyBlockArgument &self, PyType type) {
2348 return mlirBlockArgumentSetType(self.get(), type);
2349 },
2350 nb::arg("type"), "Sets the type of this block argument.");
2351 c.def(
2352 "set_location",
2353 [](PyBlockArgument &self, PyLocation loc) {
2354 return mlirBlockArgumentSetLocation(self.get(), loc);
2355 },
2356 nb::arg("loc"), "Sets the location of this block argument.");
2357 }
2358};
2359
2360/// A list of block arguments. Internally, these are stored as consecutive
2361/// elements, random access is cheap. The argument list is associated with the
2362/// operation that contains the block (detached blocks are not allowed in
2363/// Python bindings) and extends its lifetime.
2364class PyBlockArgumentList
2365 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2366public:
2367 static constexpr const char *pyClassName = "BlockArgumentList";
2368 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2369
2370 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2371 intptr_t startIndex = 0, intptr_t length = -1,
2372 intptr_t step = 1)
2373 : Sliceable(startIndex,
2374 length == -1 ? mlirBlockGetNumArguments(block) : length,
2375 step),
2376 operation(std::move(operation)), block(block) {}
2377
2378 static void bindDerived(ClassTy &c) {
2379 c.def_prop_ro(
2380 "types",
2381 [](PyBlockArgumentList &self) {
2382 return getValueTypes(self, self.operation->getContext());
2383 },
2384 "Returns a list of types for all arguments in this argument list.");
2385 }
2386
2387private:
2388 /// Give the parent CRTP class access to hook implementations below.
2389 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2390
2391 /// Returns the number of arguments in the list.
2392 intptr_t getRawNumElements() {
2393 operation->checkValid();
2394 return mlirBlockGetNumArguments(block);
2395 }
2396
2397 /// Returns `pos`-the element in the list.
2398 PyBlockArgument getRawElement(intptr_t pos) {
2399 MlirValue argument = mlirBlockGetArgument(block, pos);
2400 return PyBlockArgument(operation, argument);
2401 }
2402
2403 /// Returns a sublist of this list.
2404 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2405 intptr_t step) {
2406 return PyBlockArgumentList(operation, block, startIndex, length, step);
2407 }
2408
2409 PyOperationRef operation;
2410 MlirBlock block;
2411};
2412
2413/// A list of operation operands. Internally, these are stored as consecutive
2414/// elements, random access is cheap. The (returned) operand list is associated
2415/// with the operation whose operands these are, and thus extends the lifetime
2416/// of this operation.
2417class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2418public:
2419 static constexpr const char *pyClassName = "OpOperandList";
2420 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2421
2422 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2423 intptr_t length = -1, intptr_t step = 1)
2424 : Sliceable(startIndex,
2425 length == -1 ? mlirOperationGetNumOperands(operation->get())
2426 : length,
2427 step),
2428 operation(operation) {}
2429
2430 void dunderSetItem(intptr_t index, PyValue value) {
2431 index = wrapIndex(index);
2432 mlirOperationSetOperand(operation->get(), index, value.get());
2433 }
2434
2435 static void bindDerived(ClassTy &c) {
2436 c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
2437 nb::arg("value"),
2438 "Sets the operand at the specified index to a new value.");
2439 }
2440
2441private:
2442 /// Give the parent CRTP class access to hook implementations below.
2443 friend class Sliceable<PyOpOperandList, PyValue>;
2444
2445 intptr_t getRawNumElements() {
2446 operation->checkValid();
2447 return mlirOperationGetNumOperands(operation->get());
2448 }
2449
2450 PyValue getRawElement(intptr_t pos) {
2451 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2452 MlirOperation owner;
2453 if (mlirValueIsAOpResult(operand))
2454 owner = mlirOpResultGetOwner(operand);
2455 else if (mlirValueIsABlockArgument(operand))
2457 else
2458 assert(false && "Value must be an block arg or op result.");
2459 PyOperationRef pyOwner =
2460 PyOperation::forOperation(operation->getContext(), owner);
2461 return PyValue(pyOwner, operand);
2462 }
2463
2464 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2465 return PyOpOperandList(operation, startIndex, length, step);
2466 }
2467
2468 PyOperationRef operation;
2469};
2470
2471/// A list of operation successors. Internally, these are stored as consecutive
2472/// elements, random access is cheap. The (returned) successor list is
2473/// associated with the operation whose successors these are, and thus extends
2474/// the lifetime of this operation.
2475class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2476public:
2477 static constexpr const char *pyClassName = "OpSuccessors";
2478
2479 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2480 intptr_t length = -1, intptr_t step = 1)
2481 : Sliceable(startIndex,
2482 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2483 : length,
2484 step),
2485 operation(operation) {}
2486
2487 void dunderSetItem(intptr_t index, PyBlock block) {
2488 index = wrapIndex(index);
2489 mlirOperationSetSuccessor(operation->get(), index, block.get());
2490 }
2491
2492 static void bindDerived(ClassTy &c) {
2493 c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
2494 nb::arg("block"), "Sets the successor block at the specified index.");
2495 }
2496
2497private:
2498 /// Give the parent CRTP class access to hook implementations below.
2499 friend class Sliceable<PyOpSuccessors, PyBlock>;
2500
2501 intptr_t getRawNumElements() {
2502 operation->checkValid();
2503 return mlirOperationGetNumSuccessors(operation->get());
2504 }
2505
2506 PyBlock getRawElement(intptr_t pos) {
2507 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2508 return PyBlock(operation, block);
2509 }
2510
2511 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2512 return PyOpSuccessors(operation, startIndex, length, step);
2513 }
2514
2515 PyOperationRef operation;
2516};
2517
2518/// A list of block successors. Internally, these are stored as consecutive
2519/// elements, random access is cheap. The (returned) successor list is
2520/// associated with the operation and block whose successors these are, and thus
2521/// extends the lifetime of this operation and block.
2522class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2523public:
2524 static constexpr const char *pyClassName = "BlockSuccessors";
2525
2526 PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2527 intptr_t startIndex = 0, intptr_t length = -1,
2528 intptr_t step = 1)
2529 : Sliceable(startIndex,
2530 length == -1 ? mlirBlockGetNumSuccessors(block.get())
2531 : length,
2532 step),
2533 operation(operation), block(block) {}
2534
2535private:
2536 /// Give the parent CRTP class access to hook implementations below.
2537 friend class Sliceable<PyBlockSuccessors, PyBlock>;
2538
2539 intptr_t getRawNumElements() {
2540 block.checkValid();
2541 return mlirBlockGetNumSuccessors(block.get());
2542 }
2543
2544 PyBlock getRawElement(intptr_t pos) {
2545 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2546 return PyBlock(operation, block);
2547 }
2548
2549 PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2550 return PyBlockSuccessors(block, operation, startIndex, length, step);
2551 }
2552
2553 PyOperationRef operation;
2554 PyBlock block;
2555};
2556
2557/// A list of block predecessors. The (returned) predecessor list is
2558/// associated with the operation and block whose predecessors these are, and
2559/// thus extends the lifetime of this operation and block.
2560///
2561/// WARNING: This Sliceable is more expensive than the others here because
2562/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
2563/// operands) anew for each indexed access.
2564class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2565public:
2566 static constexpr const char *pyClassName = "BlockPredecessors";
2567
2568 PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2569 intptr_t startIndex = 0, intptr_t length = -1,
2570 intptr_t step = 1)
2571 : Sliceable(startIndex,
2572 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2573 : length,
2574 step),
2575 operation(operation), block(block) {}
2576
2577private:
2578 /// Give the parent CRTP class access to hook implementations below.
2579 friend class Sliceable<PyBlockPredecessors, PyBlock>;
2580
2581 intptr_t getRawNumElements() {
2582 block.checkValid();
2583 return mlirBlockGetNumPredecessors(block.get());
2584 }
2585
2586 PyBlock getRawElement(intptr_t pos) {
2587 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2588 return PyBlock(operation, block);
2589 }
2590
2591 PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2592 intptr_t step) {
2593 return PyBlockPredecessors(block, operation, startIndex, length, step);
2594 }
2595
2596 PyOperationRef operation;
2597 PyBlock block;
2598};
2599
2600/// A list of operation attributes. Can be indexed by name, producing
2601/// attributes, or by index, producing named attributes.
2602class PyOpAttributeMap {
2603public:
2604 PyOpAttributeMap(PyOperationRef operation)
2605 : operation(std::move(operation)) {}
2606
2607 nb::typed<nb::object, PyAttribute>
2608 dunderGetItemNamed(const std::string &name) {
2609 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2610 toMlirStringRef(name));
2611 if (mlirAttributeIsNull(attr)) {
2612 throw nb::key_error("attempt to access a non-existent attribute");
2613 }
2614 return PyAttribute(operation->getContext(), attr).maybeDownCast();
2615 }
2616
2617 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2618 if (index < 0) {
2619 index += dunderLen();
2620 }
2621 if (index < 0 || index >= dunderLen()) {
2622 throw nb::index_error("attempt to access out of bounds attribute");
2623 }
2624 MlirNamedAttribute namedAttr =
2625 mlirOperationGetAttribute(operation->get(), index);
2626 return PyNamedAttribute(
2627 namedAttr.attribute,
2628 std::string(mlirIdentifierStr(namedAttr.name).data,
2629 mlirIdentifierStr(namedAttr.name).length));
2630 }
2631
2632 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2634 attr);
2635 }
2636
2637 void dunderDelItem(const std::string &name) {
2638 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2639 toMlirStringRef(name));
2640 if (!removed)
2641 throw nb::key_error("attempt to delete a non-existent attribute");
2642 }
2643
2644 intptr_t dunderLen() {
2645 return mlirOperationGetNumAttributes(operation->get());
2646 }
2647
2648 bool dunderContains(const std::string &name) {
2649 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2650 operation->get(), toMlirStringRef(name)));
2651 }
2652
2653 static void
2654 forEachAttr(MlirOperation op,
2655 llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
2656 intptr_t n = mlirOperationGetNumAttributes(op);
2657 for (intptr_t i = 0; i < n; ++i) {
2660 fn(name, na.attribute);
2661 }
2662 }
2663
2664 static void bind(nb::module_ &m) {
2665 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2666 .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
2667 "Checks if an attribute with the given name exists in the map.")
2668 .def("__len__", &PyOpAttributeMap::dunderLen,
2669 "Returns the number of attributes in the map.")
2670 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
2671 nb::arg("name"), "Gets an attribute by name.")
2672 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
2673 nb::arg("index"), "Gets a named attribute by index.")
2674 .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
2675 nb::arg("attr"), "Sets an attribute with the given name.")
2676 .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
2677 "Deletes an attribute with the given name.")
2678 .def(
2679 "__iter__",
2680 [](PyOpAttributeMap &self) {
2681 nb::list keys;
2682 PyOpAttributeMap::forEachAttr(
2683 self.operation->get(),
2684 [&](MlirStringRef name, MlirAttribute) {
2685 keys.append(nb::str(name.data, name.length));
2686 });
2687 return nb::iter(keys);
2688 },
2689 "Iterates over attribute names.")
2690 .def(
2691 "keys",
2692 [](PyOpAttributeMap &self) {
2693 nb::list out;
2694 PyOpAttributeMap::forEachAttr(
2695 self.operation->get(),
2696 [&](MlirStringRef name, MlirAttribute) {
2697 out.append(nb::str(name.data, name.length));
2698 });
2699 return out;
2700 },
2701 "Returns a list of attribute names.")
2702 .def(
2703 "values",
2704 [](PyOpAttributeMap &self) {
2705 nb::list out;
2706 PyOpAttributeMap::forEachAttr(
2707 self.operation->get(),
2708 [&](MlirStringRef, MlirAttribute attr) {
2709 out.append(PyAttribute(self.operation->getContext(), attr)
2710 .maybeDownCast());
2711 });
2712 return out;
2713 },
2714 "Returns a list of attribute values.")
2715 .def(
2716 "items",
2717 [](PyOpAttributeMap &self) {
2718 nb::list out;
2719 PyOpAttributeMap::forEachAttr(
2720 self.operation->get(),
2721 [&](MlirStringRef name, MlirAttribute attr) {
2722 out.append(nb::make_tuple(
2723 nb::str(name.data, name.length),
2724 PyAttribute(self.operation->getContext(), attr)
2725 .maybeDownCast()));
2726 });
2727 return out;
2728 },
2729 "Returns a list of `(name, attribute)` tuples.");
2730 }
2731
2732private:
2733 PyOperationRef operation;
2734};
2735
2736// see
2737// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
2738
2739#ifndef _Py_CAST
2740#define _Py_CAST(type, expr) ((type)(expr))
2741#endif
2742
2743// Static inline functions should use _Py_NULL rather than using directly NULL
2744// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
2745// _Py_NULL is defined as nullptr.
2746#ifndef _Py_NULL
2747#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
2748 (defined(__cplusplus) && __cplusplus >= 201103)
2749#define _Py_NULL nullptr
2750#else
2751#define _Py_NULL NULL
2752#endif
2753#endif
2754
2755// Python 3.10.0a3
2756#if PY_VERSION_HEX < 0x030A00A3
2757
2758// bpo-42262 added Py_XNewRef()
2759#if !defined(Py_XNewRef)
2760[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
2761 Py_XINCREF(obj);
2762 return obj;
2763}
2764#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
2765#endif
2766
2767// bpo-42262 added Py_NewRef()
2768#if !defined(Py_NewRef)
2769[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
2770 Py_INCREF(obj);
2771 return obj;
2772}
2773#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
2774#endif
2775
2776#endif // Python 3.10.0a3
2777
2778// Python 3.9.0b1
2779#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
2780
2781// bpo-40429 added PyThreadState_GetFrame()
2782PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
2783 assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
2784 return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
2785}
2786
2787// bpo-40421 added PyFrame_GetBack()
2788PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
2789 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2790 return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
2791}
2792
2793// bpo-40421 added PyFrame_GetCode()
2794PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
2795 assert(frame != _Py_NULL && "expected frame != _Py_NULL");
2796 assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
2797 return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
2798}
2799
2800#endif // Python 3.9.0b1
2801
2802MlirLocation tracebackToLocation(MlirContext ctx) {
2803 size_t framesLimit =
2805 // Use a thread_local here to avoid requiring a large amount of space.
2806 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2807 frames;
2808 size_t count = 0;
2809
2810 nb::gil_scoped_acquire acquire;
2811 PyThreadState *tstate = PyThreadState_GET();
2812 PyFrameObject *next;
2813 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2814 // In the increment expression:
2815 // 1. get the next prev frame;
2816 // 2. decrement the ref count on the current frame (in order that it can get
2817 // gc'd, along with any objects in its closure and etc);
2818 // 3. set current = next.
2819 for (; pyFrame != nullptr && count < framesLimit;
2820 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2821 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2822 auto fileNameStr =
2823 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2824 llvm::StringRef fileName(fileNameStr);
2825 if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2826 continue;
2827
2828 // co_qualname and PyCode_Addr2Location added in py3.11
2829#if PY_VERSION_HEX < 0x030B00F0
2830 std::string name =
2831 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2832 llvm::StringRef funcName(name);
2833 int startLine = PyFrame_GetLineNumber(pyFrame);
2834 MlirLocation loc =
2835 mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2836#else
2837 std::string name =
2838 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2839 llvm::StringRef funcName(name);
2840 int startLine, startCol, endLine, endCol;
2841 int lasti = PyFrame_GetLasti(pyFrame);
2842 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2843 &endCol)) {
2844 throw nb::python_error();
2845 }
2846 MlirLocation loc = mlirLocationFileLineColRangeGet(
2847 ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2848#endif
2849
2850 frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2851 ++count;
2852 }
2853 // When the loop breaks (after the last iter), current frame (if non-null)
2854 // is leaked without this.
2855 Py_XDECREF(pyFrame);
2856
2857 if (count == 0)
2858 return mlirLocationUnknownGet(ctx);
2859
2860 MlirLocation callee = frames[0];
2861 assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2862 if (count == 1)
2863 return callee;
2864
2865 MlirLocation caller = frames[count - 1];
2866 assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2867 for (int i = count - 2; i >= 1; i--)
2868 caller = mlirLocationCallSiteGet(frames[i], caller);
2869
2870 return mlirLocationCallSiteGet(callee, caller);
2871}
2872
2874maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2875 if (location.has_value())
2876 return location.value();
2877 if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2879
2881 MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2883 return {ref, mlirLoc};
2884}
2885
2886} // namespace
2887
2888//------------------------------------------------------------------------------
2889// Populates the core exports of the 'ir' submodule.
2890//------------------------------------------------------------------------------
2891
2892void mlir::python::populateIRCore(nb::module_ &m) {
2893 // disable leak warnings which tend to be false positives.
2894 nb::set_leak_warnings(false);
2895 //----------------------------------------------------------------------------
2896 // Enums.
2897 //----------------------------------------------------------------------------
2898 nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2899 .value("ERROR", MlirDiagnosticError)
2900 .value("WARNING", MlirDiagnosticWarning)
2901 .value("NOTE", MlirDiagnosticNote)
2902 .value("REMARK", MlirDiagnosticRemark);
2903
2904 nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2905 .value("PRE_ORDER", MlirWalkPreOrder)
2906 .value("POST_ORDER", MlirWalkPostOrder);
2907
2908 nb::enum_<MlirWalkResult>(m, "WalkResult")
2909 .value("ADVANCE", MlirWalkResultAdvance)
2910 .value("INTERRUPT", MlirWalkResultInterrupt)
2911 .value("SKIP", MlirWalkResultSkip);
2912
2913 //----------------------------------------------------------------------------
2914 // Mapping of Diagnostics.
2915 //----------------------------------------------------------------------------
2916 nb::class_<PyDiagnostic>(m, "Diagnostic")
2917 .def_prop_ro("severity", &PyDiagnostic::getSeverity,
2918 "Returns the severity of the diagnostic.")
2919 .def_prop_ro("location", &PyDiagnostic::getLocation,
2920 "Returns the location associated with the diagnostic.")
2921 .def_prop_ro("message", &PyDiagnostic::getMessage,
2922 "Returns the message text of the diagnostic.")
2923 .def_prop_ro("notes", &PyDiagnostic::getNotes,
2924 "Returns a tuple of attached note diagnostics.")
2925 .def(
2926 "__str__",
2927 [](PyDiagnostic &self) -> nb::str {
2928 if (!self.isValid())
2929 return nb::str("<Invalid Diagnostic>");
2930 return self.getMessage();
2931 },
2932 "Returns the diagnostic message as a string.");
2933
2934 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2935 .def(
2936 "__init__",
2938 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2939 },
2940 "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
2941 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
2942 "The severity level of the diagnostic.")
2943 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
2944 "The location associated with the diagnostic.")
2945 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
2946 "The message text of the diagnostic.")
2947 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
2948 "List of attached note diagnostics.")
2949 .def(
2950 "__str__",
2951 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
2952 "Returns the diagnostic message as a string.");
2953
2954 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2955 .def("detach", &PyDiagnosticHandler::detach,
2956 "Detaches the diagnostic handler from the context.")
2957 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
2958 "Returns True if the handler is attached to a context.")
2959 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
2960 "Returns True if an error was encountered during diagnostic "
2961 "handling.")
2962 .def("__enter__", &PyDiagnosticHandler::contextEnter,
2963 "Enters the diagnostic handler as a context manager.")
2964 .def("__exit__", &PyDiagnosticHandler::contextExit,
2965 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2966 nb::arg("traceback").none(),
2967 "Exits the diagnostic handler context manager.");
2968
2969 // Expose DefaultThreadPool to python
2970 nb::class_<PyThreadPool>(m, "ThreadPool")
2971 .def(
2972 "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
2973 "Creates a new thread pool with default concurrency.")
2974 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
2975 "Returns the maximum number of threads in the pool.")
2976 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
2977 "Returns the raw pointer to the LLVM thread pool as a string.");
2978
2979 nb::class_<PyMlirContext>(m, "Context")
2980 .def(
2981 "__init__",
2982 [](PyMlirContext &self) {
2983 MlirContext context = mlirContextCreateWithThreading(false);
2984 new (&self) PyMlirContext(context);
2985 },
2986 R"(
2987 Creates a new MLIR context.
2988
2989 The context is the top-level container for all MLIR objects. It owns the storage
2990 for types, attributes, locations, and other core IR objects. A context can be
2991 configured to allow or disallow unregistered dialects and can have dialects
2992 loaded on-demand.)")
2993 .def_static("_get_live_count", &PyMlirContext::getLiveCount,
2994 "Gets the number of live Context objects.")
2995 .def(
2996 "_get_context_again",
2997 [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
2999 return ref.releaseObject();
3000 },
3001 "Gets another reference to the same context.")
3002 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
3003 "Gets the number of live modules owned by this context.")
3005 "Gets a capsule wrapping the MlirContext.")
3008 "Creates a Context from a capsule wrapping MlirContext.")
3009 .def("__enter__", &PyMlirContext::contextEnter,
3010 "Enters the context as a context manager.")
3011 .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
3012 nb::arg("exc_value").none(), nb::arg("traceback").none(),
3013 "Exits the context manager.")
3014 .def_prop_ro_static(
3015 "current",
3016 [](nb::object & /*class*/)
3017 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3019 if (!context)
3020 return {};
3021 return nb::cast(context);
3022 },
3023 nb::sig("def current(/) -> Context | None"),
3024 "Gets the Context bound to the current thread or returns None if no "
3025 "context is set.")
3026 .def_prop_ro(
3027 "dialects",
3028 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3029 "Gets a container for accessing dialects by name.")
3030 .def_prop_ro(
3031 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3032 "Alias for `dialects`.")
3033 .def(
3034 "get_dialect_descriptor",
3035 [=](PyMlirContext &self, std::string &name) {
3036 MlirDialect dialect = mlirContextGetOrLoadDialect(
3037 self.get(), {name.data(), name.size()});
3038 if (mlirDialectIsNull(dialect)) {
3039 throw nb::value_error(
3040 (Twine("Dialect '") + name + "' not found").str().c_str());
3041 }
3042 return PyDialectDescriptor(self.getRef(), dialect);
3043 },
3044 nb::arg("dialect_name"),
3045 "Gets or loads a dialect by name, returning its descriptor object.")
3046 .def_prop_rw(
3047 "allow_unregistered_dialects",
3048 [](PyMlirContext &self) -> bool {
3050 },
3051 [](PyMlirContext &self, bool value) {
3053 },
3054 "Controls whether unregistered dialects are allowed in this context.")
3055 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
3056 nb::arg("callback"),
3057 "Attaches a diagnostic handler that will receive callbacks.")
3058 .def(
3059 "enable_multithreading",
3060 [](PyMlirContext &self, bool enable) {
3061 mlirContextEnableMultithreading(self.get(), enable);
3062 },
3063 nb::arg("enable"),
3064 R"(
3065 Enables or disables multi-threading support in the context.
3066
3067 Args:
3068 enable: Whether to enable (True) or disable (False) multi-threading.
3069 )")
3070 .def(
3071 "set_thread_pool",
3072 [](PyMlirContext &self, PyThreadPool &pool) {
3073 // we should disable multi-threading first before setting
3074 // new thread pool otherwise the assert in
3075 // MLIRContext::setThreadPool will be raised.
3076 mlirContextEnableMultithreading(self.get(), false);
3077 mlirContextSetThreadPool(self.get(), pool.get());
3078 },
3079 R"(
3080 Sets a custom thread pool for the context to use.
3081
3082 Args:
3083 pool: A ThreadPool object to use for parallel operations.
3084
3085 Note:
3086 Multi-threading is automatically disabled before setting the thread pool.)")
3087 .def(
3088 "get_num_threads",
3089 [](PyMlirContext &self) {
3090 return mlirContextGetNumThreads(self.get());
3091 },
3092 "Gets the number of threads in the context's thread pool.")
3093 .def(
3094 "_mlir_thread_pool_ptr",
3095 [](PyMlirContext &self) {
3096 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
3097 std::stringstream ss;
3098 ss << pool.ptr;
3099 return ss.str();
3100 },
3101 "Gets the raw pointer to the LLVM thread pool as a string.")
3102 .def(
3103 "is_registered_operation",
3104 [](PyMlirContext &self, std::string &name) {
3106 self.get(), MlirStringRef{name.data(), name.size()});
3107 },
3108 nb::arg("operation_name"),
3109 R"(
3110 Checks whether an operation with the given name is registered.
3111
3112 Args:
3113 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3114
3115 Returns:
3116 True if the operation is registered, False otherwise.)")
3117 .def(
3118 "append_dialect_registry",
3119 [](PyMlirContext &self, PyDialectRegistry &registry) {
3120 mlirContextAppendDialectRegistry(self.get(), registry);
3121 },
3122 nb::arg("registry"),
3123 R"(
3124 Appends the contents of a dialect registry to the context.
3125
3126 Args:
3127 registry: A DialectRegistry containing dialects to append.)")
3128 .def_prop_rw("emit_error_diagnostics",
3131 R"(
3132 Controls whether error diagnostics are emitted to diagnostic handlers.
3133
3134 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3135 .def(
3136 "load_all_available_dialects",
3137 [](PyMlirContext &self) {
3139 },
3140 R"(
3141 Loads all dialects available in the registry into the context.
3142
3143 This eagerly loads all dialects that have been registered, making them
3144 immediately available for use.)");
3145
3146 //----------------------------------------------------------------------------
3147 // Mapping of PyDialectDescriptor
3148 //----------------------------------------------------------------------------
3149 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3150 .def_prop_ro(
3151 "namespace",
3152 [](PyDialectDescriptor &self) {
3154 return nb::str(ns.data, ns.length);
3155 },
3156 "Returns the namespace of the dialect.")
3157 .def(
3158 "__repr__",
3159 [](PyDialectDescriptor &self) {
3161 std::string repr("<DialectDescriptor ");
3162 repr.append(ns.data, ns.length);
3163 repr.append(">");
3164 return repr;
3165 },
3166 nb::sig("def __repr__(self) -> str"),
3167 "Returns a string representation of the dialect descriptor.");
3168
3169 //----------------------------------------------------------------------------
3170 // Mapping of PyDialects
3171 //----------------------------------------------------------------------------
3172 nb::class_<PyDialects>(m, "Dialects")
3173 .def(
3174 "__getitem__",
3175 [=](PyDialects &self, std::string keyName) {
3176 MlirDialect dialect =
3177 self.getDialectForKey(keyName, /*attrError=*/false);
3178 nb::object descriptor =
3179 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3180 return createCustomDialectWrapper(keyName, std::move(descriptor));
3181 },
3182 "Gets a dialect by name using subscript notation.")
3183 .def(
3184 "__getattr__",
3185 [=](PyDialects &self, std::string attrName) {
3186 MlirDialect dialect =
3187 self.getDialectForKey(attrName, /*attrError=*/true);
3188 nb::object descriptor =
3189 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
3190 return createCustomDialectWrapper(attrName, std::move(descriptor));
3191 },
3192 "Gets a dialect by name using attribute notation.");
3193
3194 //----------------------------------------------------------------------------
3195 // Mapping of PyDialect
3196 //----------------------------------------------------------------------------
3197 nb::class_<PyDialect>(m, "Dialect")
3198 .def(nb::init<nb::object>(), nb::arg("descriptor"),
3199 "Creates a Dialect from a DialectDescriptor.")
3200 .def_prop_ro(
3201 "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
3202 "Returns the DialectDescriptor for this dialect.")
3203 .def(
3204 "__repr__",
3205 [](const nb::object &self) {
3206 auto clazz = self.attr("__class__");
3207 return nb::str("<Dialect ") +
3208 self.attr("descriptor").attr("namespace") +
3209 nb::str(" (class ") + clazz.attr("__module__") +
3210 nb::str(".") + clazz.attr("__name__") + nb::str(")>");
3211 },
3212 nb::sig("def __repr__(self) -> str"),
3213 "Returns a string representation of the dialect.");
3214
3215 //----------------------------------------------------------------------------
3216 // Mapping of PyDialectRegistry
3217 //----------------------------------------------------------------------------
3218 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3220 "Gets a capsule wrapping the MlirDialectRegistry.")
3223 "Creates a DialectRegistry from a capsule wrapping "
3224 "`MlirDialectRegistry`.")
3225 .def(nb::init<>(), "Creates a new empty dialect registry.");
3226
3227 //----------------------------------------------------------------------------
3228 // Mapping of Location
3229 //----------------------------------------------------------------------------
3230 nb::class_<PyLocation>(m, "Location")
3232 "Gets a capsule wrapping the MlirLocation.")
3234 "Creates a Location from a capsule wrapping MlirLocation.")
3235 .def("__enter__", &PyLocation::contextEnter,
3236 "Enters the location as a context manager.")
3237 .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
3238 nb::arg("exc_value").none(), nb::arg("traceback").none(),
3239 "Exits the location context manager.")
3240 .def(
3241 "__eq__",
3242 [](PyLocation &self, PyLocation &other) -> bool {
3243 return mlirLocationEqual(self, other);
3244 },
3245 "Compares two locations for equality.")
3246 .def(
3247 "__eq__", [](PyLocation &self, nb::object other) { return false; },
3248 "Compares location with non-location object (always returns False).")
3249 .def_prop_ro_static(
3250 "current",
3251 [](nb::object & /*class*/) -> std::optional<PyLocation *> {
3253 if (!loc)
3254 return std::nullopt;
3255 return loc;
3256 },
3257 // clang-format off
3258 nb::sig("def current(/) -> Location | None"),
3259 // clang-format on
3260 "Gets the Location bound to the current thread or raises ValueError.")
3261 .def_static(
3262 "unknown",
3263 [](DefaultingPyMlirContext context) {
3264 return PyLocation(context->getRef(),
3265 mlirLocationUnknownGet(context->get()));
3266 },
3267 nb::arg("context") = nb::none(),
3268 "Gets a Location representing an unknown location.")
3269 .def_static(
3270 "callsite",
3271 [](PyLocation callee, const std::vector<PyLocation> &frames,
3272 DefaultingPyMlirContext context) {
3273 if (frames.empty())
3274 throw nb::value_error("No caller frames provided.");
3275 MlirLocation caller = frames.back().get();
3276 for (const PyLocation &frame :
3277 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
3278 caller = mlirLocationCallSiteGet(frame.get(), caller);
3279 return PyLocation(context->getRef(),
3280 mlirLocationCallSiteGet(callee.get(), caller));
3281 },
3282 nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
3283 "Gets a Location representing a caller and callsite.")
3284 .def("is_a_callsite", mlirLocationIsACallSite,
3285 "Returns True if this location is a CallSiteLoc.")
3286 .def_prop_ro(
3287 "callee",
3288 [](PyLocation &self) {
3289 return PyLocation(self.getContext(),
3291 },
3292 "Gets the callee location from a CallSiteLoc.")
3293 .def_prop_ro(
3294 "caller",
3295 [](PyLocation &self) {
3296 return PyLocation(self.getContext(),
3298 },
3299 "Gets the caller location from a CallSiteLoc.")
3300 .def_static(
3301 "file",
3302 [](std::string filename, int line, int col,
3303 DefaultingPyMlirContext context) {
3304 return PyLocation(
3305 context->getRef(),
3307 context->get(), toMlirStringRef(filename), line, col));
3308 },
3309 nb::arg("filename"), nb::arg("line"), nb::arg("col"),
3310 nb::arg("context") = nb::none(),
3311 "Gets a Location representing a file, line and column.")
3312 .def_static(
3313 "file",
3314 [](std::string filename, int startLine, int startCol, int endLine,
3315 int endCol, DefaultingPyMlirContext context) {
3316 return PyLocation(context->getRef(),
3318 context->get(), toMlirStringRef(filename),
3319 startLine, startCol, endLine, endCol));
3320 },
3321 nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
3322 nb::arg("end_line"), nb::arg("end_col"),
3323 nb::arg("context") = nb::none(),
3324 "Gets a Location representing a file, line and column range.")
3325 .def("is_a_file", mlirLocationIsAFileLineColRange,
3326 "Returns True if this location is a FileLineColLoc.")
3327 .def_prop_ro(
3328 "filename",
3329 [](MlirLocation loc) {
3330 return mlirIdentifierStr(
3332 },
3333 "Gets the filename from a FileLineColLoc.")
3334 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
3335 "Gets the start line number from a `FileLineColLoc`.")
3336 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
3337 "Gets the start column number from a `FileLineColLoc`.")
3338 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
3339 "Gets the end line number from a `FileLineColLoc`.")
3340 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
3341 "Gets the end column number from a `FileLineColLoc`.")
3342 .def_static(
3343 "fused",
3344 [](const std::vector<PyLocation> &pyLocations,
3345 std::optional<PyAttribute> metadata,
3346 DefaultingPyMlirContext context) {
3347 llvm::SmallVector<MlirLocation, 4> locations;
3348 locations.reserve(pyLocations.size());
3349 for (auto &pyLocation : pyLocations)
3350 locations.push_back(pyLocation.get());
3351 MlirLocation location = mlirLocationFusedGet(
3352 context->get(), locations.size(), locations.data(),
3353 metadata ? metadata->get() : MlirAttribute{0});
3354 return PyLocation(context->getRef(), location);
3355 },
3356 nb::arg("locations"), nb::arg("metadata") = nb::none(),
3357 nb::arg("context") = nb::none(),
3358 "Gets a Location representing a fused location with optional "
3359 "metadata.")
3360 .def("is_a_fused", mlirLocationIsAFused,
3361 "Returns True if this location is a `FusedLoc`.")
3362 .def_prop_ro(
3363 "locations",
3364 [](PyLocation &self) {
3365 unsigned numLocations = mlirLocationFusedGetNumLocations(self);
3366 std::vector<MlirLocation> locations(numLocations);
3367 if (numLocations)
3368 mlirLocationFusedGetLocations(self, locations.data());
3369 std::vector<PyLocation> pyLocations{};
3370 pyLocations.reserve(numLocations);
3371 for (unsigned i = 0; i < numLocations; ++i)
3372 pyLocations.emplace_back(self.getContext(), locations[i]);
3373 return pyLocations;
3374 },
3375 "Gets the list of locations from a `FusedLoc`.")
3376 .def_static(
3377 "name",
3378 [](std::string name, std::optional<PyLocation> childLoc,
3379 DefaultingPyMlirContext context) {
3380 return PyLocation(
3381 context->getRef(),
3383 context->get(), toMlirStringRef(name),
3384 childLoc ? childLoc->get()
3385 : mlirLocationUnknownGet(context->get())));
3386 },
3387 nb::arg("name"), nb::arg("childLoc") = nb::none(),
3388 nb::arg("context") = nb::none(),
3389 "Gets a Location representing a named location with optional child "
3390 "location.")
3391 .def("is_a_name", mlirLocationIsAName,
3392 "Returns True if this location is a `NameLoc`.")
3393 .def_prop_ro(
3394 "name_str",
3395 [](MlirLocation loc) {
3397 },
3398 "Gets the name string from a `NameLoc`.")
3399 .def_prop_ro(
3400 "child_loc",
3401 [](PyLocation &self) {
3402 return PyLocation(self.getContext(),
3404 },
3405 "Gets the child location from a `NameLoc`.")
3406 .def_static(
3407 "from_attr",
3408 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3409 return PyLocation(context->getRef(),
3410 mlirLocationFromAttribute(attribute));
3411 },
3412 nb::arg("attribute"), nb::arg("context") = nb::none(),
3413 "Gets a Location from a `LocationAttr`.")
3414 .def_prop_ro(
3415 "context",
3416 [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3417 return self.getContext().getObject();
3418 },
3419 "Context that owns the `Location`.")
3420 .def_prop_ro(
3421 "attr",
3422 [](PyLocation &self) {
3423 return PyAttribute(self.getContext(),
3425 },
3426 "Get the underlying `LocationAttr`.")
3427 .def(
3428 "emit_error",
3429 [](PyLocation &self, std::string message) {
3430 mlirEmitError(self, message.c_str());
3431 },
3432 nb::arg("message"),
3433 R"(
3434 Emits an error diagnostic at this location.
3435
3436 Args:
3437 message: The error message to emit.)")
3438 .def(
3439 "__repr__",
3440 [](PyLocation &self) {
3441 PyPrintAccumulator printAccum;
3442 mlirLocationPrint(self, printAccum.getCallback(),
3443 printAccum.getUserData());
3444 return printAccum.join();
3445 },
3446 "Returns the assembly representation of the location.");
3447
3448 //----------------------------------------------------------------------------
3449 // Mapping of Module
3450 //----------------------------------------------------------------------------
3451 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3453 "Gets a capsule wrapping the MlirModule.")
3455 R"(
3456 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3457
3458 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3459 prevent double-frees (of the underlying `mlir::Module`).)")
3460 .def("_clear_mlir_module", &PyModule::clearMlirModule,
3461 R"(
3462 Clears the internal MLIR module reference.
3463
3464 This is used internally to prevent double-free when ownership is transferred
3465 via the C API capsule mechanism. Not intended for normal use.)")
3466 .def_static(
3467 "parse",
3468 [](const std::string &moduleAsm, DefaultingPyMlirContext context)
3469 -> nb::typed<nb::object, PyModule> {
3470 PyMlirContext::ErrorCapture errors(context->getRef());
3471 MlirModule module = mlirModuleCreateParse(
3472 context->get(), toMlirStringRef(moduleAsm));
3473 if (mlirModuleIsNull(module))
3474 throw MLIRError("Unable to parse module assembly", errors.take());
3475 return PyModule::forModule(module).releaseObject();
3476 },
3477 nb::arg("asm"), nb::arg("context") = nb::none(),
3479 .def_static(
3480 "parse",
3481 [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
3482 -> nb::typed<nb::object, PyModule> {
3483 PyMlirContext::ErrorCapture errors(context->getRef());
3484 MlirModule module = mlirModuleCreateParse(
3485 context->get(), toMlirStringRef(moduleAsm));
3486 if (mlirModuleIsNull(module))
3487 throw MLIRError("Unable to parse module assembly", errors.take());
3488 return PyModule::forModule(module).releaseObject();
3489 },
3490 nb::arg("asm"), nb::arg("context") = nb::none(),
3492 .def_static(
3493 "parseFile",
3494 [](const std::string &path, DefaultingPyMlirContext context)
3495 -> nb::typed<nb::object, PyModule> {
3496 PyMlirContext::ErrorCapture errors(context->getRef());
3497 MlirModule module = mlirModuleCreateParseFromFile(
3498 context->get(), toMlirStringRef(path));
3499 if (mlirModuleIsNull(module))
3500 throw MLIRError("Unable to parse module assembly", errors.take());
3501 return PyModule::forModule(module).releaseObject();
3502 },
3503 nb::arg("path"), nb::arg("context") = nb::none(),
3505 .def_static(
3506 "create",
3507 [](const std::optional<PyLocation> &loc)
3508 -> nb::typed<nb::object, PyModule> {
3509 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3510 MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
3511 return PyModule::forModule(module).releaseObject();
3512 },
3513 nb::arg("loc") = nb::none(), "Creates an empty module.")
3514 .def_prop_ro(
3515 "context",
3516 [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3517 return self.getContext().getObject();
3518 },
3519 "Context that created the `Module`.")
3520 .def_prop_ro(
3521 "operation",
3522 [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
3525 self.getRef().releaseObject())
3526 .releaseObject();
3527 },
3528 "Accesses the module as an operation.")
3529 .def_prop_ro(
3530 "body",
3531 [](PyModule &self) {
3533 self.getContext(), mlirModuleGetOperation(self.get()),
3534 self.getRef().releaseObject());
3535 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3536 return returnBlock;
3537 },
3538 "Return the block for this module.")
3539 .def(
3540 "dump",
3541 [](PyModule &self) {
3543 },
3545 .def(
3546 "__str__",
3547 [](const nb::object &self) {
3548 // Defer to the operation's __str__.
3549 return self.attr("operation").attr("__str__")();
3550 },
3551 nb::sig("def __str__(self) -> str"),
3552 R"(
3553 Gets the assembly form of the operation with default options.
3554
3555 If more advanced control over the assembly formatting or I/O options is needed,
3556 use the dedicated print or get_asm method, which supports keyword arguments to
3557 customize behavior.
3558 )")
3559 .def(
3560 "__eq__",
3561 [](PyModule &self, PyModule &other) {
3562 return mlirModuleEqual(self.get(), other.get());
3563 },
3564 "other"_a, "Compares two modules for equality.")
3565 .def(
3566 "__hash__",
3567 [](PyModule &self) { return mlirModuleHashValue(self.get()); },
3568 "Returns the hash value of the module.");
3569
3570 //----------------------------------------------------------------------------
3571 // Mapping of Operation.
3572 //----------------------------------------------------------------------------
3573 nb::class_<PyOperationBase>(m, "_OperationBase")
3574 .def_prop_ro(
3576 [](PyOperationBase &self) {
3577 return self.getOperation().getCapsule();
3578 },
3579 "Gets a capsule wrapping the `MlirOperation`.")
3580 .def(
3581 "__eq__",
3582 [](PyOperationBase &self, PyOperationBase &other) {
3583 return mlirOperationEqual(self.getOperation().get(),
3584 other.getOperation().get());
3585 },
3586 "Compares two operations for equality.")
3587 .def(
3588 "__eq__",
3589 [](PyOperationBase &self, nb::object other) { return false; },
3590 "Compares operation with non-operation object (always returns "
3591 "False).")
3592 .def(
3593 "__hash__",
3594 [](PyOperationBase &self) {
3595 return mlirOperationHashValue(self.getOperation().get());
3596 },
3597 "Returns the hash value of the operation.")
3598 .def_prop_ro(
3599 "attributes",
3600 [](PyOperationBase &self) {
3601 return PyOpAttributeMap(self.getOperation().getRef());
3602 },
3603 "Returns a dictionary-like map of operation attributes.")
3604 .def_prop_ro(
3605 "context",
3606 [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
3607 PyOperation &concreteOperation = self.getOperation();
3608 concreteOperation.checkValid();
3609 return concreteOperation.getContext().getObject();
3610 },
3611 "Context that owns the operation.")
3612 .def_prop_ro(
3613 "name",
3614 [](PyOperationBase &self) {
3615 auto &concreteOperation = self.getOperation();
3616 concreteOperation.checkValid();
3617 MlirOperation operation = concreteOperation.get();
3618 return mlirIdentifierStr(mlirOperationGetName(operation));
3619 },
3620 "Returns the fully qualified name of the operation.")
3621 .def_prop_ro(
3622 "operands",
3623 [](PyOperationBase &self) {
3624 return PyOpOperandList(self.getOperation().getRef());
3625 },
3626 "Returns the list of operation operands.")
3627 .def_prop_ro(
3628 "regions",
3629 [](PyOperationBase &self) {
3630 return PyRegionList(self.getOperation().getRef());
3631 },
3632 "Returns the list of operation regions.")
3633 .def_prop_ro(
3634 "results",
3635 [](PyOperationBase &self) {
3636 return PyOpResultList(self.getOperation().getRef());
3637 },
3638 "Returns the list of Operation results.")
3639 .def_prop_ro(
3640 "result",
3641 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
3642 auto &operation = self.getOperation();
3643 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3644 .maybeDownCast();
3645 },
3646 "Shortcut to get an op result if it has only one (throws an error "
3647 "otherwise).")
3648 .def_prop_rw(
3649 "location",
3650 [](PyOperationBase &self) {
3651 PyOperation &operation = self.getOperation();
3652 return PyLocation(operation.getContext(),
3653 mlirOperationGetLocation(operation.get()));
3654 },
3655 [](PyOperationBase &self, const PyLocation &location) {
3656 PyOperation &operation = self.getOperation();
3657 mlirOperationSetLocation(operation.get(), location.get());
3658 },
3659 nb::for_getter("Returns the source location the operation was "
3660 "defined or derived from."),
3661 nb::for_setter("Sets the source location the operation was defined "
3662 "or derived from."))
3663 .def_prop_ro(
3664 "parent",
3665 [](PyOperationBase &self)
3666 -> std::optional<nb::typed<nb::object, PyOperation>> {
3667 auto parent = self.getOperation().getParentOperation();
3668 if (parent)
3669 return parent->getObject();
3670 return {};
3671 },
3672 "Returns the parent operation, or `None` if at top level.")
3673 .def(
3674 "__str__",
3675 [](PyOperationBase &self) {
3676 return self.getAsm(/*binary=*/false,
3677 /*largeElementsLimit=*/std::nullopt,
3678 /*largeResourceLimit=*/std::nullopt,
3679 /*enableDebugInfo=*/false,
3680 /*prettyDebugInfo=*/false,
3681 /*printGenericOpForm=*/false,
3682 /*useLocalScope=*/false,
3683 /*useNameLocAsPrefix=*/false,
3684 /*assumeVerified=*/false,
3685 /*skipRegions=*/false);
3686 },
3687 nb::sig("def __str__(self) -> str"),
3688 "Returns the assembly form of the operation.")
3689 .def("print",
3690 nb::overload_cast<PyAsmState &, nb::object, bool>(
3692 nb::arg("state"), nb::arg("file") = nb::none(),
3693 nb::arg("binary") = false,
3694 R"(
3695 Prints the assembly form of the operation to a file like object.
3696
3697 Args:
3698 state: `AsmState` capturing the operation numbering and flags.
3699 file: Optional file like object to write to. Defaults to sys.stdout.
3700 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3701 .def("print",
3702 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3703 bool, bool, bool, bool, bool, bool, nb::object,
3704 bool, bool>(&PyOperationBase::print),
3705 // Careful: Lots of arguments must match up with print method.
3706 nb::arg("large_elements_limit") = nb::none(),
3707 nb::arg("large_resource_limit") = nb::none(),
3708 nb::arg("enable_debug_info") = false,
3709 nb::arg("pretty_debug_info") = false,
3710 nb::arg("print_generic_op_form") = false,
3711 nb::arg("use_local_scope") = false,
3712 nb::arg("use_name_loc_as_prefix") = false,
3713 nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
3714 nb::arg("binary") = false, nb::arg("skip_regions") = false,
3715 R"(
3716 Prints the assembly form of the operation to a file like object.
3717
3718 Args:
3719 large_elements_limit: Whether to elide elements attributes above this
3720 number of elements. Defaults to None (no limit).
3721 large_resource_limit: Whether to elide resource attributes above this
3722 number of characters. Defaults to None (no limit). If large_elements_limit
3723 is set and this is None, the behavior will be to use large_elements_limit
3724 as large_resource_limit.
3725 enable_debug_info: Whether to print debug/location information. Defaults
3726 to False.
3727 pretty_debug_info: Whether to format debug information for easier reading
3728 by a human (warning: the result is unparseable). Defaults to False.
3729 print_generic_op_form: Whether to print the generic assembly forms of all
3730 ops. Defaults to False.
3731 use_local_scope: Whether to print in a way that is more optimized for
3732 multi-threaded access but may not be consistent with how the overall
3733 module prints.
3734 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3735 prefixes for the SSA identifiers. Defaults to False.
3736 assume_verified: By default, if not printing generic form, the verifier
3737 will be run and if it fails, generic form will be printed with a comment
3738 about failed verification. While a reasonable default for interactive use,
3739 for systematic use, it is often better for the caller to verify explicitly
3740 and report failures in a more robust fashion. Set this to True if doing this
3741 in order to avoid running a redundant verification. If the IR is actually
3742 invalid, behavior is undefined.
3743 file: The file like object to write to. Defaults to sys.stdout.
3744 binary: Whether to write bytes (True) or str (False). Defaults to False.
3745 skip_regions: Whether to skip printing regions. Defaults to False.)")
3746 .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3747 nb::arg("desired_version") = nb::none(),
3748 R"(
3749 Write the bytecode form of the operation to a file like object.
3750
3751 Args:
3752 file: The file like object to write to.
3753 desired_version: Optional version of bytecode to emit.
3754 Returns:
3755 The bytecode writer status.)")
3756 .def("get_asm", &PyOperationBase::getAsm,
3757 // Careful: Lots of arguments must match up with get_asm method.
3758 nb::arg("binary") = false,
3759 nb::arg("large_elements_limit") = nb::none(),
3760 nb::arg("large_resource_limit") = nb::none(),
3761 nb::arg("enable_debug_info") = false,
3762 nb::arg("pretty_debug_info") = false,
3763 nb::arg("print_generic_op_form") = false,
3764 nb::arg("use_local_scope") = false,
3765 nb::arg("use_name_loc_as_prefix") = false,
3766 nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3767 R"(
3768 Gets the assembly form of the operation with all options available.
3769
3770 Args:
3771 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3772 False.
3773 ... others ...: See the print() method for common keyword arguments for
3774 configuring the printout.
3775 Returns:
3776 Either a bytes or str object, depending on the setting of the `binary`
3777 argument.)")
3778 .def("verify", &PyOperationBase::verify,
3779 "Verify the operation. Raises MLIRError if verification fails, and "
3780 "returns true otherwise.")
3781 .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3782 "Puts self immediately after the other operation in its parent "
3783 "block.")
3784 .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3785 "Puts self immediately before the other operation in its parent "
3786 "block.")
3787 .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
3788 nb::arg("other"),
3789 R"(
3790 Checks if this operation is before another in the same block.
3791
3792 Args:
3793 other: Another operation in the same parent block.
3794
3795 Returns:
3796 True if this operation is before `other` in the operation list of the parent block.)")
3797 .def(
3798 "clone",
3799 [](PyOperationBase &self,
3800 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3801 return self.getOperation().clone(ip);
3802 },
3803 nb::arg("ip") = nb::none(),
3804 R"(
3805 Creates a deep copy of the operation.
3806
3807 Args:
3808 ip: Optional insertion point where the cloned operation should be inserted.
3809 If None, the current insertion point is used. If False, the operation
3810 remains detached.
3811
3812 Returns:
3813 A new Operation that is a clone of this operation.)")
3814 .def(
3815 "detach_from_parent",
3816 [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
3817 PyOperation &operation = self.getOperation();
3818 operation.checkValid();
3819 if (!operation.isAttached())
3820 throw nb::value_error("Detached operation has no parent.");
3821
3822 operation.detachFromParent();
3823 return operation.createOpView();
3824 },
3825 "Detaches the operation from its parent block.")
3826 .def_prop_ro(
3827 "attached",
3828 [](PyOperationBase &self) {
3829 PyOperation &operation = self.getOperation();
3830 operation.checkValid();
3831 return operation.isAttached();
3832 },
3833 "Reports if the operation is attached to its parent block.")
3834 .def(
3835 "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
3836 R"(
3837 Erases the operation and frees its memory.
3838
3839 Note:
3840 After erasing, any Python references to the operation become invalid.)")
3841 .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3842 nb::arg("walk_order") = MlirWalkPostOrder,
3843 // clang-format off
3844 nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
3845 // clang-format on
3846 R"(
3847 Walks the operation tree with a callback function.
3848
3849 Args:
3850 callback: A callable that takes an Operation and returns a WalkResult.
3851 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
3852
3853 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3854 .def_static(
3855 "create",
3856 [](std::string_view name,
3857 std::optional<std::vector<PyType *>> results,
3858 std::optional<std::vector<PyValue *>> operands,
3859 std::optional<nb::dict> attributes,
3860 std::optional<std::vector<PyBlock *>> successors, int regions,
3861 const std::optional<PyLocation> &location,
3862 const nb::object &maybeIp,
3863 bool inferType) -> nb::typed<nb::object, PyOperation> {
3864 // Unpack/validate operands.
3865 llvm::SmallVector<MlirValue, 4> mlirOperands;
3866 if (operands) {
3867 mlirOperands.reserve(operands->size());
3868 for (PyValue *operand : *operands) {
3869 if (!operand)
3870 throw nb::value_error("operand value cannot be None");
3871 mlirOperands.push_back(operand->get());
3872 }
3873 }
3874
3875 PyLocation pyLoc = maybeGetTracebackLocation(location);
3876 return PyOperation::create(name, results, mlirOperands, attributes,
3877 successors, regions, pyLoc, maybeIp,
3878 inferType);
3879 },
3880 nb::arg("name"), nb::arg("results") = nb::none(),
3881 nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
3882 nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
3883 nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
3884 nb::arg("infer_type") = false,
3885 R"(
3886 Creates a new operation.
3887
3888 Args:
3889 name: Operation name (e.g. `dialect.operation`).
3890 results: Optional sequence of Type representing op result types.
3891 operands: Optional operands of the operation.
3892 attributes: Optional Dict of {str: Attribute}.
3893 successors: Optional List of Block for the operation's successors.
3894 regions: Number of regions to create (default = 0).
3895 location: Optional Location object (defaults to resolve from context manager).
3896 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
3897 infer_type: Whether to infer result types (default = False).
3898 Returns:
3899 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
3900 .def_static(
3901 "parse",
3902 [](const std::string &sourceStr, const std::string &sourceName,
3904 -> nb::typed<nb::object, PyOpView> {
3905 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3906 ->createOpView();
3907 },
3908 nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3909 nb::arg("context") = nb::none(),
3910 "Parses an operation. Supports both text assembly format and binary "
3911 "bytecode format.")
3913 "Gets a capsule wrapping the MlirOperation.")
3916 "Creates an Operation from a capsule wrapping MlirOperation.")
3917 .def_prop_ro(
3918 "operation",
3919 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3920 return self;
3921 },
3922 "Returns self (the operation).")
3923 .def_prop_ro(
3924 "opview",
3925 [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
3926 return self.createOpView();
3927 },
3928 R"(
3929 Returns an OpView of this operation.
3930
3931 Note:
3932 If the operation has a registered and loaded dialect then this OpView will
3933 be concrete wrapper class.)")
3934 .def_prop_ro("block", &PyOperation::getBlock,
3935 "Returns the block containing this operation.")
3936 .def_prop_ro(
3937 "successors",
3938 [](PyOperationBase &self) {
3939 return PyOpSuccessors(self.getOperation().getRef());
3940 },
3941 "Returns the list of Operation successors.")
3942 .def(
3943 "replace_uses_of_with",
3944 [](PyOperation &self, PyValue &of, PyValue &with) {
3945 mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
3946 },
3947 "of"_a, "with_"_a,
3948 "Replaces uses of the 'of' value with the 'with' value inside the "
3949 "operation.")
3950 .def("_set_invalid", &PyOperation::setInvalid,
3951 "Invalidate the operation.");
3952
3953 auto opViewClass =
3954 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3955 .def(nb::init<nb::typed<nb::object, PyOperation>>(),
3956 nb::arg("operation"))
3957 .def(
3958 "__init__",
3959 [](PyOpView *self, std::string_view name,
3960 std::tuple<int, bool> opRegionSpec,
3961 nb::object operandSegmentSpecObj,
3962 nb::object resultSegmentSpecObj,
3963 std::optional<nb::list> resultTypeList, nb::list operandList,
3964 std::optional<nb::dict> attributes,
3965 std::optional<std::vector<PyBlock *>> successors,
3966 std::optional<int> regions,
3967 const std::optional<PyLocation> &location,
3968 const nb::object &maybeIp) {
3969 PyLocation pyLoc = maybeGetTracebackLocation(location);
3971 name, opRegionSpec, operandSegmentSpecObj,
3972 resultSegmentSpecObj, resultTypeList, operandList,
3973 attributes, successors, regions, pyLoc, maybeIp));
3974 },
3975 nb::arg("name"), nb::arg("opRegionSpec"),
3976 nb::arg("operandSegmentSpecObj") = nb::none(),
3977 nb::arg("resultSegmentSpecObj") = nb::none(),
3978 nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
3979 nb::arg("attributes") = nb::none(),
3980 nb::arg("successors") = nb::none(),
3981 nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
3982 nb::arg("ip") = nb::none())
3983 .def_prop_ro(
3984 "operation",
3985 [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
3986 return self.getOperationObject();
3987 })
3988 .def_prop_ro("opview",
3989 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
3990 return self;
3991 })
3992 .def(
3993 "__str__",
3994 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3995 .def_prop_ro(
3996 "successors",
3997 [](PyOperationBase &self) {
3998 return PyOpSuccessors(self.getOperation().getRef());
3999 },
4000 "Returns the list of Operation successors.")
4001 .def(
4002 "_set_invalid",
4003 [](PyOpView &self) { self.getOperation().setInvalid(); },
4004 "Invalidate the operation.");
4005 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
4006 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
4007 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
4008 // It is faster to pass the operation_name, ods_regions, and
4009 // ods_operand_segments/ods_result_segments as arguments to the constructor,
4010 // rather than to access them as attributes.
4011 opViewClass.attr("build_generic") = classmethod(
4012 [](nb::handle cls, std::optional<nb::list> resultTypeList,
4013 nb::list operandList, std::optional<nb::dict> attributes,
4014 std::optional<std::vector<PyBlock *>> successors,
4015 std::optional<int> regions, std::optional<PyLocation> location,
4016 const nb::object &maybeIp) {
4017 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4018 std::tuple<int, bool> opRegionSpec =
4019 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
4020 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
4021 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
4022 PyLocation pyLoc = maybeGetTracebackLocation(location);
4023 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
4024 resultSegmentSpec, resultTypeList,
4025 operandList, attributes, successors,
4026 regions, pyLoc, maybeIp);
4027 },
4028 nb::arg("cls"), nb::arg("results") = nb::none(),
4029 nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
4030 nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
4031 nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
4032 "Builds a specific, generated OpView based on class level attributes.");
4033 opViewClass.attr("parse") = classmethod(
4034 [](const nb::object &cls, const std::string &sourceStr,
4035 const std::string &sourceName,
4036 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
4037 PyOperationRef parsed =
4038 PyOperation::parse(context->getRef(), sourceStr, sourceName);
4039
4040 // Check if the expected operation was parsed, and cast to to the
4041 // appropriate `OpView` subclass if successful.
4042 // NOTE: This accesses attributes that have been automatically added to
4043 // `OpView` subclasses, and is not intended to be used on `OpView`
4044 // directly.
4045 std::string clsOpName =
4046 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
4047 MlirStringRef identifier =
4049 std::string_view parsedOpName(identifier.data, identifier.length);
4050 if (clsOpName != parsedOpName)
4051 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
4052 parsedOpName + "'");
4053 return PyOpView::constructDerived(cls, parsed.getObject());
4054 },
4055 nb::arg("cls"), nb::arg("source"), nb::kw_only(),
4056 nb::arg("source_name") = "", nb::arg("context") = nb::none(),
4057 "Parses a specific, generated OpView based on class level attributes.");
4058
4059 //----------------------------------------------------------------------------
4060 // Mapping of PyRegion.
4061 //----------------------------------------------------------------------------
4062 nb::class_<PyRegion>(m, "Region")
4063 .def_prop_ro(
4064 "blocks",
4065 [](PyRegion &self) {
4066 return PyBlockList(self.getParentOperation(), self.get());
4067 },
4068 "Returns a forward-optimized sequence of blocks.")
4069 .def_prop_ro(
4070 "owner",
4071 [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4072 return self.getParentOperation()->createOpView();
4073 },
4074 "Returns the operation owning this region.")
4075 .def(
4076 "__iter__",
4077 [](PyRegion &self) {
4078 self.checkValid();
4079 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
4080 return PyBlockIterator(self.getParentOperation(), firstBlock);
4081 },
4082 "Iterates over blocks in the region.")
4083 .def(
4084 "__eq__",
4085 [](PyRegion &self, PyRegion &other) {
4086 return self.get().ptr == other.get().ptr;
4087 },
4088 "Compares two regions for pointer equality.")
4089 .def(
4090 "__eq__", [](PyRegion &self, nb::object &other) { return false; },
4091 "Compares region with non-region object (always returns False).");
4092
4093 //----------------------------------------------------------------------------
4094 // Mapping of PyBlock.
4095 //----------------------------------------------------------------------------
4096 nb::class_<PyBlock>(m, "Block")
4098 "Gets a capsule wrapping the MlirBlock.")
4099 .def_prop_ro(
4100 "owner",
4101 [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4102 return self.getParentOperation()->createOpView();
4103 },
4104 "Returns the owning operation of this block.")
4105 .def_prop_ro(
4106 "region",
4107 [](PyBlock &self) {
4108 MlirRegion region = mlirBlockGetParentRegion(self.get());
4109 return PyRegion(self.getParentOperation(), region);
4110 },
4111 "Returns the owning region of this block.")
4112 .def_prop_ro(
4113 "arguments",
4114 [](PyBlock &self) {
4115 return PyBlockArgumentList(self.getParentOperation(), self.get());
4116 },
4117 "Returns a list of block arguments.")
4118 .def(
4119 "add_argument",
4120 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
4121 return PyBlockArgument(self.getParentOperation(),
4122 mlirBlockAddArgument(self.get(), type, loc));
4123 },
4124 "type"_a, "loc"_a,
4125 R"(
4126 Appends an argument of the specified type to the block.
4127
4128 Args:
4129 type: The type of the argument to add.
4130 loc: The source location for the argument.
4131
4132 Returns:
4133 The newly added block argument.)")
4134 .def(
4135 "erase_argument",
4136 [](PyBlock &self, unsigned index) {
4137 return mlirBlockEraseArgument(self.get(), index);
4138 },
4139 nb::arg("index"),
4140 R"(
4141 Erases the argument at the specified index.
4142
4143 Args:
4144 index: The index of the argument to erase.)")
4145 .def_prop_ro(
4146 "operations",
4147 [](PyBlock &self) {
4148 return PyOperationList(self.getParentOperation(), self.get());
4149 },
4150 "Returns a forward-optimized sequence of operations.")
4151 .def_static(
4152 "create_at_start",
4153 [](PyRegion &parent, const nb::sequence &pyArgTypes,
4154 const std::optional<nb::sequence> &pyArgLocs) {
4155 parent.checkValid();
4156 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
4157 mlirRegionInsertOwnedBlock(parent, 0, block);
4158 return PyBlock(parent.getParentOperation(), block);
4159 },
4160 nb::arg("parent"), nb::arg("arg_types") = nb::list(),
4161 nb::arg("arg_locs") = std::nullopt,
4162 "Creates and returns a new Block at the beginning of the given "
4163 "region (with given argument types and locations).")
4164 .def(
4165 "append_to",
4166 [](PyBlock &self, PyRegion &region) {
4167 MlirBlock b = self.get();
4170 mlirRegionAppendOwnedBlock(region.get(), b);
4171 },
4172 nb::arg("region"),
4173 R"(
4174 Appends this block to a region.
4175
4176 Transfers ownership if the block is currently owned by another region.
4177
4178 Args:
4179 region: The region to append the block to.)")
4180 .def(
4181 "create_before",
4182 [](PyBlock &self, const nb::args &pyArgTypes,
4183 const std::optional<nb::sequence> &pyArgLocs) {
4184 self.checkValid();
4185 MlirBlock block =
4186 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4187 MlirRegion region = mlirBlockGetParentRegion(self.get());
4188 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
4189 return PyBlock(self.getParentOperation(), block);
4190 },
4191 nb::arg("arg_types"), nb::kw_only(),
4192 nb::arg("arg_locs") = std::nullopt,
4193 "Creates and returns a new Block before this block "
4194 "(with given argument types and locations).")
4195 .def(
4196 "create_after",
4197 [](PyBlock &self, const nb::args &pyArgTypes,
4198 const std::optional<nb::sequence> &pyArgLocs) {
4199 self.checkValid();
4200 MlirBlock block =
4201 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4202 MlirRegion region = mlirBlockGetParentRegion(self.get());
4203 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
4204 return PyBlock(self.getParentOperation(), block);
4205 },
4206 nb::arg("arg_types"), nb::kw_only(),
4207 nb::arg("arg_locs") = std::nullopt,
4208 "Creates and returns a new Block after this block "
4209 "(with given argument types and locations).")
4210 .def(
4211 "__iter__",
4212 [](PyBlock &self) {
4213 self.checkValid();
4214 MlirOperation firstOperation =
4216 return PyOperationIterator(self.getParentOperation(),
4217 firstOperation);
4218 },
4219 "Iterates over operations in the block.")
4220 .def(
4221 "__eq__",
4222 [](PyBlock &self, PyBlock &other) {
4223 return self.get().ptr == other.get().ptr;
4224 },
4225 "Compares two blocks for pointer equality.")
4226 .def(
4227 "__eq__", [](PyBlock &self, nb::object &other) { return false; },
4228 "Compares block with non-block object (always returns False).")
4229 .def(
4230 "__hash__",
4231 [](PyBlock &self) {
4232 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4233 },
4234 "Returns the hash value of the block.")
4235 .def(
4236 "__str__",
4237 [](PyBlock &self) {
4238 self.checkValid();
4239 PyPrintAccumulator printAccum;
4240 mlirBlockPrint(self.get(), printAccum.getCallback(),
4241 printAccum.getUserData());
4242 return printAccum.join();
4243 },
4244 "Returns the assembly form of the block.")
4245 .def(
4246 "append",
4247 [](PyBlock &self, PyOperationBase &operation) {
4248 if (operation.getOperation().isAttached())
4249 operation.getOperation().detachFromParent();
4250
4251 MlirOperation mlirOperation = operation.getOperation().get();
4252 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
4253 operation.getOperation().setAttached(
4254 self.getParentOperation().getObject());
4255 },
4256 nb::arg("operation"),
4257 R"(
4258 Appends an operation to this block.
4259
4260 If the operation is currently in another block, it will be moved.
4261
4262 Args:
4263 operation: The operation to append to the block.)")
4264 .def_prop_ro(
4265 "successors",
4266 [](PyBlock &self) {
4267 return PyBlockSuccessors(self, self.getParentOperation());
4268 },
4269 "Returns the list of Block successors.")
4270 .def_prop_ro(
4271 "predecessors",
4272 [](PyBlock &self) {
4273 return PyBlockPredecessors(self, self.getParentOperation());
4274 },
4275 "Returns the list of Block predecessors.");
4276
4277 //----------------------------------------------------------------------------
4278 // Mapping of PyInsertionPoint.
4279 //----------------------------------------------------------------------------
4280
4281 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
4282 .def(nb::init<PyBlock &>(), nb::arg("block"),
4283 "Inserts after the last operation but still inside the block.")
4284 .def("__enter__", &PyInsertionPoint::contextEnter,
4285 "Enters the insertion point as a context manager.")
4286 .def("__exit__", &PyInsertionPoint::contextExit,
4287 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
4288 nb::arg("traceback").none(),
4289 "Exits the insertion point context manager.")
4290 .def_prop_ro_static(
4291 "current",
4292 [](nb::object & /*class*/) {
4294 if (!ip)
4295 throw nb::value_error("No current InsertionPoint");
4296 return ip;
4297 },
4298 nb::sig("def current(/) -> InsertionPoint"),
4299 "Gets the InsertionPoint bound to the current thread or raises "
4300 "ValueError if none has been set.")
4301 .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
4302 "Inserts before a referenced operation.")
4303 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
4304 nb::arg("block"),
4305 R"(
4306 Creates an insertion point at the beginning of a block.
4307
4308 Args:
4309 block: The block at whose beginning operations should be inserted.
4310
4311 Returns:
4312 An InsertionPoint at the block's beginning.)")
4313 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
4314 nb::arg("block"),
4315 R"(
4316 Creates an insertion point before a block's terminator.
4317
4318 Args:
4319 block: The block whose terminator to insert before.
4320
4321 Returns:
4322 An InsertionPoint before the terminator.
4323
4324 Raises:
4325 ValueError: If the block has no terminator.)")
4326 .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
4327 R"(
4328 Creates an insertion point immediately after an operation.
4329
4330 Args:
4331 operation: The operation after which to insert.
4332
4333 Returns:
4334 An InsertionPoint after the operation.)")
4335 .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
4336 R"(
4337 Inserts an operation at this insertion point.
4338
4339 Args:
4340 operation: The operation to insert.)")
4341 .def_prop_ro(
4342 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
4343 "Returns the block that this `InsertionPoint` points to.")
4344 .def_prop_ro(
4345 "ref_operation",
4346 [](PyInsertionPoint &self)
4347 -> std::optional<nb::typed<nb::object, PyOperation>> {
4348 auto refOperation = self.getRefOperation();
4349 if (refOperation)
4350 return refOperation->getObject();
4351 return {};
4352 },
4353 "The reference operation before which new operations are "
4354 "inserted, or None if the insertion point is at the end of "
4355 "the block.");
4356
4357 //----------------------------------------------------------------------------
4358 // Mapping of PyAttribute.
4359 //----------------------------------------------------------------------------
4360 nb::class_<PyAttribute>(m, "Attribute")
4361 // Delegate to the PyAttribute copy constructor, which will also lifetime
4362 // extend the backing context which owns the MlirAttribute.
4363 .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
4364 "Casts the passed attribute to the generic `Attribute`.")
4366 "Gets a capsule wrapping the MlirAttribute.")
4367 .def_static(
4369 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4370 .def_static(
4371 "parse",
4372 [](const std::string &attrSpec, DefaultingPyMlirContext context)
4373 -> nb::typed<nb::object, PyAttribute> {
4374 PyMlirContext::ErrorCapture errors(context->getRef());
4375 MlirAttribute attr = mlirAttributeParseGet(
4376 context->get(), toMlirStringRef(attrSpec));
4377 if (mlirAttributeIsNull(attr))
4378 throw MLIRError("Unable to parse attribute", errors.take());
4379 return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
4380 },
4381 nb::arg("asm"), nb::arg("context") = nb::none(),
4382 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4383 "failure.")
4384 .def_prop_ro(
4385 "context",
4386 [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4387 return self.getContext().getObject();
4388 },
4389 "Context that owns the `Attribute`.")
4390 .def_prop_ro(
4391 "type",
4392 [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
4393 return PyType(self.getContext(), mlirAttributeGetType(self))
4394 .maybeDownCast();
4395 },
4396 "Returns the type of the `Attribute`.")
4397 .def(
4398 "get_named",
4399 [](PyAttribute &self, std::string name) {
4400 return PyNamedAttribute(self, std::move(name));
4401 },
4402 nb::keep_alive<0, 1>(),
4403 R"(
4404 Binds a name to the attribute, creating a `NamedAttribute`.
4405
4406 Args:
4407 name: The name to bind to the `Attribute`.
4408
4409 Returns:
4410 A `NamedAttribute` with the given name and this attribute.)")
4411 .def(
4412 "__eq__",
4413 [](PyAttribute &self, PyAttribute &other) { return self == other; },
4414 "Compares two attributes for equality.")
4415 .def(
4416 "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
4417 "Compares attribute with non-attribute object (always returns "
4418 "False).")
4419 .def(
4420 "__hash__",
4421 [](PyAttribute &self) {
4422 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4423 },
4424 "Returns the hash value of the attribute.")
4425 .def(
4426 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
4428 .def(
4429 "__str__",
4430 [](PyAttribute &self) {
4431 PyPrintAccumulator printAccum;
4432 mlirAttributePrint(self, printAccum.getCallback(),
4433 printAccum.getUserData());
4434 return printAccum.join();
4435 },
4436 "Returns the assembly form of the Attribute.")
4437 .def(
4438 "__repr__",
4439 [](PyAttribute &self) {
4440 // Generally, assembly formats are not printed for __repr__ because
4441 // this can cause exceptionally long debug output and exceptions.
4442 // However, attribute values are generally considered useful and
4443 // are printed. This may need to be re-evaluated if debug dumps end
4444 // up being excessive.
4445 PyPrintAccumulator printAccum;
4446 printAccum.parts.append("Attribute(");
4447 mlirAttributePrint(self, printAccum.getCallback(),
4448 printAccum.getUserData());
4449 printAccum.parts.append(")");
4450 return printAccum.join();
4451 },
4452 "Returns a string representation of the attribute.")
4453 .def_prop_ro(
4454 "typeid",
4455 [](PyAttribute &self) {
4456 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
4457 assert(!mlirTypeIDIsNull(mlirTypeID) &&
4458 "mlirTypeID was expected to be non-null.");
4459 return PyTypeID(mlirTypeID);
4460 },
4461 "Returns the `TypeID` of the attribute.")
4462 .def(
4464 [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4465 return self.maybeDownCast();
4466 },
4467 "Downcasts the attribute to a more specific attribute if possible.");
4468
4469 //----------------------------------------------------------------------------
4470 // Mapping of PyNamedAttribute
4471 //----------------------------------------------------------------------------
4472 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
4473 .def(
4474 "__repr__",
4475 [](PyNamedAttribute &self) {
4476 PyPrintAccumulator printAccum;
4477 printAccum.parts.append("NamedAttribute(");
4478 printAccum.parts.append(
4479 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
4481 printAccum.parts.append("=");
4483 printAccum.getCallback(),
4484 printAccum.getUserData());
4485 printAccum.parts.append(")");
4486 return printAccum.join();
4487 },
4488 "Returns a string representation of the named attribute.")
4489 .def_prop_ro(
4490 "name",
4491 [](PyNamedAttribute &self) {
4492 return mlirIdentifierStr(self.namedAttr.name);
4493 },
4494 "The name of the `NamedAttribute` binding.")
4495 .def_prop_ro(
4496 "attr",
4497 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
4498 nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
4499 "The underlying generic attribute of the `NamedAttribute` binding.");
4500
4501 //----------------------------------------------------------------------------
4502 // Mapping of PyType.
4503 //----------------------------------------------------------------------------
4504 nb::class_<PyType>(m, "Type")
4505 // Delegate to the PyType copy constructor, which will also lifetime
4506 // extend the backing context which owns the MlirType.
4507 .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
4508 "Casts the passed type to the generic `Type`.")
4510 "Gets a capsule wrapping the `MlirType`.")
4512 "Creates a Type from a capsule wrapping `MlirType`.")
4513 .def_static(
4514 "parse",
4515 [](std::string typeSpec,
4516 DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
4517 PyMlirContext::ErrorCapture errors(context->getRef());
4518 MlirType type =
4519 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4520 if (mlirTypeIsNull(type))
4521 throw MLIRError("Unable to parse type", errors.take());
4522 return PyType(context.get()->getRef(), type).maybeDownCast();
4523 },
4524 nb::arg("asm"), nb::arg("context") = nb::none(),
4525 R"(
4526 Parses the assembly form of a type.
4527
4528 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4529
4530 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4531 .def_prop_ro(
4532 "context",
4533 [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4534 return self.getContext().getObject();
4535 },
4536 "Context that owns the `Type`.")
4537 .def(
4538 "__eq__", [](PyType &self, PyType &other) { return self == other; },
4539 "Compares two types for equality.")
4540 .def(
4541 "__eq__", [](PyType &self, nb::object &other) { return false; },
4542 nb::arg("other").none(),
4543 "Compares type with non-type object (always returns False).")
4544 .def(
4545 "__hash__",
4546 [](PyType &self) {
4547 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4548 },
4549 "Returns the hash value of the `Type`.")
4550 .def(
4551 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4552 .def(
4553 "__str__",
4554 [](PyType &self) {
4555 PyPrintAccumulator printAccum;
4556 mlirTypePrint(self, printAccum.getCallback(),
4557 printAccum.getUserData());
4558 return printAccum.join();
4559 },
4560 "Returns the assembly form of the `Type`.")
4561 .def(
4562 "__repr__",
4563 [](PyType &self) {
4564 // Generally, assembly formats are not printed for __repr__ because
4565 // this can cause exceptionally long debug output and exceptions.
4566 // However, types are an exception as they typically have compact
4567 // assembly forms and printing them is useful.
4568 PyPrintAccumulator printAccum;
4569 printAccum.parts.append("Type(");
4570 mlirTypePrint(self, printAccum.getCallback(),
4571 printAccum.getUserData());
4572 printAccum.parts.append(")");
4573 return printAccum.join();
4574 },
4575 "Returns a string representation of the `Type`.")
4576 .def(
4578 [](PyType &self) -> nb::typed<nb::object, PyType> {
4579 return self.maybeDownCast();
4580 },
4581 "Downcasts the Type to a more specific `Type` if possible.")
4582 .def_prop_ro(
4583 "typeid",
4584 [](PyType &self) {
4585 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
4586 if (!mlirTypeIDIsNull(mlirTypeID))
4587 return PyTypeID(mlirTypeID);
4588 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4589 throw nb::value_error(
4590 (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
4591 },
4592 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4593 "`Type` has no "
4594 "`TypeID`.");
4595
4596 //----------------------------------------------------------------------------
4597 // Mapping of PyTypeID.
4598 //----------------------------------------------------------------------------
4599 nb::class_<PyTypeID>(m, "TypeID")
4601 "Gets a capsule wrapping the `MlirTypeID`.")
4603 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4604 // Note, this tests whether the underlying TypeIDs are the same,
4605 // not whether the wrapper MlirTypeIDs are the same, nor whether
4606 // the Python objects are the same (i.e., PyTypeID is a value type).
4607 .def(
4608 "__eq__",
4609 [](PyTypeID &self, PyTypeID &other) { return self == other; },
4610 "Compares two `TypeID`s for equality.")
4611 .def(
4612 "__eq__",
4613 [](PyTypeID &self, const nb::object &other) { return false; },
4614 "Compares TypeID with non-TypeID object (always returns False).")
4615 // Note, this gives the hash value of the underlying TypeID, not the
4616 // hash value of the Python object, nor the hash value of the
4617 // MlirTypeID wrapper.
4618 .def(
4619 "__hash__",
4620 [](PyTypeID &self) {
4621 return static_cast<size_t>(mlirTypeIDHashValue(self));
4622 },
4623 "Returns the hash value of the `TypeID`.");
4624
4625 //----------------------------------------------------------------------------
4626 // Mapping of Value.
4627 //----------------------------------------------------------------------------
4628 m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
4629
4630 nb::class_<PyValue>(m, "Value", nb::is_generic(),
4631 nb::sig("class Value(Generic[_T])"))
4632 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
4633 "Creates a Value reference from another `Value`.")
4635 "Gets a capsule wrapping the `MlirValue`.")
4637 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4638 .def_prop_ro(
4639 "context",
4640 [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4641 return self.getParentOperation()->getContext().getObject();
4642 },
4643 "Context in which the value lives.")
4644 .def(
4645 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4647 .def_prop_ro(
4648 "owner",
4649 [](PyValue &self) -> nb::object {
4650 MlirValue v = self.get();
4651 if (mlirValueIsAOpResult(v)) {
4652 assert(mlirOperationEqual(self.getParentOperation()->get(),
4653 mlirOpResultGetOwner(self.get())) &&
4654 "expected the owner of the value in Python to match "
4655 "that in "
4656 "the IR");
4657 return self.getParentOperation().getObject();
4658 }
4659
4661 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4662 return nb::cast(PyBlock(self.getParentOperation(), block));
4663 }
4664
4665 assert(false && "Value must be a block argument or an op result");
4666 return nb::none();
4667 },
4668 "Returns the owner of the value (`Operation` for results, `Block` "
4669 "for "
4670 "arguments).")
4671 .def_prop_ro(
4672 "uses",
4673 [](PyValue &self) {
4674 return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
4675 },
4676 "Returns an iterator over uses of this value.")
4677 .def(
4678 "__eq__",
4679 [](PyValue &self, PyValue &other) {
4680 return self.get().ptr == other.get().ptr;
4681 },
4682 "Compares two values for pointer equality.")
4683 .def(
4684 "__eq__", [](PyValue &self, nb::object other) { return false; },
4685 "Compares value with non-value object (always returns False).")
4686 .def(
4687 "__hash__",
4688 [](PyValue &self) {
4689 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4690 },
4691 "Returns the hash value of the value.")
4692 .def(
4693 "__str__",
4694 [](PyValue &self) {
4695 PyPrintAccumulator printAccum;
4696 printAccum.parts.append("Value(");
4697 mlirValuePrint(self.get(), printAccum.getCallback(),
4698 printAccum.getUserData());
4699 printAccum.parts.append(")");
4700 return printAccum.join();
4701 },
4702 R"(
4703 Returns the string form of the value.
4704
4705 If the value is a block argument, this is the assembly form of its type and the
4706 position in the argument list. If the value is an operation result, this is
4707 equivalent to printing the operation that produced it.
4708 )")
4709 .def(
4710 "get_name",
4711 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4712 PyPrintAccumulator printAccum;
4713 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4714 if (useLocalScope)
4716 if (useNameLocAsPrefix)
4718 MlirAsmState valueState =
4719 mlirAsmStateCreateForValue(self.get(), flags);
4720 mlirValuePrintAsOperand(self.get(), valueState,
4721 printAccum.getCallback(),
4722 printAccum.getUserData());
4724 mlirAsmStateDestroy(valueState);
4725 return printAccum.join();
4726 },
4727 nb::arg("use_local_scope") = false,
4728 nb::arg("use_name_loc_as_prefix") = false,
4729 R"(
4730 Returns the string form of value as an operand.
4731
4732 Args:
4733 use_local_scope: Whether to use local scope for naming.
4734 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4735
4736 Returns:
4737 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4738 .def(
4739 "get_name",
4740 [](PyValue &self, PyAsmState &state) {
4741 PyPrintAccumulator printAccum;
4742 MlirAsmState valueState = state.get();
4743 mlirValuePrintAsOperand(self.get(), valueState,
4744 printAccum.getCallback(),
4745 printAccum.getUserData());
4746 return printAccum.join();
4747 },
4748 nb::arg("state"),
4749 "Returns the string form of value as an operand (i.e., the ValueID).")
4750 .def_prop_ro(
4751 "type",
4752 [](PyValue &self) -> nb::typed<nb::object, PyType> {
4753 return PyType(self.getParentOperation()->getContext(),
4754 mlirValueGetType(self.get()))
4755 .maybeDownCast();
4756 },
4757 "Returns the type of the value.")
4758 .def(
4759 "set_type",
4760 [](PyValue &self, const PyType &type) {
4761 mlirValueSetType(self.get(), type);
4762 },
4763 nb::arg("type"), "Sets the type of the value.",
4764 nb::sig("def set_type(self, type: _T)"))
4765 .def(
4766 "replace_all_uses_with",
4767 [](PyValue &self, PyValue &with) {
4768 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4769 },
4770 "Replace all uses of value with the new value, updating anything in "
4771 "the IR that uses `self` to use the other value instead.")
4772 .def(
4773 "replace_all_uses_except",
4774 [](PyValue &self, PyValue &with, PyOperation &exception) {
4775 MlirOperation exceptedUser = exception.get();
4776 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4777 },
4778 nb::arg("with_"), nb::arg("exceptions"),
4780 .def(
4781 "replace_all_uses_except",
4782 [](PyValue &self, PyValue &with, const nb::list &exceptions) {
4783 // Convert Python list to a SmallVector of MlirOperations
4784 llvm::SmallVector<MlirOperation> exceptionOps;
4785 for (nb::handle exception : exceptions) {
4786 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4787 }
4788
4790 self, with, static_cast<intptr_t>(exceptionOps.size()),
4791 exceptionOps.data());
4792 },
4793 nb::arg("with_"), nb::arg("exceptions"),
4795 .def(
4796 "replace_all_uses_except",
4797 [](PyValue &self, PyValue &with, PyOperation &exception) {
4798 MlirOperation exceptedUser = exception.get();
4799 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4800 },
4801 nb::arg("with_"), nb::arg("exceptions"),
4803 .def(
4804 "replace_all_uses_except",
4805 [](PyValue &self, PyValue &with,
4806 std::vector<PyOperation> &exceptions) {
4807 // Convert Python list to a SmallVector of MlirOperations
4808 llvm::SmallVector<MlirOperation> exceptionOps;
4809 for (PyOperation &exception : exceptions)
4810 exceptionOps.push_back(exception);
4812 self, with, static_cast<intptr_t>(exceptionOps.size()),
4813 exceptionOps.data());
4814 },
4815 nb::arg("with_"), nb::arg("exceptions"),
4817 .def(
4819 [](PyValue &self) -> nb::typed<nb::object, PyValue> {
4820 return self.maybeDownCast();
4821 },
4822 "Downcasts the `Value` to a more specific kind if possible.")
4823 .def_prop_ro(
4824 "location",
4825 [](MlirValue self) {
4826 return PyLocation(
4828 mlirValueGetLocation(self));
4829 },
4830 "Returns the source location of the value.");
4831
4832 PyBlockArgument::bind(m);
4833 PyOpResult::bind(m);
4834 PyOpOperand::bind(m);
4835
4836 nb::class_<PyAsmState>(m, "AsmState")
4837 .def(nb::init<PyValue &, bool>(), nb::arg("value"),
4838 nb::arg("use_local_scope") = false,
4839 R"(
4840 Creates an `AsmState` for consistent SSA value naming.
4841
4842 Args:
4843 value: The value to create state for.
4844 use_local_scope: Whether to use local scope for naming.)")
4845 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4846 nb::arg("use_local_scope") = false,
4847 R"(
4848 Creates an AsmState for consistent SSA value naming.
4849
4850 Args:
4851 op: The operation to create state for.
4852 use_local_scope: Whether to use local scope for naming.)");
4853
4854 //----------------------------------------------------------------------------
4855 // Mapping of SymbolTable.
4856 //----------------------------------------------------------------------------
4857 nb::class_<PySymbolTable>(m, "SymbolTable")
4858 .def(nb::init<PyOperationBase &>(),
4859 R"(
4860 Creates a symbol table for an operation.
4861
4862 Args:
4863 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4864
4865 Raises:
4866 TypeError: If the operation is not a symbol table.)")
4867 .def(
4868 "__getitem__",
4869 [](PySymbolTable &self,
4870 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4871 return self.dunderGetItem(name);
4872 },
4873 R"(
4874 Looks up a symbol by name in the symbol table.
4875
4876 Args:
4877 name: The name of the symbol to look up.
4878
4879 Returns:
4880 The operation defining the symbol.
4881
4882 Raises:
4883 KeyError: If the symbol is not found.)")
4884 .def("insert", &PySymbolTable::insert, nb::arg("operation"),
4885 R"(
4886 Inserts a symbol operation into the symbol table.
4887
4888 Args:
4889 operation: An operation with a symbol name to insert.
4890
4891 Returns:
4892 The symbol name attribute of the inserted operation.
4893
4894 Raises:
4895 ValueError: If the operation does not have a symbol name.)")
4896 .def("erase", &PySymbolTable::erase, nb::arg("operation"),
4897 R"(
4898 Erases a symbol operation from the symbol table.
4899
4900 Args:
4901 operation: The symbol operation to erase.
4902
4903 Note:
4904 The operation is also erased from the IR and invalidated.)")
4905 .def("__delitem__", &PySymbolTable::dunderDel,
4906 "Deletes a symbol by name from the symbol table.")
4907 .def(
4908 "__contains__",
4909 [](PySymbolTable &table, const std::string &name) {
4910 return !mlirOperationIsNull(mlirSymbolTableLookup(
4911 table, mlirStringRefCreate(name.data(), name.length())));
4912 },
4913 "Checks if a symbol with the given name exists in the table.")
4914 // Static helpers.
4915 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
4916 nb::arg("symbol"), nb::arg("name"),
4917 "Sets the symbol name for a symbol operation.")
4918 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
4919 nb::arg("symbol"),
4920 "Gets the symbol name from a symbol operation.")
4921 .def_static("get_visibility", &PySymbolTable::getVisibility,
4922 nb::arg("symbol"),
4923 "Gets the visibility attribute of a symbol operation.")
4924 .def_static("set_visibility", &PySymbolTable::setVisibility,
4925 nb::arg("symbol"), nb::arg("visibility"),
4926 "Sets the visibility attribute of a symbol operation.")
4927 .def_static("replace_all_symbol_uses",
4928 &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4929 nb::arg("new_symbol"), nb::arg("from_op"),
4930 "Replaces all uses of a symbol with a new symbol name within "
4931 "the given operation.")
4932 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4933 nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4934 nb::arg("callback"),
4935 "Walks symbol tables starting from an operation with a "
4936 "callback function.");
4937
4938 // Container bindings.
4939 PyBlockArgumentList::bind(m);
4940 PyBlockIterator::bind(m);
4941 PyBlockList::bind(m);
4942 PyBlockSuccessors::bind(m);
4943 PyBlockPredecessors::bind(m);
4944 PyOperationIterator::bind(m);
4945 PyOperationList::bind(m);
4946 PyOpAttributeMap::bind(m);
4947 PyOpOperandIterator::bind(m);
4948 PyOpOperandList::bind(m);
4950 PyOpSuccessors::bind(m);
4951 PyRegionIterator::bind(m);
4952 PyRegionList::bind(m);
4953
4954 // Debug bindings.
4956
4957 // Attribute builder getter.
4959
4960 nb::register_exception_translator([](const std::exception_ptr &p,
4961 void *payload) {
4962 // We can't define exceptions with custom fields through pybind, so instead
4963 // the exception class is defined in python and imported here.
4964 try {
4965 if (p)
4966 std::rethrow_exception(p);
4967 } catch (const MLIRError &e) {
4968 nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4969 .attr("MLIRError")(e.message, e.errorDiagnostics);
4970 PyErr_SetObject(PyExc_Exception, obj.ptr());
4971 }
4972 });
4973}
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
Definition Debug.cpp:28
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
Definition Debug.cpp:20
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
Definition Debug.cpp:18
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
Definition Debug.cpp:16
static const char kDumpDocstring[]
Definition IRAffine.cpp:37
static MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.cpp:77
static const char kModuleParseDocstring[]
Definition IRCore.cpp:36
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
Definition IRCore.cpp:1542
static nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.cpp:59
static MlirValue getUniqueResult(MlirOperation operation)
Definition IRCore.cpp:1709
#define Py_XNewRef(obj)
Definition IRCore.cpp:2764
#define _Py_CAST(type, expr)
Definition IRCore.cpp:2740
static MlirValue getOpResultOrValue(nb::handle operand)
Definition IRCore.cpp:1724
#define Py_NewRef(obj)
Definition IRCore.cpp:2773
#define _Py_NULL
Definition IRCore.cpp:2751
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
Definition IRCore.cpp:1294
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
Definition IRCore.cpp:65
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition IRCore.cpp:91
static void populateResultTypes(StringRef name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
Definition IRCore.cpp:1612
static const char kValueReplaceAllUsesExceptDocstring[]
Definition IRCore.cpp:47
MlirContext mlirModuleGetContext(MlirModule module)
Definition IR.cpp:446
size_t mlirModuleHashValue(MlirModule mod)
Definition IR.cpp:472
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
Definition IR.cpp:1095
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:669
bool mlirValueIsABlockArgument(MlirValue value)
Definition IR.cpp:1115
intptr_t mlirOperationGetNumRegions(MlirOperation op)
Definition IR.cpp:681
MlirBlock mlirOperationGetBlock(MlirOperation op)
Definition IR.cpp:673
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
Definition IR.cpp:1132
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
Definition IR.cpp:521
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
Definition IR.cpp:732
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
Definition IR.cpp:437
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
Definition IR.cpp:178
intptr_t mlirOperationGetNumResults(MlirOperation op)
Definition IR.cpp:728
void mlirOperationDestroy(MlirOperation op)
Definition IR.cpp:639
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1280
MlirType mlirValueGetType(MlirValue value)
Definition IR.cpp:1151
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
Definition IR.cpp:1081
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:202
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
Definition IR.cpp:468
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:210
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
Definition IR.cpp:793
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
Definition IR.cpp:705
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:220
MlirOperation mlirModuleGetOperation(MlirModule module)
Definition IR.cpp:460
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:215
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
Definition IR.cpp:233
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
Definition IR.cpp:1127
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
Definition IR.cpp:740
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
Definition IR.cpp:1299
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
Definition IR.cpp:643
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
Definition IR.cpp:237
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
Definition IR.cpp:252
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1091
void mlirModuleDestroy(MlirModule module)
Definition IR.cpp:454
MlirModule mlirModuleCreateEmpty(MlirLocation location)
Definition IR.cpp:425
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:225
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
Definition IR.cpp:677
void mlirValueSetType(MlirValue value, MlirType type)
Definition IR.cpp:1155
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
Definition IR.cpp:736
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
Definition IR.cpp:1295
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Definition IR.cpp:415
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
Definition IR.cpp:812
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
Definition IR.cpp:717
MlirOperation mlirOpResultGetOwner(MlirValue value)
Definition IR.cpp:1142
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
Definition IR.cpp:429
size_t mlirOperationHashValue(MlirOperation op)
Definition IR.cpp:647
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
Definition IR.cpp:504
MlirOperation mlirOperationClone(MlirOperation op)
Definition IR.cpp:635
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
Definition IR.cpp:1123
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
Definition IR.cpp:1137
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
Definition IR.cpp:713
MlirLocation mlirOperationGetLocation(MlirOperation op)
Definition IR.cpp:655
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:807
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
Definition IR.cpp:1291
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
Definition IR.cpp:513
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
Definition IR.cpp:659
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1284
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
Definition IR.cpp:817
bool mlirValueIsAOpResult(MlirValue value)
Definition IR.cpp:1119
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
Definition IR.cpp:1100
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
Definition IR.cpp:685
MlirOperation mlirOperationCreate(MlirOperationState *state)
Definition IR.cpp:589
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
Definition IR.cpp:256
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
Definition IR.cpp:1276
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
Definition IR.cpp:1087
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
Definition IR.cpp:802
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:206
void mlirValueDump(MlirValue value)
Definition IR.cpp:1159
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
Definition IR.cpp:1265
MlirBlock mlirModuleGetBody(MlirModule module)
Definition IR.cpp:450
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
Definition IR.cpp:626
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
Definition IR.cpp:196
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:651
intptr_t mlirOpResultGetResultNumber(MlirValue value)
Definition IR.cpp:1146
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
Definition IR.cpp:517
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
Definition IR.cpp:248
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
Definition IR.cpp:229
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
Definition IR.cpp:241
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
Definition IR.cpp:509
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
Definition IR.cpp:480
intptr_t mlirOperationGetNumOperands(MlirOperation op)
Definition IR.cpp:709
void mlirTypeDump(MlirType type)
Definition IR.cpp:1270
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
Definition IR.cpp:798
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
Definition Interop.h:348
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
Definition Interop.h:216
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
Definition Interop.h:118
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
Definition Interop.h:338
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
Definition Interop.h:189
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
Definition Interop.h:367
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
Definition Interop.h:330
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
Definition Interop.h:180
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
Definition Interop.h:282
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
Definition Interop.h:224
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
Definition Interop.h:357
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
Definition Interop.h:198
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
Definition Interop.h:255
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
Definition Interop.h:245
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
Definition Interop.h:376
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
Definition Interop.h:454
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
Definition Interop.h:445
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
Definition Interop.h:273
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
Definition Interop.h:264
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
Definition Interop.h:235
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
A list of operation results.
Definition IRCore.cpp:1557
static void bindDerived(ClassTy &c)
Definition IRCore.cpp:1570
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
Definition IRCore.cpp:1562
static constexpr const char * pyClassName
Definition IRCore.cpp:1559
Sliceable< PyOpResultList, PyOpResult > SliceableT
Definition IRCore.cpp:1560
PyOperationRef & getOperation()
Definition IRCore.cpp:1585
Python wrapper for MlirOpResult.
Definition IRCore.cpp:1513
static constexpr IsAFunctionTy isaFunction
Definition IRCore.cpp:1515
static void bindDerived(ClassTy &c)
Definition IRCore.cpp:1519
static constexpr const char * pyClassName
Definition IRCore.cpp:1516
Accumulates into a file, either writing text (default) or binary.
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
static void bind(nanobind::module_ &m)
nanobind::class_< PyOpResultList > ClassTy
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
Base class for all objects that directly or indirectly depend on an MlirContext.
Definition IRModule.h:284
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRModule.h:292
BaseContextObject(PyMlirContextRef ref)
Definition IRModule.h:286
static PyLocation & resolve()
Definition IRCore.cpp:951
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRModule.h:273
static PyMlirContext & resolve()
Definition IRCore.cpp:672
ReferrentTy * get() const
Wrapper around an MlirAsmState.
Definition IRModule.h:778
MlirAsmState get()
Definition IRModule.h:803
Wrapper around the generic MlirAttribute.
Definition IRModule.h:1008
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
Definition IRCore.cpp:2037
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
Definition IRModule.h:1010
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
Definition IRCore.cpp:2033
nanobind::object maybeDownCast()
Definition IRCore.cpp:2045
MlirAttribute get() const
Definition IRModule.h:1014
bool operator==(const PyAttribute &other) const
Definition IRCore.cpp:2029
Wrapper around an MlirBlock.
Definition IRModule.h:813
PyOperationRef & getParentOperation()
Definition IRModule.h:821
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Definition IRCore.cpp:196
Represents a diagnostic handler attached to the context.
Definition IRModule.h:380
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
Definition IRCore.cpp:828
nanobind::object contextEnter()
Definition IRModule.h:391
void detach()
Detaches the handler. Does nothing if not attached.
Definition IRCore.cpp:834
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRModule.h:392
Python class mirroring the C MlirDiagnostic struct.
Definition IRModule.h:330
nanobind::tuple getNotes()
Definition IRCore.cpp:872
nanobind::str getMessage()
Definition IRCore.cpp:864
DiagnosticInfo getInfo()
Definition IRCore.cpp:888
PyDiagnostic(MlirDiagnostic diagnostic)
Definition IRModule.h:332
MlirDiagnosticSeverity getSeverity()
Definition IRCore.cpp:852
Wrapper around an MlirDialect.
Definition IRModule.h:435
Wrapper around an MlirDialectRegistry.
Definition IRModule.h:472
nanobind::object getCapsule()
Definition IRCore.cpp:913
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
Definition IRCore.cpp:917
User-level dialect object.
Definition IRModule.h:459
nanobind::object getDescriptor()
Definition IRModule.h:463
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
Definition IRModule.h:448
MlirDialect getDialectForKey(const std::string &key, bool attrError)
Definition IRCore.cpp:900
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition IRModule.cpp:73
TracebackLoc & getTracebackLoc()
Definition Globals.h:153
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.h:39
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition IRModule.cpp:155
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition IRModule.cpp:169
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition IRModule.cpp:184
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition IRModule.cpp:132
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition IRModule.cpp:142
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRModule.h:837
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
Definition IRCore.cpp:1993
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
Definition IRCore.cpp:1980
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:2019
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
Definition IRCore.cpp:2002
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
Definition IRCore.cpp:1945
void insert(PyOperationBase &operationBase)
Inserts an operation.
Definition IRCore.cpp:1954
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
Definition IRCore.cpp:2015
std::optional< PyOperationRef > & getRefOperation()
Definition IRModule.h:865
Wrapper around an MlirLocation.
Definition IRModule.h:299
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
Definition IRCore.cpp:929
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
Definition IRModule.h:301
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
Definition IRCore.cpp:933
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:945
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
Definition IRCore.cpp:941
MlirLocation get() const
Definition IRModule.h:305
MlirContext get()
Accesses the underlying MlirContext.
Definition IRModule.h:204
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRModule.h:208
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
Definition IRCore.cpp:593
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
Definition IRCore.cpp:2013
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
Definition IRCore.cpp:608
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
Definition IRCore.cpp:557
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Definition IRCore.cpp:602
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:568
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
Definition IRCore.cpp:561
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
Definition IRCore.cpp:598
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
Definition IRModule.h:240
MlirModule get()
Gets the backing MlirModule.
Definition IRModule.h:522
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
Definition IRCore.cpp:978
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
Definition IRCore.cpp:1003
PyModuleRef getRef()
Gets a strong reference to this module.
Definition IRModule.h:525
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
Definition IRCore.cpp:1010
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition IRModule.h:1034
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
Definition IRCore.cpp:2063
MlirNamedAttribute namedAttr
Definition IRModule.h:1043
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
Definition IRModule.h:53
nanobind::object getObject()
Definition IRModule.h:91
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
Definition IRModule.h:79
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
Definition IRModule.h:723
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRModule.h:726
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
Definition IRCore.cpp:1927
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
Definition IRCore.cpp:1743
nanobind::object getOperationObject()
Definition IRModule.h:728
PyOpView(const nanobind::object &operationObject)
Definition IRCore.cpp:1935
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRModule.h:552
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
Definition IRCore.cpp:1167
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IRCore.cpp:1245
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
Definition IRCore.cpp:1199
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
Definition IRCore.cpp:1227
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
Definition IRCore.cpp:1145
void moveBefore(PyOperationBase &other)
Definition IRCore.cpp:1236
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
bool verify()
Verify the operation.
Definition IRCore.cpp:1253
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
Definition IRModule.h:630
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
Definition IRCore.cpp:1018
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
Definition IRCore.cpp:1445
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
Definition IRCore.cpp:1285
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
Definition IRCore.cpp:1280
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
Definition IRCore.cpp:1070
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
Definition IRCore.cpp:1425
PyOperationRef getRef()
Definition IRModule.h:643
MlirOperation get() const
Definition IRModule.h:638
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:1063
void setAttached(const nanobind::object &parent=nanobind::object())
Definition IRModule.h:648
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
Definition IRCore.cpp:1261
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1434
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
Definition IRCore.cpp:1271
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, llvm::ArrayRef< MlirValue > operands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
Definition IRCore.cpp:1309
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
Definition IRCore.cpp:1079
void setInvalid()
Invalidate the operation.
Definition IRModule.h:690
Wrapper around an MlirRegion.
Definition IRModule.h:759
PyOperationRef & getParentOperation()
Definition IRModule.h:768
Bindings for MLIR symbol tables.
Definition IRModule.h:1266
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
Definition IRCore.cpp:2198
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
Definition IRCore.cpp:2284
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
Definition IRCore.cpp:2272
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
Definition IRCore.cpp:2254
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
Definition IRCore.cpp:2228
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
Definition IRCore.cpp:2203
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
Definition IRCore.cpp:2188
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
Definition IRCore.cpp:2167
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
Definition IRCore.cpp:2215
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
Definition IRCore.cpp:2243
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
Definition IRCore.cpp:2175
static PyThreadContextEntry * getTopOfStack()
Stack management.
Definition IRCore.cpp:692
static void popLocation(PyLocation &location)
Definition IRCore.cpp:804
static nanobind::object pushLocation(nanobind::object location)
Definition IRCore.cpp:795
static nanobind::object pushContext(nanobind::object context)
Definition IRCore.cpp:754
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
Definition IRCore.cpp:749
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
Definition IRCore.cpp:784
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
Definition IRCore.cpp:772
static void popContext(PyMlirContext &context)
Definition IRCore.cpp:761
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
Definition IRCore.cpp:744
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
Definition IRCore.cpp:739
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
Definition IRCore.cpp:687
PyThreadContextEntry(FrameKind frameKind, nanobind::object context, nanobind::object insertionPoint, nanobind::object location)
Definition IRModule.h:120
PyInsertionPoint * getInsertionPoint()
Definition IRCore.cpp:727
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
Definition IRModule.h:168
std::string _mlir_thread_pool_ptr() const
Definition IRModule.h:179
int getMaxConcurrency() const
Definition IRModule.h:176
MlirLlvmThreadPool get()
Definition IRModule.h:177
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRModule.h:904
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
Definition IRCore.cpp:2113
bool operator==(const PyTypeID &other) const
Definition IRCore.cpp:2119
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
Definition IRCore.cpp:2109
PyTypeID(MlirTypeID typeID)
Definition IRModule.h:906
Wrapper around the generic MlirType.
Definition IRModule.h:878
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRModule.h:880
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
Definition IRCore.cpp:2079
MlirType get() const
Definition IRModule.h:884
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
Definition IRCore.cpp:2083
nanobind::object maybeDownCast()
Definition IRCore.cpp:2091
bool operator==(const PyType &other) const
Definition IRCore.cpp:2075
Wrapper around the generic MlirValue.
Definition IRModule.h:1167
PyValue(PyOperationRef parentOperation, MlirValue value)
Definition IRModule.h:1173
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
Definition IRCore.cpp:2146
nanobind::object maybeDownCast()
Definition IRCore.cpp:2131
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
Definition IRCore.cpp:2127
PyOperationRef & getParentOperation()
Definition IRModule.h:1178
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MlirDiagnosticSeverity
Severity of a diagnostic.
Definition Diagnostics.h:32
@ MlirDiagnosticNote
Definition Diagnostics.h:35
@ MlirDiagnosticRemark
Definition Diagnostics.h:36
@ MlirDiagnosticWarning
Definition Diagnostics.h:34
@ MlirDiagnosticError
Definition Diagnostics.h:33
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
Definition Diagnostics.h:41
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
Definition IR.cpp:265
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
Definition IR.h:851
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
Definition IR.cpp:1202
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
Definition IR.cpp:117
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
Definition IR.cpp:841
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
Definition IR.cpp:273
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
Definition IR.cpp:1384
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:137
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
Definition IR.cpp:311
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
Definition IR.cpp:1363
MlirWalkOrder
Traversal order for operation walk.
Definition IR.h:844
@ MlirWalkPreOrder
Definition IR.h:845
@ MlirWalkPostOrder
Definition IR.h:846
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1311
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
Definition IR.cpp:392
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
Definition IR.cpp:1368
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
Definition IR.cpp:84
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1332
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
Definition IR.cpp:1245
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
Definition IR.cpp:1228
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
Definition IR.cpp:73
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
Definition IR.cpp:951
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
Definition IR.cpp:321
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
Definition IR.cpp:355
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
Definition IR.cpp:1184
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
Definition IR.cpp:1167
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
Definition IR.cpp:403
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
Definition IR.cpp:1216
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
Definition IR.cpp:289
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
Definition IR.cpp:361
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
Definition IR.cpp:1303
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
Definition IR.cpp:990
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
Definition IR.cpp:865
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
Definition IR.cpp:1188
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
Definition IR.cpp:833
MlirWalkResult
Operation walk result.
Definition IR.h:837
@ MlirWalkResultInterrupt
Definition IR.h:839
@ MlirWalkResultSkip
Definition IR.h:840
@ MlirWalkResultAdvance
Definition IR.h:838
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
Definition IR.cpp:931
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1156
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
Definition IR.cpp:100
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
Definition IR.cpp:1059
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
Definition IR.cpp:293
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
Definition IR.cpp:347
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
Definition IR.cpp:1040
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
Definition IR.h:104
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
Definition IR.cpp:95
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
Definition IR.cpp:343
struct MlirNamedAttribute MlirNamedAttribute
Definition IR.h:80
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
Definition IR.cpp:937
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
Definition IR.cpp:974
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
Definition IR.h:937
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
Definition IR.cpp:1015
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
Definition IR.cpp:1077
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
Definition IR.cpp:1358
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1249
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
Definition IR.cpp:281
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
Definition IR.cpp:1214
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
Definition IR.cpp:1348
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
Definition IR.cpp:407
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
Definition IR.cpp:299
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
Definition IR.cpp:375
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
Definition IR.cpp:1063
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
Definition IR.cpp:827
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
Definition IR.cpp:1174
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
Definition IR.cpp:112
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
Definition IR.cpp:857
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1261
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
Definition IR.cpp:1224
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
Definition IR.cpp:334
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
Definition IR.cpp:1005
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
Definition IR.cpp:388
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition IR.cpp:869
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
Definition IR.cpp:269
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
Definition IR.cpp:1253
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
Definition IR.cpp:1344
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
Definition IR.h:182
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
Definition IR.cpp:994
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
Definition IR.cpp:325
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
Definition IR.cpp:399
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
Definition IR.h:244
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
Definition IR.cpp:887
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
Definition IR.cpp:55
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:986
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
Definition IR.cpp:411
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition IR.cpp:1068
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
Definition IR.cpp:1309
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
Definition IR.cpp:1373
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
Definition IR.cpp:927
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
Definition IR.cpp:998
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
Definition IR.h:876
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
Definition IR.cpp:1257
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1320
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
Definition IR.cpp:108
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
Definition IR.cpp:121
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
Definition IR.cpp:305
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
Definition IR.cpp:379
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
Definition IR.cpp:104
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
Definition IR.cpp:329
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
Definition IR.cpp:1206
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
Definition IR.cpp:1340
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
Definition IR.cpp:914
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
Definition IR.cpp:1054
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
Definition IR.cpp:855
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
Definition IR.h:1246
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
Definition IR.cpp:77
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
Definition IR.cpp:905
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
Definition IR.cpp:861
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
Definition IR.cpp:1161
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
Definition IR.cpp:848
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
Definition IR.cpp:71
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
Definition IR.cpp:920
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:138
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:51
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:132
struct MlirStringRef MlirStringRef
Definition Support.h:77
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:127
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Definition Support.h:163
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:47
PyObjectRef< PyOperation > PyOperationRef
Definition IRModule.h:604
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRModule.h:190
void populateIRCore(nanobind::module_ &m)
PyObjectRef< PyModule > PyModuleRef
Definition IRModule.h:511
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
Definition Diagnostics.h:26
A logical result value, essentially a boolean with named states.
Definition Support.h:116
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:73
const char * data
Pointer to the first symbol.
Definition Support.h:74
size_t length
Length of the fragment.
Definition Support.h:75
static bool dunderContains(const std::string &attributeKind)
Definition IRCore.cpp:160
static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace)
Definition IRCore.cpp:169
static nb::callable dunderGetItemNamed(const std::string &attributeKind)
Definition IRCore.cpp:163
static void bind(nb::module_ &m)
Definition IRCore.cpp:175
Wrapper for the global LLVM debugging flag.
Definition IRCore.cpp:116
static void bind(nb::module_ &m)
Definition IRCore.cpp:127
static void set(nb::object &o, bool enable)
Definition IRCore.cpp:117
static bool get(const nb::object &)
Definition IRCore.cpp:122
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRModule.h:1318
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Definition IRModule.h:1323
Materialized diagnostic information.
Definition IRModule.h:342
std::vector< DiagnosticInfo > notes
Definition IRModule.h:346
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRModule.h:408
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRModule.h:418