20#include "nanobind/nanobind.h"
21#include "nanobind/typing.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/SmallVector.h"
28using namespace nb::literals;
37 R
"(Parses a module's assembly format from a string.
39Returns a new MlirModule or raises an MLIRError if the parsing fails.
41See also: https://mlir.llvm.org/docs/LangRef/
45 "Dumps a debug representation of the object to stderr.";
48 R
"(Replace all uses of this value with the `with` value, except for those
49in `exceptions`. `exceptions` can be either a single operation or a list of
58template <
class Func,
typename... Args>
60 nb::object
cf = nb::cpp_function(f, args...);
61 return nb::borrow<nb::object>((PyClassMethod_New(
cf.ptr())));
66 nb::object dialectDescriptor) {
70 return nb::cast(
PyDialect(std::move(dialectDescriptor)));
74 return (*dialectClass)(std::move(dialectDescriptor));
92 const std::optional<nb::sequence> &pyArgLocs) {
94 argTypes.reserve(nb::len(pyArgTypes));
95 for (
const auto &pyType : pyArgTypes)
96 argTypes.push_back(nb::cast<PyType &>(pyType));
100 argLocs.reserve(nb::len(*pyArgLocs));
101 for (
const auto &pyLoc : *pyArgLocs)
102 argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
103 }
else if (!argTypes.empty()) {
107 if (argTypes.size() != argLocs.size())
108 throw nb::value_error((
"Expected " + Twine(argTypes.size()) +
109 " locations, got: " + Twine(argLocs.size()))
112 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
117 static void set(nb::object &o,
bool enable) {
118 nb::ft_lock_guard lock(mutex);
122 static bool get(
const nb::object &) {
123 nb::ft_lock_guard lock(mutex);
127 static void bind(nb::module_ &m) {
129 nb::class_<PyGlobalDebugFlag>(m,
"_GlobalDebug")
134 [](
const std::string &type) {
135 nb::ft_lock_guard lock(mutex);
138 "types"_a,
"Sets specific debug types to be produced by LLVM.")
141 [](
const std::vector<std::string> &types) {
142 std::vector<const char *> pointers;
143 pointers.reserve(types.size());
144 for (
const std::string &str : types)
145 pointers.push_back(str.c_str());
146 nb::ft_lock_guard lock(mutex);
150 "Sets multiple specific debug types to be produced by LLVM.");
154 static nb::ft_mutex mutex;
157nb::ft_mutex PyGlobalDebugFlag::mutex;
166 throw nb::key_error(attributeKind.c_str());
170 nb::callable
func,
bool replace) {
175 static void bind(nb::module_ &m) {
176 nb::class_<PyAttrBuilderMap>(m,
"AttrBuilder")
179 "Checks whether an attribute builder is registered for the "
180 "given attribute kind.")
183 "Gets the registered attribute builder for the given "
186 "attribute_kind"_a,
"attr_builder"_a,
"replace"_a =
false,
187 "Register an attribute builder for building MLIR "
188 "attributes from Python values.");
206class PyRegionIterator {
211 PyRegionIterator &dunderIter() {
return *
this; }
216 throw nb::stop_iteration();
219 return PyRegion(operation, region);
222 static void bind(nb::module_ &m) {
223 nb::class_<PyRegionIterator>(m,
"RegionIterator")
224 .def(
"__iter__", &PyRegionIterator::dunderIter,
225 "Returns an iterator over the regions in the operation.")
226 .def(
"__next__", &PyRegionIterator::dunderNext,
227 "Returns the next region in the iteration.");
237class PyRegionList :
public Sliceable<PyRegionList, PyRegion> {
239 static constexpr const char *pyClassName =
"RegionSequence";
242 intptr_t length = -1, intptr_t step = 1)
243 : Sliceable(startIndex,
247 operation(std::move(operation)) {}
249 PyRegionIterator dunderIter() {
251 return PyRegionIterator(operation, startIndex);
254 static void bindDerived(ClassTy &c) {
255 c.def(
"__iter__", &PyRegionList::dunderIter,
256 "Returns an iterator over the regions in the sequence.");
261 friend class Sliceable<PyRegionList, PyRegion>;
263 intptr_t getRawNumElements() {
264 operation->checkValid();
268 PyRegion getRawElement(intptr_t pos) {
269 operation->checkValid();
273 PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
274 return PyRegionList(operation, startIndex, length, step);
280class PyBlockIterator {
283 : operation(std::move(operation)), next(next) {}
285 PyBlockIterator &dunderIter() {
return *
this; }
287 PyBlock dunderNext() {
288 operation->checkValid();
290 throw nb::stop_iteration();
293 PyBlock returnBlock(operation, next);
298 static void bind(nb::module_ &m) {
299 nb::class_<PyBlockIterator>(m,
"BlockIterator")
300 .def(
"__iter__", &PyBlockIterator::dunderIter,
301 "Returns an iterator over the blocks in the operation's region.")
302 .def(
"__next__", &PyBlockIterator::dunderNext,
303 "Returns the next block in the iteration.");
317 : operation(std::move(operation)), region(region) {}
319 PyBlockIterator dunderIter() {
320 operation->checkValid();
324 intptr_t dunderLen() {
325 operation->checkValid();
335 PyBlock dunderGetItem(intptr_t index) {
336 operation->checkValid();
338 index += dunderLen();
341 throw nb::index_error(
"attempt to access out of bounds block");
346 return PyBlock(operation, block);
351 throw nb::index_error(
"attempt to access out of bounds block");
354 PyBlock appendBlock(
const nb::args &pyArgTypes,
355 const std::optional<nb::sequence> &pyArgLocs) {
356 operation->checkValid();
358 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
360 return PyBlock(operation, block);
363 static void bind(nb::module_ &m) {
364 nb::class_<PyBlockList>(m,
"BlockList")
365 .def(
"__getitem__", &PyBlockList::dunderGetItem,
366 "Returns the block at the specified index.")
367 .def(
"__iter__", &PyBlockList::dunderIter,
368 "Returns an iterator over blocks in the operation's region.")
369 .def(
"__len__", &PyBlockList::dunderLen,
370 "Returns the number of blocks in the operation's region.")
371 .def(
"append", &PyBlockList::appendBlock,
373 Appends a new block, with argument types as positional args.
378 nb::arg("args"), nb::kw_only(),
379 nb::arg(
"arg_locs") = std::nullopt);
387class PyOperationIterator {
389 PyOperationIterator(
PyOperationRef parentOperation, MlirOperation next)
390 : parentOperation(std::move(parentOperation)), next(next) {}
392 PyOperationIterator &dunderIter() {
return *
this; }
394 nb::typed<nb::object, PyOpView> dunderNext() {
395 parentOperation->checkValid();
396 if (mlirOperationIsNull(next)) {
397 throw nb::stop_iteration();
406 static void bind(nb::module_ &m) {
407 nb::class_<PyOperationIterator>(m,
"OperationIterator")
408 .def(
"__iter__", &PyOperationIterator::dunderIter,
409 "Returns an iterator over the operations in an operation's block.")
410 .def(
"__next__", &PyOperationIterator::dunderNext,
411 "Returns the next operation in the iteration.");
423class PyOperationList {
426 : parentOperation(std::move(parentOperation)), block(block) {}
428 PyOperationIterator dunderIter() {
429 parentOperation->checkValid();
430 return PyOperationIterator(parentOperation,
434 intptr_t dunderLen() {
435 parentOperation->checkValid();
438 while (!mlirOperationIsNull(childOp)) {
445 nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
446 parentOperation->checkValid();
448 index += dunderLen();
451 throw nb::index_error(
"attempt to access out of bounds operation");
454 while (!mlirOperationIsNull(childOp)) {
462 throw nb::index_error(
"attempt to access out of bounds operation");
465 static void bind(nb::module_ &m) {
466 nb::class_<PyOperationList>(m,
"OperationList")
467 .def(
"__getitem__", &PyOperationList::dunderGetItem,
468 "Returns the operation at the specified index.")
469 .def(
"__iter__", &PyOperationList::dunderIter,
470 "Returns an iterator over operations in the list.")
471 .def(
"__len__", &PyOperationList::dunderLen,
472 "Returns the number of operations in the list.");
482 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
484 nb::typed<nb::object, PyOpView> getOwner() {
493 static void bind(nb::module_ &m) {
494 nb::class_<PyOpOperand>(m,
"OpOperand")
495 .def_prop_ro(
"owner", &PyOpOperand::getOwner,
496 "Returns the operation that owns this operand.")
497 .def_prop_ro(
"operand_number", &PyOpOperand::getOperandNumber,
498 "Returns the operand number in the owning operation.");
502 MlirOpOperand opOperand;
505class PyOpOperandIterator {
507 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
509 PyOpOperandIterator &dunderIter() {
return *
this; }
511 PyOpOperand dunderNext() {
513 throw nb::stop_iteration();
515 PyOpOperand returnOpOperand(opOperand);
517 return returnOpOperand;
520 static void bind(nb::module_ &m) {
521 nb::class_<PyOpOperandIterator>(m,
"OpOperandIterator")
522 .def(
"__iter__", &PyOpOperandIterator::dunderIter,
523 "Returns an iterator over operands.")
524 .def(
"__next__", &PyOpOperandIterator::dunderNext,
525 "Returns the next operand in the iteration.");
529 MlirOpOperand opOperand;
539 nb::gil_scoped_acquire acquire;
540 nb::ft_lock_guard lock(live_contexts_mutex);
541 auto &liveContexts = getLiveContexts();
542 liveContexts[context.ptr] =
this;
549 nb::gil_scoped_acquire acquire;
551 nb::ft_lock_guard lock(live_contexts_mutex);
552 getLiveContexts().erase(context.ptr);
564 throw nb::python_error();
569 nb::gil_scoped_acquire acquire;
570 nb::ft_lock_guard lock(live_contexts_mutex);
571 auto &liveContexts = getLiveContexts();
572 auto it = liveContexts.find(context.ptr);
573 if (it == liveContexts.end()) {
576 nb::object pyRef = nb::cast(unownedContextWrapper);
577 assert(pyRef &&
"cast to nb::object failed");
578 liveContexts[context.ptr] = unownedContextWrapper;
582 nb::object pyRef = nb::cast(it->second);
586nb::ft_mutex PyMlirContext::live_contexts_mutex;
588PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
589 static LiveContextMap liveContexts;
594 nb::ft_lock_guard lock(live_contexts_mutex);
595 return getLiveContexts().size();
603 const nb::object &excVal,
604 const nb::object &excTb) {
613 nb::object pyHandlerObject =
614 nb::cast(pyHandler, nb::rv_policy::take_ownership);
615 (
void)pyHandlerObject.inc_ref();
619 auto handlerCallback =
622 nb::object pyDiagnosticObject =
623 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
630 nb::gil_scoped_acquire gil;
632 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
633 }
catch (std::exception &e) {
634 fprintf(stderr,
"MLIR Python Diagnostic handler raised exception: %s\n",
636 pyHandler->hadError =
true;
643 auto deleteCallback = +[](
void *userData) {
645 assert(pyHandler->registeredID &&
"handler is not registered");
646 pyHandler->registeredID.reset();
649 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
650 pyHandlerObject.dec_ref();
654 get(), handlerCallback,
static_cast<void *
>(pyHandler), deleteCallback);
655 return pyHandlerObject;
662 if (self->ctx->emitErrorDiagnostics)
675 throw std::runtime_error(
676 "An MLIR function requires a Context but none was provided in the call "
677 "or from the surrounding environment. Either pass to the function with "
678 "a 'context=' argument or establish a default using 'with Context():'");
688 static thread_local std::vector<PyThreadContextEntry> stack;
696 return &stack.back();
699void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
700 nb::object insertionPoint,
701 nb::object location) {
702 auto &stack = getStack();
703 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
704 std::move(location));
708 if (stack.size() > 1) {
709 auto &prev = *(stack.rbegin() + 1);
710 auto ¤t = stack.back();
711 if (current.context.is(prev.context)) {
713 if (!current.insertionPoint)
714 current.insertionPoint = prev.insertionPoint;
715 if (!current.location)
716 current.location = prev.location;
724 return nb::cast<PyMlirContext *>(context);
730 return nb::cast<PyInsertionPoint *>(insertionPoint);
736 return nb::cast<PyLocation *>(location);
741 return tos ? tos->getContext() :
nullptr;
746 return tos ? tos->getInsertionPoint() :
nullptr;
751 return tos ? tos->getLocation() :
nullptr;
764 throw std::runtime_error(
"Unbalanced Context enter/exit");
765 auto &tos = stack.back();
767 throw std::runtime_error(
"Unbalanced Context enter/exit");
774 nb::cast<PyInsertionPoint &>(insertionPointObj);
775 nb::object contextObj =
776 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
781 return insertionPointObj;
787 throw std::runtime_error(
"Unbalanced InsertionPoint enter/exit");
788 auto &tos = stack.back();
790 tos.getInsertionPoint() != &insertionPoint)
791 throw std::runtime_error(
"Unbalanced InsertionPoint enter/exit");
796 PyLocation &location = nb::cast<PyLocation &>(locationObj);
797 nb::object contextObj = location.getContext().getObject();
807 throw std::runtime_error(
"Unbalanced Location enter/exit");
808 auto &tos = stack.back();
810 throw std::runtime_error(
"Unbalanced Location enter/exit");
820 if (materializedNotes) {
821 for (nb::handle noteObject : *materializedNotes) {
822 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
830 : context(context), callback(std::move(callback)) {}
839 assert(!registeredID &&
"should have unregistered");
845void PyDiagnostic::checkValid() {
847 throw std::invalid_argument(
848 "Diagnostic is invalid (used outside of callback)");
866 nb::object fileObject = nb::module_::import_(
"io").attr(
"StringIO")();
869 return nb::cast<nb::str>(fileObject.attr(
"getvalue")());
874 if (materializedNotes)
875 return *materializedNotes;
877 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
878 for (intptr_t i = 0; i < numNotes; ++i) {
880 nb::object diagnostic = nb::cast(
PyDiagnostic(noteDiag));
881 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
883 materializedNotes = std::move(notes);
885 return *materializedNotes;
889 std::vector<DiagnosticInfo> notes;
891 notes.emplace_back(nb::cast<PyDiagnostic>(n).
getInfo());
903 {key.data(), key.size()});
905 std::string msg = (Twine(
"Dialect '") + key +
"' not found").str();
907 throw nb::attribute_error(msg.c_str());
908 throw nb::index_error(msg.c_str());
918 MlirDialectRegistry rawRegistry =
921 throw nb::python_error();
936 throw nb::python_error();
946 const nb::object &excVal,
947 const nb::object &excTb) {
954 throw std::runtime_error(
955 "An MLIR function requires a Location but none was provided in the "
956 "call or from the surrounding environment. Either pass to the function "
957 "with a 'loc=' argument or establish a default using 'with loc:'");
970 nb::gil_scoped_acquire acquire;
971 auto &liveModules =
getContext()->liveModules;
972 assert(liveModules.count(module.ptr) == 1 &&
973 "destroying module not in live map");
974 liveModules.erase(module.ptr);
982 nb::gil_scoped_acquire acquire;
983 auto &liveModules = contextRef->liveModules;
984 auto it = liveModules.find(module.ptr);
985 if (it == liveModules.end()) {
991 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
992 unownedModule->handle = pyRef;
993 liveModules[module.ptr] =
994 std::make_pair(unownedModule->handle, unownedModule);
995 return PyModuleRef(unownedModule, std::move(pyRef));
998 PyModule *existing = it->second.second;
999 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1005 if (mlirModuleIsNull(rawModule))
1006 throw nb::python_error();
1038template <
typename T,
class... Args>
1040 nb::handle type = nb::type<T>();
1041 nb::object instance = nb::inst_alloc(type);
1042 T *
ptr = nb::inst_ptr<T>(instance);
1043 new (
ptr) T(std::forward<Args>(args)...);
1044 nb::inst_mark_ready(instance);
1051 MlirOperation operation,
1052 nb::object parentKeepAlive) {
1055 makeObjectRef<PyOperation>(std::move(contextRef), operation);
1056 unownedOperation->handle = unownedOperation.
getObject();
1057 if (parentKeepAlive) {
1058 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1060 return unownedOperation;
1064 MlirOperation operation,
1065 nb::object parentKeepAlive) {
1066 return createInstance(std::move(contextRef), operation,
1067 std::move(parentKeepAlive));
1071 MlirOperation operation,
1072 nb::object parentKeepAlive) {
1073 PyOperationRef created = createInstance(std::move(contextRef), operation,
1074 std::move(parentKeepAlive));
1075 created->attached =
false;
1080 const std::string &sourceStr,
1081 const std::string &sourceName) {
1086 if (mlirOperationIsNull(op))
1087 throw MLIRError(
"Unable to parse operation assembly", errors.
take());
1093 throw std::runtime_error(
"the operation has been invalidated");
1098 std::optional<int64_t> largeResourceLimit,
1099 bool enableDebugInfo,
bool prettyDebugInfo,
1100 bool printGenericOpForm,
bool useLocalScope,
1101 bool useNameLocAsPrefix,
bool assumeVerified,
1102 nb::object fileObject,
bool binary,
1106 if (fileObject.is_none())
1107 fileObject = nb::module_::import_(
"sys").attr(
"stdout");
1110 if (largeElementsLimit)
1112 if (largeResourceLimit)
1114 if (enableDebugInfo)
1117 if (printGenericOpForm)
1125 if (useNameLocAsPrefix)
1130 accum.getUserData());
1138 if (fileObject.is_none())
1139 fileObject = nb::module_::import_(
"sys").attr(
"stdout");
1142 accum.getUserData());
1146 std::optional<int64_t> bytecodeVersion) {
1151 if (!bytecodeVersion.has_value())
1161 throw nb::value_error((Twine(
"Unable to honor desired bytecode version ") +
1162 Twine(*bytecodeVersion))
1175 std::string exceptionWhat;
1176 nb::object exceptionType;
1178 UserData userData{callback,
false, {}, {}};
1181 UserData *calleeUserData =
static_cast<UserData *
>(userData);
1183 return (calleeUserData->callback)(op);
1184 }
catch (nb::python_error &e) {
1185 calleeUserData->gotException =
true;
1186 calleeUserData->exceptionWhat = std::string(e.what());
1187 calleeUserData->exceptionType = nb::borrow(e.type());
1188 return MlirWalkResult::MlirWalkResultInterrupt;
1192 if (userData.gotException) {
1193 std::string message(
"Exception raised in callback: ");
1194 message.append(userData.exceptionWhat);
1195 throw std::runtime_error(message);
1200 std::optional<int64_t> largeElementsLimit,
1201 std::optional<int64_t> largeResourceLimit,
1202 bool enableDebugInfo,
bool prettyDebugInfo,
1203 bool printGenericOpForm,
bool useLocalScope,
1204 bool useNameLocAsPrefix,
bool assumeVerified,
1206 nb::object fileObject;
1208 fileObject = nb::module_::import_(
"io").attr(
"BytesIO")();
1210 fileObject = nb::module_::import_(
"io").attr(
"StringIO")();
1212 print(largeElementsLimit,
1224 return fileObject.attr(
"getvalue")();
1233 operation.parentKeepAlive = otherOp.parentKeepAlive;
1242 operation.parentKeepAlive = otherOp.parentKeepAlive;
1264 throw nb::value_error(
"Detached operations have no parent");
1266 if (mlirOperationIsNull(operation))
1275 assert(!
mlirBlockIsNull(block) &&
"Attached operation has null parent");
1276 assert(parentOperation &&
"Operation has no parent");
1277 return PyBlock{std::move(*parentOperation), block};
1287 if (mlirOperationIsNull(rawOperation))
1288 throw nb::python_error();
1295 const nb::object &maybeIp) {
1297 if (!maybeIp.is(nb::cast(
false))) {
1299 if (maybeIp.is_none()) {
1302 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1310 std::optional<std::vector<PyType *>> results,
1312 std::optional<nb::dict> attributes,
1313 std::optional<std::vector<PyBlock *>> successors,
1315 const nb::object &maybeIp,
bool inferType) {
1322 throw nb::value_error(
"number of regions must be >= 0");
1326 mlirResults.reserve(results->size());
1330 throw nb::value_error(
"result type cannot be None");
1331 mlirResults.push_back(*
result);
1336 mlirAttributes.reserve(attributes->size());
1337 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1340 key = nb::cast<std::string>(it.first);
1341 }
catch (nb::cast_error &err) {
1342 std::string msg =
"Invalid attribute key (not a string) when "
1343 "attempting to create the operation \"" +
1344 std::string(name) +
"\" (" + err.what() +
")";
1345 throw nb::type_error(msg.c_str());
1348 auto &attribute = nb::cast<PyAttribute &>(it.second);
1350 mlirAttributes.emplace_back(std::move(key), attribute);
1351 }
catch (nb::cast_error &err) {
1352 std::string msg =
"Invalid attribute value for the key \"" + key +
1353 "\" when attempting to create the operation \"" +
1354 std::string(name) +
"\" (" + err.what() +
")";
1355 throw nb::type_error(msg.c_str());
1356 }
catch (std::runtime_error &) {
1359 "Found an invalid (`None`?) attribute value for the key \"" + key +
1360 "\" when attempting to create the operation \"" +
1361 std::string(name) +
"\"";
1362 throw std::runtime_error(msg);
1368 mlirSuccessors.reserve(successors->size());
1369 for (
auto *successor : *successors) {
1372 throw nb::value_error(
"successor block cannot be None");
1373 mlirSuccessors.push_back(successor->get());
1379 MlirOperationState state =
1381 if (!operands.empty())
1383 state.enableResultTypeInference = inferType;
1384 if (!mlirResults.empty())
1386 mlirResults.data());
1387 if (!mlirAttributes.empty()) {
1392 mlirNamedAttributes.reserve(mlirAttributes.size());
1393 for (
auto &it : mlirAttributes)
1399 mlirNamedAttributes.data());
1401 if (!mlirSuccessors.empty())
1403 mlirSuccessors.data());
1406 mlirRegions.resize(regions);
1407 for (
int i = 0; i < regions; ++i)
1410 mlirRegions.data());
1417 throw MLIRError(
"Operation creation failed", errors.
take());
1455template <
typename DerivedTy>
1456class PyConcreteValue :
public PyValue {
1462 using ClassTy = nb::class_<DerivedTy, PyValue>;
1463 using IsAFunctionTy =
bool (*)(MlirValue);
1465 PyConcreteValue() =
default;
1467 :
PyValue(operationRef, value) {}
1468 PyConcreteValue(
PyValue &orig)
1469 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1473 static MlirValue castFrom(
PyValue &orig) {
1474 if (!DerivedTy::isaFunction(orig.
get())) {
1475 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1476 throw nb::value_error((Twine(
"Cannot cast value to ") +
1477 DerivedTy::pyClassName +
" (from " + origRepr +
1486 static void bind(nb::module_ &m) {
1488 m, DerivedTy::pyClassName, nb::is_generic(),
1489 nb::sig((Twine(
"class ") + DerivedTy::pyClassName +
"(Value[_T])")
1492 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg(
"value"));
1495 [](PyValue &otherValue) ->
bool {
1496 return DerivedTy::isaFunction(otherValue);
1498 nb::arg(
"other_value"));
1500 [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
1501 return self.maybeDownCast();
1503 DerivedTy::bindDerived(cls);
1507 static void bindDerived(ClassTy &m) {}
1517 using PyConcreteValue::PyConcreteValue;
1522 [](
PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
1525 "expected the owner of the value in Python to match that in "
1527 return self.getParentOperation().getObject();
1529 "Returns the operation that produces this result.");
1535 "Returns the position of this result in the operation's result list.");
1540template <
typename Container>
1541static std::vector<nb::typed<nb::object, PyType>>
1543 std::vector<nb::typed<nb::object, PyType>>
result;
1544 result.reserve(container.size());
1545 for (
int i = 0, e = container.size(); i < e; ++i) {
1568 operation(std::move(operation)) {}
1576 "Returns a list of types for all results in this result list.");
1582 "Returns the operation that owns this result list.");
1591 intptr_t getRawNumElements() {
1596 PyOpResult getRawElement(intptr_t index) {
1598 return PyOpResult(value);
1601 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1602 return PyOpResultList(operation, startIndex, length, step);
1613 const nb::object &resultSegmentSpecObj,
1614 std::vector<int32_t> &resultSegmentLengths,
1615 std::vector<PyType *> &resultTypes) {
1616 resultTypes.reserve(resultTypeList.size());
1617 if (resultSegmentSpecObj.is_none()) {
1619 for (
const auto &it : llvm::enumerate(resultTypeList)) {
1621 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1622 if (!resultTypes.back())
1623 throw nb::cast_error();
1624 }
catch (nb::cast_error &err) {
1625 throw nb::value_error((llvm::Twine(
"Result ") +
1626 llvm::Twine(it.index()) +
" of operation \"" +
1627 name +
"\" must be a Type (" + err.what() +
")")
1634 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1635 if (resultSegmentSpec.size() != resultTypeList.size()) {
1636 throw nb::value_error((llvm::Twine(
"Operation \"") + name +
1638 llvm::Twine(resultSegmentSpec.size()) +
1639 " result segments but was provided " +
1640 llvm::Twine(resultTypeList.size()))
1644 resultSegmentLengths.reserve(resultTypeList.size());
1645 for (
const auto &it :
1646 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1647 int segmentSpec = std::get<1>(it.value());
1648 if (segmentSpec == 1 || segmentSpec == 0) {
1651 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1653 resultTypes.push_back(resultType);
1654 resultSegmentLengths.push_back(1);
1655 }
else if (segmentSpec == 0) {
1657 resultSegmentLengths.push_back(0);
1659 throw nb::value_error(
1660 (llvm::Twine(
"Result ") + llvm::Twine(it.index()) +
1661 " of operation \"" + name +
1662 "\" must be a Type (was None and result is not optional)")
1666 }
catch (nb::cast_error &err) {
1667 throw nb::value_error((llvm::Twine(
"Result ") +
1668 llvm::Twine(it.index()) +
" of operation \"" +
1669 name +
"\" must be a Type (" + err.what() +
1674 }
else if (segmentSpec == -1) {
1677 if (std::get<0>(it.value()).is_none()) {
1679 resultSegmentLengths.push_back(0);
1682 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1683 for (nb::handle segmentItem : segment) {
1684 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1685 if (!resultTypes.back()) {
1686 throw nb::type_error(
"contained a None item");
1689 resultSegmentLengths.push_back(nb::len(segment));
1691 }
catch (std::exception &err) {
1695 throw nb::value_error((llvm::Twine(
"Result ") +
1696 llvm::Twine(it.index()) +
" of operation \"" +
1697 name +
"\" must be a Sequence of Types (" +
1703 throw nb::value_error(
"Unexpected segment spec");
1711 if (numResults != 1) {
1713 throw nb::value_error((Twine(
"Cannot call .result on operation ") +
1714 StringRef(name.data, name.length) +
" which has " +
1716 " results (it is only valid for operations with a "
1725 if (operand.is_none()) {
1726 throw nb::value_error(
"contained a None item");
1729 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1733 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1737 if (nb::try_cast<PyValue *>(operand, value)) {
1738 return value->
get();
1740 throw nb::value_error(
"is not a Value");
1744 std::string_view name, std::tuple<int, bool> opRegionSpec,
1745 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1746 std::optional<nb::list> resultTypeList, nb::list operandList,
1747 std::optional<nb::dict> attributes,
1748 std::optional<std::vector<PyBlock *>> successors,
1749 std::optional<int> regions,
PyLocation &location,
1750 const nb::object &maybeIp) {
1759 std::vector<int32_t> operandSegmentLengths;
1760 std::vector<int32_t> resultSegmentLengths;
1763 int opMinRegionCount = std::get<0>(opRegionSpec);
1764 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1766 regions = opMinRegionCount;
1768 if (*regions < opMinRegionCount) {
1769 throw nb::value_error(
1770 (llvm::Twine(
"Operation \"") + name +
"\" requires a minimum of " +
1771 llvm::Twine(opMinRegionCount) +
1772 " regions but was built with regions=" + llvm::Twine(*regions))
1776 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1777 throw nb::value_error(
1778 (llvm::Twine(
"Operation \"") + name +
"\" requires a maximum of " +
1779 llvm::Twine(opMinRegionCount) +
1780 " regions but was built with regions=" + llvm::Twine(*regions))
1786 std::vector<PyType *> resultTypes;
1787 if (resultTypeList.has_value()) {
1789 resultSegmentLengths, resultTypes);
1794 operands.reserve(operands.size());
1795 if (operandSegmentSpecObj.is_none()) {
1797 for (
const auto &it : llvm::enumerate(operandList)) {
1800 }
catch (nb::builtin_exception &err) {
1801 throw nb::value_error((llvm::Twine(
"Operand ") +
1802 llvm::Twine(it.index()) +
" of operation \"" +
1803 name +
"\" must be a Value (" + err.what() +
")")
1810 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1811 if (operandSegmentSpec.size() != operandList.size()) {
1812 throw nb::value_error((llvm::Twine(
"Operation \"") + name +
1814 llvm::Twine(operandSegmentSpec.size()) +
1815 "operand segments but was provided " +
1816 llvm::Twine(operandList.size()))
1820 operandSegmentLengths.reserve(operandList.size());
1821 for (
const auto &it :
1822 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1823 int segmentSpec = std::get<1>(it.value());
1824 if (segmentSpec == 1 || segmentSpec == 0) {
1826 auto &operand = std::get<0>(it.value());
1827 if (!operand.is_none()) {
1831 }
catch (nb::builtin_exception &err) {
1832 throw nb::value_error((llvm::Twine(
"Operand ") +
1833 llvm::Twine(it.index()) +
1834 " of operation \"" + name +
1835 "\" must be a Value (" + err.what() +
")")
1840 operandSegmentLengths.push_back(1);
1841 }
else if (segmentSpec == 0) {
1843 operandSegmentLengths.push_back(0);
1845 throw nb::value_error(
1846 (llvm::Twine(
"Operand ") + llvm::Twine(it.index()) +
1847 " of operation \"" + name +
1848 "\" must be a Value (was None and operand is not optional)")
1852 }
else if (segmentSpec == -1) {
1855 if (std::get<0>(it.value()).is_none()) {
1857 operandSegmentLengths.push_back(0);
1860 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1861 for (nb::handle segmentItem : segment) {
1864 operandSegmentLengths.push_back(nb::len(segment));
1866 }
catch (std::exception &err) {
1870 throw nb::value_error((llvm::Twine(
"Operand ") +
1871 llvm::Twine(it.index()) +
" of operation \"" +
1872 name +
"\" must be a Sequence of Values (" +
1878 throw nb::value_error(
"Unexpected segment spec");
1884 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1887 attributes = nb::dict(*attributes);
1889 attributes = nb::dict();
1891 if (attributes->contains(
"resultSegmentSizes") ||
1892 attributes->contains(
"operandSegmentSizes")) {
1893 throw nb::value_error(
"Manually setting a 'resultSegmentSizes' or "
1894 "'operandSegmentSizes' attribute is unsupported. "
1895 "Use Operation.create for such low-level access.");
1899 if (!resultSegmentLengths.empty()) {
1900 MlirAttribute segmentLengthAttr =
1902 resultSegmentLengths.data());
1903 (*attributes)[
"resultSegmentSizes"] =
1908 if (!operandSegmentLengths.empty()) {
1909 MlirAttribute segmentLengthAttr =
1911 operandSegmentLengths.data());
1912 (*attributes)[
"operandSegmentSizes"] =
1919 std::move(resultTypes),
1921 std::move(attributes),
1922 std::move(successors),
1923 *regions, location, maybeIp,
1928 const nb::object &operation) {
1929 nb::handle opViewType = nb::type<PyOpView>();
1930 nb::object instance = cls.attr(
"__new__")(cls);
1931 opViewType.attr(
"__init__")(instance, operation);
1939 operationObject(operation.getRef().getObject()) {}
1948 : refOperation(beforeOperationBase.getOperation().getRef()),
1949 block((*refOperation)->
getBlock()) {}
1952 : refOperation(beforeOperationRef), block((*refOperation)->
getBlock()) {}
1957 throw nb::value_error(
1958 "Attempt to insert operation that is already attached");
1959 block.getParentOperation()->checkValid();
1960 MlirOperation beforeOp = {
nullptr};
1963 (*refOperation)->checkValid();
1964 beforeOp = (*refOperation)->get();
1970 throw nb::index_error(
"Cannot insert operation at the end of a block "
1971 "that already has a terminator. Did you mean to "
1972 "use 'InsertionPoint.at_block_terminator(block)' "
1973 "versus 'InsertionPoint(block)'?");
1982 if (mlirOperationIsNull(firstOp)) {
1989 block.getParentOperation()->getContext(), firstOp);
1995 if (mlirOperationIsNull(terminator))
1996 throw nb::value_error(
"Block has no terminator");
1998 block.getParentOperation()->getContext(), terminator);
2006 if (mlirOperationIsNull(nextOperation))
2009 block.getParentOperation()->getContext(), nextOperation);
2020 const nb::object &excVal,
2021 const nb::object &excTb) {
2039 if (mlirAttributeIsNull(rawAttr))
2040 throw nb::python_error();
2048 "mlirTypeID was expected to be non-null.");
2053 nb::object thisObj = nb::cast(
this, nb::rv_policy::move);
2056 return typeCaster.value()(thisObj);
2064 : ownedName(new std::string(std::move(ownedName))) {
2086 throw nb::python_error();
2094 "mlirTypeID was expected to be non-null.");
2099 nb::object thisObj = nb::cast(
this, nb::rv_policy::move);
2102 return typeCaster.value()(thisObj);
2116 throw nb::python_error();
2135 "mlirTypeID was expected to be non-null.");
2136 std::optional<nb::callable> valueCaster =
2140 nb::object thisObj = nb::cast(
this, nb::rv_policy::move);
2143 return valueCaster.value()(thisObj);
2148 if (mlirValueIsNull(value))
2149 throw nb::python_error();
2150 MlirOperation owner;
2155 if (mlirOperationIsNull(owner))
2156 throw nb::python_error();
2160 return PyValue(ownerRef, value);
2168 : operation(operation.getOperation().getRef()) {
2171 throw nb::type_error(
"Operation is not a Symbol Table.");
2176 operation->checkValid();
2179 if (mlirOperationIsNull(symbol))
2180 throw nb::key_error(
2181 (
"Symbol '" + name +
"' not in the symbol table.").c_str());
2184 operation.getObject())
2189 operation->checkValid();
2200 erase(nb::cast<PyOperationBase &>(operation));
2204 operation->checkValid();
2208 if (mlirAttributeIsNull(symbolAttr))
2209 throw nb::value_error(
"Expected operation to have a symbol name.");
2220 MlirAttribute existingNameAttr =
2222 if (mlirAttributeIsNull(existingNameAttr))
2223 throw nb::value_error(
"Expected operation to have a symbol name.");
2229 const std::string &name) {
2234 MlirAttribute existingNameAttr =
2236 if (mlirAttributeIsNull(existingNameAttr))
2237 throw nb::value_error(
"Expected operation to have a symbol name.");
2238 MlirAttribute newNameAttr =
2247 MlirAttribute existingVisAttr =
2249 if (mlirAttributeIsNull(existingVisAttr))
2250 throw nb::value_error(
"Expected operation to have a symbol visibility.");
2255 const std::string &visibility) {
2256 if (visibility !=
"public" && visibility !=
"private" &&
2257 visibility !=
"nested")
2258 throw nb::value_error(
2259 "Expected visibility to be 'public', 'private' or 'nested'");
2263 MlirAttribute existingVisAttr =
2265 if (mlirAttributeIsNull(existingVisAttr))
2266 throw nb::value_error(
"Expected operation to have a symbol visibility.");
2273 const std::string &newSymbol,
2281 throw nb::value_error(
"Symbol rename failed");
2285 bool allSymUsesVisible,
2286 nb::object callback) {
2291 nb::object callback;
2293 std::string exceptionWhat;
2294 nb::object exceptionType;
2297 fromOperation.
getContext(), std::move(callback),
false, {}, {}};
2299 fromOperation.
get(), allSymUsesVisible,
2300 [](MlirOperation foundOp,
bool isVisible,
void *calleeUserDataVoid) {
2301 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2303 PyOperation::forOperation(calleeUserData->context, foundOp);
2304 if (calleeUserData->gotException)
2307 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2308 } catch (nb::python_error &e) {
2309 calleeUserData->gotException =
true;
2310 calleeUserData->exceptionWhat = e.what();
2311 calleeUserData->exceptionType = nb::borrow(e.type());
2314 static_cast<void *
>(&userData));
2315 if (userData.gotException) {
2316 std::string message(
"Exception raised in callback: ");
2317 message.append(userData.exceptionWhat);
2318 throw std::runtime_error(message);
2325class PyBlockArgument :
public PyConcreteValue<PyBlockArgument> {
2328 static constexpr const char *pyClassName =
"BlockArgument";
2329 using PyConcreteValue::PyConcreteValue;
2331 static void bindDerived(ClassTy &c) {
2334 [](PyBlockArgument &self) {
2335 return PyBlock(self.getParentOperation(),
2338 "Returns the block that owns this argument.");
2341 [](PyBlockArgument &self) {
2344 "Returns the position of this argument in the block's argument list.");
2347 [](PyBlockArgument &self, PyType type) {
2350 nb::arg(
"type"),
"Sets the type of this block argument.");
2353 [](PyBlockArgument &self, PyLocation loc) {
2356 nb::arg(
"loc"),
"Sets the location of this block argument.");
2364class PyBlockArgumentList
2365 :
public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2367 static constexpr const char *pyClassName =
"BlockArgumentList";
2368 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2371 intptr_t startIndex = 0, intptr_t length = -1,
2373 : Sliceable(startIndex,
2376 operation(std::move(operation)), block(block) {}
2378 static void bindDerived(ClassTy &c) {
2381 [](PyBlockArgumentList &self) {
2384 "Returns a list of types for all arguments in this argument list.");
2389 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2392 intptr_t getRawNumElements() {
2398 PyBlockArgument getRawElement(intptr_t pos) {
2400 return PyBlockArgument(operation, argument);
2404 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2406 return PyBlockArgumentList(operation, block, startIndex, length, step);
2417class PyOpOperandList :
public Sliceable<PyOpOperandList, PyValue> {
2419 static constexpr const char *pyClassName =
"OpOperandList";
2420 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2422 PyOpOperandList(
PyOperationRef operation, intptr_t startIndex = 0,
2423 intptr_t length = -1, intptr_t step = 1)
2424 : Sliceable(startIndex,
2428 operation(operation) {}
2430 void dunderSetItem(intptr_t index, PyValue value) {
2431 index = wrapIndex(index);
2435 static void bindDerived(ClassTy &c) {
2436 c.def(
"__setitem__", &PyOpOperandList::dunderSetItem, nb::arg(
"index"),
2438 "Sets the operand at the specified index to a new value.");
2443 friend class Sliceable<PyOpOperandList, PyValue>;
2445 intptr_t getRawNumElements() {
2450 PyValue getRawElement(intptr_t pos) {
2452 MlirOperation owner;
2458 assert(
false &&
"Value must be an block arg or op result.");
2461 return PyValue(pyOwner, operand);
2464 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2465 return PyOpOperandList(operation, startIndex, length, step);
2475class PyOpSuccessors :
public Sliceable<PyOpSuccessors, PyBlock> {
2477 static constexpr const char *pyClassName =
"OpSuccessors";
2479 PyOpSuccessors(
PyOperationRef operation, intptr_t startIndex = 0,
2480 intptr_t length = -1, intptr_t step = 1)
2481 : Sliceable(startIndex,
2485 operation(operation) {}
2487 void dunderSetItem(intptr_t index, PyBlock block) {
2488 index = wrapIndex(index);
2492 static void bindDerived(ClassTy &c) {
2493 c.def(
"__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg(
"index"),
2494 nb::arg(
"block"),
"Sets the successor block at the specified index.");
2499 friend class Sliceable<PyOpSuccessors, PyBlock>;
2501 intptr_t getRawNumElements() {
2506 PyBlock getRawElement(intptr_t pos) {
2508 return PyBlock(operation, block);
2511 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2512 return PyOpSuccessors(operation, startIndex, length, step);
2522class PyBlockSuccessors :
public Sliceable<PyBlockSuccessors, PyBlock> {
2524 static constexpr const char *pyClassName =
"BlockSuccessors";
2527 intptr_t startIndex = 0, intptr_t length = -1,
2529 : Sliceable(startIndex,
2533 operation(operation), block(block) {}
2537 friend class Sliceable<PyBlockSuccessors, PyBlock>;
2539 intptr_t getRawNumElements() {
2544 PyBlock getRawElement(intptr_t pos) {
2546 return PyBlock(operation, block);
2549 PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2550 return PyBlockSuccessors(block, operation, startIndex, length, step);
2564class PyBlockPredecessors :
public Sliceable<PyBlockPredecessors, PyBlock> {
2566 static constexpr const char *pyClassName =
"BlockPredecessors";
2569 intptr_t startIndex = 0, intptr_t length = -1,
2571 : Sliceable(startIndex,
2575 operation(operation), block(block) {}
2579 friend class Sliceable<PyBlockPredecessors, PyBlock>;
2581 intptr_t getRawNumElements() {
2586 PyBlock getRawElement(intptr_t pos) {
2588 return PyBlock(operation, block);
2591 PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2593 return PyBlockPredecessors(block, operation, startIndex, length, step);
2602class PyOpAttributeMap {
2605 : operation(std::move(operation)) {}
2607 nb::typed<nb::object, PyAttribute>
2608 dunderGetItemNamed(
const std::string &name) {
2611 if (mlirAttributeIsNull(attr)) {
2612 throw nb::key_error(
"attempt to access a non-existent attribute");
2614 return PyAttribute(operation->
getContext(), attr).maybeDownCast();
2617 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2619 index += dunderLen();
2621 if (index < 0 || index >= dunderLen()) {
2622 throw nb::index_error(
"attempt to access out of bounds attribute");
2626 return PyNamedAttribute(
2632 void dunderSetItem(
const std::string &name,
const PyAttribute &attr) {
2637 void dunderDelItem(
const std::string &name) {
2641 throw nb::key_error(
"attempt to delete a non-existent attribute");
2644 intptr_t dunderLen() {
2648 bool dunderContains(
const std::string &name) {
2654 forEachAttr(MlirOperation op,
2655 llvm::function_ref<
void(
MlirStringRef, MlirAttribute)> fn) {
2657 for (intptr_t i = 0; i < n; ++i) {
2664 static void bind(nb::module_ &m) {
2665 nb::class_<PyOpAttributeMap>(m,
"OpAttributeMap")
2666 .def(
"__contains__", &PyOpAttributeMap::dunderContains, nb::arg(
"name"),
2667 "Checks if an attribute with the given name exists in the map.")
2668 .def(
"__len__", &PyOpAttributeMap::dunderLen,
2669 "Returns the number of attributes in the map.")
2670 .def(
"__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
2671 nb::arg(
"name"),
"Gets an attribute by name.")
2672 .def(
"__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
2673 nb::arg(
"index"),
"Gets a named attribute by index.")
2674 .def(
"__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg(
"name"),
2675 nb::arg(
"attr"),
"Sets an attribute with the given name.")
2676 .def(
"__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg(
"name"),
2677 "Deletes an attribute with the given name.")
2680 [](PyOpAttributeMap &self) {
2682 PyOpAttributeMap::forEachAttr(
2683 self.operation->
get(),
2685 keys.append(nb::str(name.data, name.length));
2687 return nb::iter(keys);
2689 "Iterates over attribute names.")
2692 [](PyOpAttributeMap &self) {
2694 PyOpAttributeMap::forEachAttr(
2695 self.operation->
get(),
2697 out.append(nb::str(name.data, name.length));
2701 "Returns a list of attribute names.")
2704 [](PyOpAttributeMap &self) {
2706 PyOpAttributeMap::forEachAttr(
2707 self.operation->
get(),
2709 out.append(PyAttribute(self.operation->getContext(), attr)
2714 "Returns a list of attribute values.")
2717 [](PyOpAttributeMap &self) {
2719 PyOpAttributeMap::forEachAttr(
2720 self.operation->
get(),
2722 out.append(nb::make_tuple(
2723 nb::str(name.data, name.length),
2724 PyAttribute(self.operation->getContext(), attr)
2729 "Returns a list of `(name, attribute)` tuples.");
2740#define _Py_CAST(type, expr) ((type)(expr))
2747#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
2748 (defined(__cplusplus) && __cplusplus >= 201103)
2749#define _Py_NULL nullptr
2751#define _Py_NULL NULL
2756#if PY_VERSION_HEX < 0x030A00A3
2759#if !defined(Py_XNewRef)
2760[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
2764#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
2768#if !defined(Py_NewRef)
2769[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
2773#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
2779#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
2782PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
2783 assert(tstate !=
_Py_NULL &&
"expected tstate != _Py_NULL");
2788PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
2789 assert(frame !=
_Py_NULL &&
"expected frame != _Py_NULL");
2794PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
2795 assert(frame !=
_Py_NULL &&
"expected frame != _Py_NULL");
2796 assert(frame->f_code !=
_Py_NULL &&
"expected frame->f_code != _Py_NULL");
2802MlirLocation tracebackToLocation(MlirContext ctx) {
2803 size_t framesLimit =
2806 thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2810 nb::gil_scoped_acquire acquire;
2811 PyThreadState *tstate = PyThreadState_GET();
2812 PyFrameObject *next;
2813 PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2819 for (; pyFrame !=
nullptr && count < framesLimit;
2820 next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2821 PyCodeObject *code = PyFrame_GetCode(pyFrame);
2823 nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2824 llvm::StringRef fileName(fileNameStr);
2825 if (!
PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2829#if PY_VERSION_HEX < 0x030B00F0
2831 nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2832 llvm::StringRef funcName(name);
2833 int startLine = PyFrame_GetLineNumber(pyFrame);
2838 nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2839 llvm::StringRef funcName(name);
2840 int startLine, startCol, endLine, endCol;
2841 int lasti = PyFrame_GetLasti(pyFrame);
2842 if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2844 throw nb::python_error();
2847 ctx,
wrap(fileName), startLine, startCol, endLine, endCol);
2855 Py_XDECREF(pyFrame);
2860 MlirLocation callee = frames[0];
2865 MlirLocation caller = frames[count - 1];
2867 for (
int i = count - 2; i >= 1; i--)
2874maybeGetTracebackLocation(
const std::optional<PyLocation> &location) {
2875 if (location.has_value())
2876 return location.value();
2881 MlirLocation mlirLoc = tracebackToLocation(ctx.
get());
2883 return {ref, mlirLoc};
2894 nb::set_leak_warnings(
false);
2898 nb::enum_<MlirDiagnosticSeverity>(m,
"DiagnosticSeverity")
2904 nb::enum_<MlirWalkOrder>(m,
"WalkOrder")
2908 nb::enum_<MlirWalkResult>(m,
"WalkResult")
2916 nb::class_<PyDiagnostic>(m,
"Diagnostic")
2918 "Returns the severity of the diagnostic.")
2920 "Returns the location associated with the diagnostic.")
2922 "Returns the message text of the diagnostic.")
2924 "Returns a tuple of attached note diagnostics.")
2929 return nb::str(
"<Invalid Diagnostic>");
2932 "Returns the diagnostic message as a string.");
2934 nb::class_<PyDiagnostic::DiagnosticInfo>(m,
"DiagnosticInfo")
2940 "diag"_a,
"Creates a DiagnosticInfo from a Diagnostic.")
2942 "The severity level of the diagnostic.")
2944 "The location associated with the diagnostic.")
2946 "The message text of the diagnostic.")
2948 "List of attached note diagnostics.")
2952 "Returns the diagnostic message as a string.");
2954 nb::class_<PyDiagnosticHandler>(m,
"DiagnosticHandler")
2956 "Detaches the diagnostic handler from the context.")
2958 "Returns True if the handler is attached to a context.")
2960 "Returns True if an error was encountered during diagnostic "
2963 "Enters the diagnostic handler as a context manager.")
2965 nb::arg(
"exc_type").none(), nb::arg(
"exc_value").none(),
2966 nb::arg(
"traceback").none(),
2967 "Exits the diagnostic handler context manager.");
2970 nb::class_<PyThreadPool>(m,
"ThreadPool")
2973 "Creates a new thread pool with default concurrency.")
2975 "Returns the maximum number of threads in the pool.")
2977 "Returns the raw pointer to the LLVM thread pool as a string.");
2979 nb::class_<PyMlirContext>(m,
"Context")
2987 Creates a new MLIR context.
2989 The context is the top-level container for all MLIR objects. It owns the storage
2990 for types, attributes, locations, and other core IR objects. A context can be
2991 configured to allow or disallow unregistered dialects and can have dialects
2992 loaded on-demand.)")
2994 "Gets the number of live Context objects.")
2996 "_get_context_again",
2997 [](
PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
3001 "Gets another reference to the same context.")
3003 "Gets the number of live modules owned by this context.")
3005 "Gets a capsule wrapping the MlirContext.")
3008 "Creates a Context from a capsule wrapping MlirContext.")
3010 "Enters the context as a context manager.")
3012 nb::arg(
"exc_value").none(), nb::arg(
"traceback").none(),
3013 "Exits the context manager.")
3014 .def_prop_ro_static(
3017 -> std::optional<nb::typed<nb::object, PyMlirContext>> {
3021 return nb::cast(context);
3023 nb::sig(
"def current(/) -> Context | None"),
3024 "Gets the Context bound to the current thread or returns None if no "
3029 "Gets a container for accessing dialects by name.")
3032 "Alias for `dialects`.")
3034 "get_dialect_descriptor",
3037 self.
get(), {name.data(), name.size()});
3039 throw nb::value_error(
3040 (Twine(
"Dialect '") + name +
"' not found").str().c_str());
3044 nb::arg(
"dialect_name"),
3045 "Gets or loads a dialect by name, returning its descriptor object.")
3047 "allow_unregistered_dialects",
3054 "Controls whether unregistered dialects are allowed in this context.")
3056 nb::arg(
"callback"),
3057 "Attaches a diagnostic handler that will receive callbacks.")
3059 "enable_multithreading",
3065 Enables or disables multi-threading support in the context.
3068 enable: Whether to enable (True) or disable (False) multi-threading.
3080 Sets a custom thread pool for the context to use.
3083 pool: A ThreadPool object to use for parallel operations.
3086 Multi-threading is automatically disabled before setting the thread pool.)")
3092 "Gets the number of threads in the context's thread pool.")
3094 "_mlir_thread_pool_ptr",
3097 std::stringstream ss;
3101 "Gets the raw pointer to the LLVM thread pool as a string.")
3103 "is_registered_operation",
3108 nb::arg(
"operation_name"),
3110 Checks whether an operation with the given name is registered.
3113 operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
3116 True if the operation is registered, False otherwise.)")
3118 "append_dialect_registry",
3122 nb::arg(
"registry"),
3124 Appends the contents of a dialect registry to the context.
3127 registry: A DialectRegistry containing dialects to append.)")
3128 .def_prop_rw("emit_error_diagnostics",
3132 Controls whether error diagnostics are emitted to diagnostic handlers.
3134 By default, error diagnostics are captured and reported through MLIRError exceptions.)")
3136 "load_all_available_dialects",
3141 Loads all dialects available in the registry into the context.
3143 This eagerly loads all dialects that have been registered, making them
3144 immediately available for use.)");
3149 nb::class_<PyDialectDescriptor>(m,
"DialectDescriptor")
3156 "Returns the namespace of the dialect.")
3161 std::string repr(
"<DialectDescriptor ");
3166 nb::sig(
"def __repr__(self) -> str"),
3167 "Returns a string representation of the dialect descriptor.");
3172 nb::class_<PyDialects>(m,
"Dialects")
3176 MlirDialect dialect =
3178 nb::object descriptor =
3182 "Gets a dialect by name using subscript notation.")
3185 [=](
PyDialects &self, std::string attrName) {
3186 MlirDialect dialect =
3188 nb::object descriptor =
3192 "Gets a dialect by name using attribute notation.");
3197 nb::class_<PyDialect>(m,
"Dialect")
3198 .def(nb::init<nb::object>(), nb::arg(
"descriptor"),
3199 "Creates a Dialect from a DialectDescriptor.")
3202 "Returns the DialectDescriptor for this dialect.")
3205 [](
const nb::object &self) {
3206 auto clazz = self.attr(
"__class__");
3207 return nb::str(
"<Dialect ") +
3208 self.attr(
"descriptor").attr(
"namespace") +
3209 nb::str(
" (class ") + clazz.attr(
"__module__") +
3210 nb::str(
".") + clazz.attr(
"__name__") + nb::str(
")>");
3212 nb::sig(
"def __repr__(self) -> str"),
3213 "Returns a string representation of the dialect.");
3218 nb::class_<PyDialectRegistry>(m,
"DialectRegistry")
3220 "Gets a capsule wrapping the MlirDialectRegistry.")
3223 "Creates a DialectRegistry from a capsule wrapping "
3224 "`MlirDialectRegistry`.")
3225 .def(nb::init<>(),
"Creates a new empty dialect registry.");
3230 nb::class_<PyLocation>(m,
"Location")
3232 "Gets a capsule wrapping the MlirLocation.")
3234 "Creates a Location from a capsule wrapping MlirLocation.")
3236 "Enters the location as a context manager.")
3238 nb::arg(
"exc_value").none(), nb::arg(
"traceback").none(),
3239 "Exits the location context manager.")
3245 "Compares two locations for equality.")
3247 "__eq__", [](
PyLocation &self, nb::object other) {
return false; },
3248 "Compares location with non-location object (always returns False).")
3249 .def_prop_ro_static(
3251 [](nb::object & ) -> std::optional<PyLocation *> {
3254 return std::nullopt;
3258 nb::sig(
"def current(/) -> Location | None"),
3260 "Gets the Location bound to the current thread or raises ValueError.")
3267 nb::arg(
"context") = nb::none(),
3268 "Gets a Location representing an unknown location.")
3271 [](
PyLocation callee,
const std::vector<PyLocation> &frames,
3274 throw nb::value_error(
"No caller frames provided.");
3275 MlirLocation caller = frames.back().get();
3277 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
3282 nb::arg(
"callee"), nb::arg(
"frames"), nb::arg(
"context") = nb::none(),
3283 "Gets a Location representing a caller and callsite.")
3285 "Returns True if this location is a CallSiteLoc.")
3292 "Gets the callee location from a CallSiteLoc.")
3299 "Gets the caller location from a CallSiteLoc.")
3302 [](std::string filename,
int line,
int col,
3309 nb::arg(
"filename"), nb::arg(
"line"), nb::arg(
"col"),
3310 nb::arg(
"context") = nb::none(),
3311 "Gets a Location representing a file, line and column.")
3314 [](std::string filename,
int startLine,
int startCol,
int endLine,
3319 startLine, startCol, endLine, endCol));
3321 nb::arg(
"filename"), nb::arg(
"start_line"), nb::arg(
"start_col"),
3322 nb::arg(
"end_line"), nb::arg(
"end_col"),
3323 nb::arg(
"context") = nb::none(),
3324 "Gets a Location representing a file, line and column range.")
3326 "Returns True if this location is a FileLineColLoc.")
3329 [](MlirLocation loc) {
3333 "Gets the filename from a FileLineColLoc.")
3335 "Gets the start line number from a `FileLineColLoc`.")
3337 "Gets the start column number from a `FileLineColLoc`.")
3339 "Gets the end line number from a `FileLineColLoc`.")
3341 "Gets the end column number from a `FileLineColLoc`.")
3344 [](
const std::vector<PyLocation> &pyLocations,
3345 std::optional<PyAttribute> metadata,
3347 llvm::SmallVector<MlirLocation, 4> locations;
3348 locations.reserve(pyLocations.size());
3349 for (
auto &pyLocation : pyLocations)
3350 locations.push_back(pyLocation.get());
3352 context->
get(), locations.size(), locations.data(),
3353 metadata ? metadata->get() : MlirAttribute{0});
3354 return PyLocation(context->getRef(), location);
3356 nb::arg(
"locations"), nb::arg(
"metadata") = nb::none(),
3357 nb::arg(
"context") = nb::none(),
3358 "Gets a Location representing a fused location with optional "
3361 "Returns True if this location is a `FusedLoc`.")
3366 std::vector<MlirLocation> locations(numLocations);
3369 std::vector<PyLocation> pyLocations{};
3370 pyLocations.reserve(numLocations);
3371 for (
unsigned i = 0; i < numLocations; ++i)
3372 pyLocations.emplace_back(self.
getContext(), locations[i]);
3375 "Gets the list of locations from a `FusedLoc`.")
3378 [](std::string name, std::optional<PyLocation> childLoc,
3384 childLoc ? childLoc->get()
3387 nb::arg(
"name"), nb::arg(
"childLoc") = nb::none(),
3388 nb::arg(
"context") = nb::none(),
3389 "Gets a Location representing a named location with optional child "
3392 "Returns True if this location is a `NameLoc`.")
3395 [](MlirLocation loc) {
3398 "Gets the name string from a `NameLoc`.")
3405 "Gets the child location from a `NameLoc`.")
3412 nb::arg(
"attribute"), nb::arg(
"context") = nb::none(),
3413 "Gets a Location from a `LocationAttr`.")
3416 [](
PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
3419 "Context that owns the `Location`.")
3426 "Get the underlying `LocationAttr`.")
3434 Emits an error diagnostic at this location.
3437 message: The error message to emit.)")
3441 PyPrintAccumulator printAccum;
3444 return printAccum.
join();
3446 "Returns the assembly representation of the location.");
3451 nb::class_<PyModule>(m,
"Module", nb::is_weak_referenceable())
3453 "Gets a capsule wrapping the MlirModule.")
3456 Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
3458 This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
3459 prevent double-frees (of the underlying `mlir::Module`).)")
3462 Clears the internal MLIR module reference.
3464 This is used internally to prevent double-free when ownership is transferred
3465 via the C API capsule mechanism. Not intended for normal use.)")
3469 -> nb::typed<nb::object, PyModule> {
3473 if (mlirModuleIsNull(module))
3474 throw MLIRError(
"Unable to parse module assembly", errors.take());
3477 nb::arg(
"asm"), nb::arg(
"context") = nb::none(),
3482 -> nb::typed<nb::object, PyModule> {
3486 if (mlirModuleIsNull(module))
3487 throw MLIRError(
"Unable to parse module assembly", errors.take());
3490 nb::arg(
"asm"), nb::arg(
"context") = nb::none(),
3495 -> nb::typed<nb::object, PyModule> {
3499 if (mlirModuleIsNull(module))
3500 throw MLIRError(
"Unable to parse module assembly", errors.take());
3503 nb::arg(
"path"), nb::arg(
"context") = nb::none(),
3507 [](
const std::optional<PyLocation> &loc)
3508 -> nb::typed<nb::object, PyModule> {
3509 PyLocation pyLoc = maybeGetTracebackLocation(loc);
3513 nb::arg(
"loc") = nb::none(),
"Creates an empty module.")
3516 [](
PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
3519 "Context that created the `Module`.")
3522 [](
PyModule &self) -> nb::typed<nb::object, PyOperation> {
3528 "Accesses the module as an operation.")
3538 "Return the block for this module.")
3547 [](
const nb::object &self) {
3549 return self.attr(
"operation").attr(
"__str__")();
3551 nb::sig(
"def __str__(self) -> str"),
3553 Gets the assembly form of the operation with default options.
3555 If more advanced control over the assembly formatting or I/O options is needed,
3556 use the dedicated print or get_asm method, which supports keyword arguments to
3564 "other"_a,
"Compares two modules for equality.")
3568 "Returns the hash value of the module.");
3573 nb::class_<PyOperationBase>(m,
"_OperationBase")
3579 "Gets a capsule wrapping the `MlirOperation`.")
3586 "Compares two operations for equality.")
3590 "Compares operation with non-operation object (always returns "
3597 "Returns the hash value of the operation.")
3603 "Returns a dictionary-like map of operation attributes.")
3611 "Context that owns the operation.")
3617 MlirOperation operation = concreteOperation.
get();
3620 "Returns the fully qualified name of the operation.")
3626 "Returns the list of operation operands.")
3632 "Returns the list of operation regions.")
3638 "Returns the list of Operation results.")
3646 "Shortcut to get an op result if it has only one (throws an error "
3659 nb::for_getter(
"Returns the source location the operation was "
3660 "defined or derived from."),
3661 nb::for_setter(
"Sets the source location the operation was defined "
3662 "or derived from."))
3666 -> std::optional<nb::typed<nb::object, PyOperation>> {
3669 return parent->getObject();
3672 "Returns the parent operation, or `None` if at top level.")
3676 return self.
getAsm(
false,
3687 nb::sig(
"def __str__(self) -> str"),
3688 "Returns the assembly form of the operation.")
3690 nb::overload_cast<PyAsmState &, nb::object, bool>(
3692 nb::arg(
"state"), nb::arg(
"file") = nb::none(),
3693 nb::arg(
"binary") =
false,
3695 Prints the assembly form of the operation to a file like object.
3698 state: `AsmState` capturing the operation numbering and flags.
3699 file: Optional file like object to write to. Defaults to sys.stdout.
3700 binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
3702 nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
3703 bool,
bool,
bool,
bool,
bool,
bool, nb::object,
3706 nb::arg(
"large_elements_limit") = nb::none(),
3707 nb::arg(
"large_resource_limit") = nb::none(),
3708 nb::arg(
"enable_debug_info") =
false,
3709 nb::arg(
"pretty_debug_info") =
false,
3710 nb::arg(
"print_generic_op_form") =
false,
3711 nb::arg(
"use_local_scope") =
false,
3712 nb::arg(
"use_name_loc_as_prefix") =
false,
3713 nb::arg(
"assume_verified") =
false, nb::arg(
"file") = nb::none(),
3714 nb::arg(
"binary") =
false, nb::arg(
"skip_regions") =
false,
3716 Prints the assembly form of the operation to a file like object.
3719 large_elements_limit: Whether to elide elements attributes above this
3720 number of elements. Defaults to None (no limit).
3721 large_resource_limit: Whether to elide resource attributes above this
3722 number of characters. Defaults to None (no limit). If large_elements_limit
3723 is set and this is None, the behavior will be to use large_elements_limit
3724 as large_resource_limit.
3725 enable_debug_info: Whether to print debug/location information. Defaults
3727 pretty_debug_info: Whether to format debug information for easier reading
3728 by a human (warning: the result is unparseable). Defaults to False.
3729 print_generic_op_form: Whether to print the generic assembly forms of all
3730 ops. Defaults to False.
3731 use_local_scope: Whether to print in a way that is more optimized for
3732 multi-threaded access but may not be consistent with how the overall
3734 use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
3735 prefixes for the SSA identifiers. Defaults to False.
3736 assume_verified: By default, if not printing generic form, the verifier
3737 will be run and if it fails, generic form will be printed with a comment
3738 about failed verification. While a reasonable default for interactive use,
3739 for systematic use, it is often better for the caller to verify explicitly
3740 and report failures in a more robust fashion. Set this to True if doing this
3741 in order to avoid running a redundant verification. If the IR is actually
3742 invalid, behavior is undefined.
3743 file: The file like object to write to. Defaults to sys.stdout.
3744 binary: Whether to write bytes (True) or str (False). Defaults to False.
3745 skip_regions: Whether to skip printing regions. Defaults to False.)")
3747 nb::arg(
"desired_version") = nb::none(),
3749 Write the bytecode form of the operation to a file like object.
3752 file: The file like object to write to.
3753 desired_version: Optional version of bytecode to emit.
3755 The bytecode writer status.)")
3758 nb::arg(
"binary") =
false,
3759 nb::arg(
"large_elements_limit") = nb::none(),
3760 nb::arg(
"large_resource_limit") = nb::none(),
3761 nb::arg(
"enable_debug_info") =
false,
3762 nb::arg(
"pretty_debug_info") =
false,
3763 nb::arg(
"print_generic_op_form") =
false,
3764 nb::arg(
"use_local_scope") =
false,
3765 nb::arg(
"use_name_loc_as_prefix") =
false,
3766 nb::arg(
"assume_verified") =
false, nb::arg(
"skip_regions") =
false,
3768 Gets the assembly form of the operation with all options available.
3771 binary: Whether to return a bytes (True) or str (False) object. Defaults to
3773 ... others ...: See the print() method for common keyword arguments for
3774 configuring the printout.
3776 Either a bytes or str object, depending on the setting of the `binary`
3779 "Verify the operation. Raises MLIRError if verification fails, and "
3780 "returns true otherwise.")
3782 "Puts self immediately after the other operation in its parent "
3785 "Puts self immediately before the other operation in its parent "
3790 Checks if this operation is before another in the same block.
3793 other: Another operation in the same parent block.
3796 True if this operation is before `other` in the operation list of the parent block.)")
3800 const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
3803 nb::arg(
"ip") = nb::none(),
3805 Creates a deep copy of the operation.
3808 ip: Optional insertion point where the cloned operation should be inserted.
3809 If None, the current insertion point is used. If False, the operation
3813 A new Operation that is a clone of this operation.)")
3815 "detach_from_parent",
3820 throw nb::value_error(
"Detached operation has no parent.");
3825 "Detaches the operation from its parent block.")
3833 "Reports if the operation is attached to its parent block.")
3837 Erases the operation and frees its memory.
3840 After erasing, any Python references to the operation become invalid.)")
3844 nb::sig(
"def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
3847 Walks the operation tree with a callback function.
3850 callback: A callable that takes an Operation and returns a WalkResult.
3851 walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
3853 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3856 [](std::string_view name,
3857 std::optional<std::vector<PyType *>> results,
3858 std::optional<std::vector<PyValue *>> operands,
3859 std::optional<nb::dict> attributes,
3860 std::optional<std::vector<PyBlock *>> successors,
int regions,
3861 const std::optional<PyLocation> &location,
3862 const nb::object &maybeIp,
3863 bool inferType) -> nb::typed<nb::object, PyOperation> {
3865 llvm::SmallVector<MlirValue, 4> mlirOperands;
3867 mlirOperands.reserve(operands->size());
3868 for (
PyValue *operand : *operands) {
3870 throw nb::value_error(
"operand value cannot be None");
3871 mlirOperands.push_back(operand->get());
3875 PyLocation pyLoc = maybeGetTracebackLocation(location);
3877 successors, regions, pyLoc, maybeIp,
3880 nb::arg(
"name"), nb::arg(
"results") = nb::none(),
3881 nb::arg(
"operands") = nb::none(), nb::arg(
"attributes") = nb::none(),
3882 nb::arg(
"successors") = nb::none(), nb::arg(
"regions") = 0,
3883 nb::arg(
"loc") = nb::none(), nb::arg(
"ip") = nb::none(),
3884 nb::arg(
"infer_type") =
false,
3886 Creates a new operation.
3889 name: Operation name (e.g. `dialect.operation`).
3890 results: Optional sequence of Type representing op result types.
3891 operands: Optional operands of the operation.
3892 attributes: Optional Dict of {str: Attribute}.
3893 successors: Optional List of Block for the operation's successors.
3894 regions: Number of regions to create (default = 0).
3895 location: Optional Location object (defaults to resolve from context manager).
3896 ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
3897 infer_type: Whether to infer result types (default = False).
3899 A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
3902 [](
const std::string &sourceStr,
const std::string &sourceName,
3904 -> nb::typed<nb::object, PyOpView> {
3908 nb::arg(
"source"), nb::kw_only(), nb::arg(
"source_name") =
"",
3909 nb::arg(
"context") = nb::none(),
3910 "Parses an operation. Supports both text assembly format and binary "
3913 "Gets a capsule wrapping the MlirOperation.")
3916 "Creates an Operation from a capsule wrapping MlirOperation.")
3919 [](nb::object self) -> nb::typed<nb::object, PyOperation> {
3922 "Returns self (the operation).")
3925 [](
PyOperation &self) -> nb::typed<nb::object, PyOpView> {
3929 Returns an OpView of this operation.
3932 If the operation has a registered and loaded dialect then this OpView will
3933 be concrete wrapper class.)")
3935 "Returns the block containing this operation.")
3941 "Returns the list of Operation successors.")
3943 "replace_uses_of_with",
3948 "Replaces uses of the 'of' value with the 'with' value inside the "
3951 "Invalidate the operation.");
3954 nb::class_<PyOpView, PyOperationBase>(m,
"OpView")
3955 .def(nb::init<nb::typed<nb::object, PyOperation>>(),
3956 nb::arg(
"operation"))
3959 [](
PyOpView *self, std::string_view name,
3960 std::tuple<int, bool> opRegionSpec,
3961 nb::object operandSegmentSpecObj,
3962 nb::object resultSegmentSpecObj,
3963 std::optional<nb::list> resultTypeList, nb::list operandList,
3964 std::optional<nb::dict> attributes,
3965 std::optional<std::vector<PyBlock *>> successors,
3966 std::optional<int> regions,
3967 const std::optional<PyLocation> &location,
3968 const nb::object &maybeIp) {
3969 PyLocation pyLoc = maybeGetTracebackLocation(location);
3971 name, opRegionSpec, operandSegmentSpecObj,
3972 resultSegmentSpecObj, resultTypeList, operandList,
3973 attributes, successors, regions, pyLoc, maybeIp));
3975 nb::arg(
"name"), nb::arg(
"opRegionSpec"),
3976 nb::arg(
"operandSegmentSpecObj") = nb::none(),
3977 nb::arg(
"resultSegmentSpecObj") = nb::none(),
3978 nb::arg(
"results") = nb::none(), nb::arg(
"operands") = nb::none(),
3979 nb::arg(
"attributes") = nb::none(),
3980 nb::arg(
"successors") = nb::none(),
3981 nb::arg(
"regions") = nb::none(), nb::arg(
"loc") = nb::none(),
3982 nb::arg(
"ip") = nb::none())
3985 [](
PyOpView &self) -> nb::typed<nb::object, PyOperation> {
3988 .def_prop_ro(
"opview",
3989 [](nb::object self) -> nb::typed<nb::object, PyOpView> {
4000 "Returns the list of Operation successors.")
4004 "Invalidate the operation.");
4005 opViewClass.attr(
"_ODS_REGIONS") = nb::make_tuple(0,
true);
4006 opViewClass.attr(
"_ODS_OPERAND_SEGMENTS") = nb::none();
4007 opViewClass.attr(
"_ODS_RESULT_SEGMENTS") = nb::none();
4012 [](nb::handle cls, std::optional<nb::list> resultTypeList,
4013 nb::list operandList, std::optional<nb::dict> attributes,
4014 std::optional<std::vector<PyBlock *>> successors,
4015 std::optional<int> regions, std::optional<PyLocation> location,
4016 const nb::object &maybeIp) {
4017 std::string name = nb::cast<std::string>(cls.attr(
"OPERATION_NAME"));
4018 std::tuple<int, bool> opRegionSpec =
4019 nb::cast<std::tuple<int, bool>>(cls.attr(
"_ODS_REGIONS"));
4020 nb::object operandSegmentSpec = cls.attr(
"_ODS_OPERAND_SEGMENTS");
4021 nb::object resultSegmentSpec = cls.attr(
"_ODS_RESULT_SEGMENTS");
4022 PyLocation pyLoc = maybeGetTracebackLocation(location);
4024 resultSegmentSpec, resultTypeList,
4025 operandList, attributes, successors,
4026 regions, pyLoc, maybeIp);
4028 nb::arg(
"cls"), nb::arg(
"results") = nb::none(),
4029 nb::arg(
"operands") = nb::none(), nb::arg(
"attributes") = nb::none(),
4030 nb::arg(
"successors") = nb::none(), nb::arg(
"regions") = nb::none(),
4031 nb::arg(
"loc") = nb::none(), nb::arg(
"ip") = nb::none(),
4032 "Builds a specific, generated OpView based on class level attributes.");
4034 [](
const nb::object &cls,
const std::string &sourceStr,
4035 const std::string &sourceName,
4045 std::string clsOpName =
4046 nb::cast<std::string>(cls.attr(
"OPERATION_NAME"));
4049 std::string_view parsedOpName(identifier.
data, identifier.
length);
4050 if (clsOpName != parsedOpName)
4051 throw MLIRError(Twine(
"Expected a '") + clsOpName +
"' op, got: '" +
4052 parsedOpName +
"'");
4055 nb::arg(
"cls"), nb::arg(
"source"), nb::kw_only(),
4056 nb::arg(
"source_name") =
"", nb::arg(
"context") = nb::none(),
4057 "Parses a specific, generated OpView based on class level attributes.");
4062 nb::class_<PyRegion>(m,
"Region")
4068 "Returns a forward-optimized sequence of blocks.")
4071 [](
PyRegion &self) -> nb::typed<nb::object, PyOpView> {
4074 "Returns the operation owning this region.")
4082 "Iterates over blocks in the region.")
4086 return self.
get().ptr == other.
get().ptr;
4088 "Compares two regions for pointer equality.")
4090 "__eq__", [](
PyRegion &self, nb::object &other) {
return false; },
4091 "Compares region with non-region object (always returns False).");
4096 nb::class_<PyBlock>(m,
"Block")
4098 "Gets a capsule wrapping the MlirBlock.")
4101 [](
PyBlock &self) -> nb::typed<nb::object, PyOpView> {
4104 "Returns the owning operation of this block.")
4111 "Returns the owning region of this block.")
4117 "Returns a list of block arguments.")
4126 Appends an argument of the specified type to the block.
4129 type: The type of the argument to add.
4130 loc: The source location for the argument.
4133 The newly added block argument.)")
4136 [](
PyBlock &self,
unsigned index) {
4141 Erases the argument at the specified index.
4144 index: The index of the argument to erase.)")
4150 "Returns a forward-optimized sequence of operations.")
4153 [](
PyRegion &parent,
const nb::sequence &pyArgTypes,
4154 const std::optional<nb::sequence> &pyArgLocs) {
4156 MlirBlock block =
createBlock(pyArgTypes, pyArgLocs);
4160 nb::arg(
"parent"), nb::arg(
"arg_types") = nb::list(),
4161 nb::arg(
"arg_locs") = std::nullopt,
4162 "Creates and returns a new Block at the beginning of the given "
4163 "region (with given argument types and locations).")
4167 MlirBlock
b = self.
get();
4174 Appends this block to a region.
4176 Transfers ownership if the block is currently owned by another region.
4179 region: The region to append the block to.)")
4182 [](
PyBlock &self,
const nb::args &pyArgTypes,
4183 const std::optional<nb::sequence> &pyArgLocs) {
4186 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4191 nb::arg(
"arg_types"), nb::kw_only(),
4192 nb::arg(
"arg_locs") = std::nullopt,
4193 "Creates and returns a new Block before this block "
4194 "(with given argument types and locations).")
4197 [](
PyBlock &self,
const nb::args &pyArgTypes,
4198 const std::optional<nb::sequence> &pyArgLocs) {
4201 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
4206 nb::arg(
"arg_types"), nb::kw_only(),
4207 nb::arg(
"arg_locs") = std::nullopt,
4208 "Creates and returns a new Block after this block "
4209 "(with given argument types and locations).")
4214 MlirOperation firstOperation =
4219 "Iterates over operations in the block.")
4223 return self.
get().ptr == other.
get().ptr;
4225 "Compares two blocks for pointer equality.")
4227 "__eq__", [](
PyBlock &self, nb::object &other) {
return false; },
4228 "Compares block with non-block object (always returns False).")
4232 return static_cast<size_t>(llvm::hash_value(self.
get().ptr));
4234 "Returns the hash value of the block.")
4239 PyPrintAccumulator printAccum;
4242 return printAccum.
join();
4244 "Returns the assembly form of the block.")
4256 nb::arg(
"operation"),
4258 Appends an operation to this block.
4260 If the operation is currently in another block, it will be moved.
4263 operation: The operation to append to the block.)")
4269 "Returns the list of Block successors.")
4275 "Returns the list of Block predecessors.");
4281 nb::class_<PyInsertionPoint>(m,
"InsertionPoint")
4282 .def(nb::init<PyBlock &>(), nb::arg(
"block"),
4283 "Inserts after the last operation but still inside the block.")
4285 "Enters the insertion point as a context manager.")
4287 nb::arg(
"exc_type").none(), nb::arg(
"exc_value").none(),
4288 nb::arg(
"traceback").none(),
4289 "Exits the insertion point context manager.")
4290 .def_prop_ro_static(
4295 throw nb::value_error(
"No current InsertionPoint");
4298 nb::sig(
"def current(/) -> InsertionPoint"),
4299 "Gets the InsertionPoint bound to the current thread or raises "
4300 "ValueError if none has been set.")
4301 .def(nb::init<PyOperationBase &>(), nb::arg(
"beforeOperation"),
4302 "Inserts before a referenced operation.")
4306 Creates an insertion point at the beginning of a block.
4309 block: The block at whose beginning operations should be inserted.
4312 An InsertionPoint at the block's beginning.)")
4316 Creates an insertion point before a block's terminator.
4319 block: The block whose terminator to insert before.
4322 An InsertionPoint before the terminator.
4325 ValueError: If the block has no terminator.)")
4328 Creates an insertion point immediately after an operation.
4331 operation: The operation after which to insert.
4334 An InsertionPoint after the operation.)")
4337 Inserts an operation at this insertion point.
4340 operation: The operation to insert.)")
4343 "Returns the block that this `InsertionPoint` points to.")
4347 -> std::optional<nb::typed<nb::object, PyOperation>> {
4350 return refOperation->getObject();
4353 "The reference operation before which new operations are "
4354 "inserted, or None if the insertion point is at the end of "
4360 nb::class_<PyAttribute>(m,
"Attribute")
4363 .def(nb::init<PyAttribute &>(), nb::arg(
"cast_from_type"),
4364 "Casts the passed attribute to the generic `Attribute`.")
4366 "Gets a capsule wrapping the MlirAttribute.")
4369 "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
4373 -> nb::typed<nb::object, PyAttribute> {
4377 if (mlirAttributeIsNull(attr))
4378 throw MLIRError(
"Unable to parse attribute", errors.take());
4381 nb::arg(
"asm"), nb::arg(
"context") = nb::none(),
4382 "Parses an attribute from an assembly form. Raises an `MLIRError` on "
4386 [](
PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
4389 "Context that owns the `Attribute`.")
4392 [](
PyAttribute &self) -> nb::typed<nb::object, PyType> {
4396 "Returns the type of the `Attribute`.")
4402 nb::keep_alive<0, 1>(),
4404 Binds a name to the attribute, creating a `NamedAttribute`.
4407 name: The name to bind to the `Attribute`.
4410 A `NamedAttribute` with the given name and this attribute.)")
4414 "Compares two attributes for equality.")
4416 "__eq__", [](
PyAttribute &self, nb::object &other) {
return false; },
4417 "Compares attribute with non-attribute object (always returns "
4422 return static_cast<size_t>(llvm::hash_value(self.
get().ptr));
4424 "Returns the hash value of the attribute.")
4431 PyPrintAccumulator printAccum;
4434 return printAccum.
join();
4436 "Returns the assembly form of the Attribute.")
4445 PyPrintAccumulator printAccum;
4446 printAccum.
parts.append(
"Attribute(");
4449 printAccum.
parts.append(
")");
4450 return printAccum.
join();
4452 "Returns a string representation of the attribute.")
4458 "mlirTypeID was expected to be non-null.");
4461 "Returns the `TypeID` of the attribute.")
4464 [](
PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4467 "Downcasts the attribute to a more specific attribute if possible.");
4472 nb::class_<PyNamedAttribute>(m,
"NamedAttribute")
4476 PyPrintAccumulator printAccum;
4477 printAccum.
parts.append(
"NamedAttribute(");
4478 printAccum.
parts.append(
4481 printAccum.
parts.append(
"=");
4485 printAccum.
parts.append(
")");
4486 return printAccum.
join();
4488 "Returns a string representation of the named attribute.")
4494 "The name of the `NamedAttribute` binding.")
4498 nb::keep_alive<0, 1>(), nb::sig(
"def attr(self) -> Attribute"),
4499 "The underlying generic attribute of the `NamedAttribute` binding.");
4504 nb::class_<PyType>(m,
"Type")
4507 .def(nb::init<PyType &>(), nb::arg(
"cast_from_type"),
4508 "Casts the passed type to the generic `Type`.")
4510 "Gets a capsule wrapping the `MlirType`.")
4512 "Creates a Type from a capsule wrapping `MlirType`.")
4515 [](std::string typeSpec,
4521 throw MLIRError(
"Unable to parse type", errors.take());
4524 nb::arg(
"asm"), nb::arg(
"context") = nb::none(),
4526 Parses the assembly form of a type.
4528 Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
4530 See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
4533 [](
PyType &self) -> nb::typed<nb::object, PyMlirContext> {
4536 "Context that owns the `Type`.")
4538 "__eq__", [](
PyType &self,
PyType &other) {
return self == other; },
4539 "Compares two types for equality.")
4541 "__eq__", [](
PyType &self, nb::object &other) {
return false; },
4542 nb::arg(
"other").none(),
4543 "Compares type with non-type object (always returns False).")
4547 return static_cast<size_t>(llvm::hash_value(self.
get().ptr));
4549 "Returns the hash value of the `Type`.")
4555 PyPrintAccumulator printAccum;
4558 return printAccum.
join();
4560 "Returns the assembly form of the `Type`.")
4568 PyPrintAccumulator printAccum;
4569 printAccum.
parts.append(
"Type(");
4572 printAccum.
parts.append(
")");
4573 return printAccum.
join();
4575 "Returns a string representation of the `Type`.")
4578 [](
PyType &self) -> nb::typed<nb::object, PyType> {
4581 "Downcasts the Type to a more specific `Type` if possible.")
4588 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
4589 throw nb::value_error(
4590 (origRepr + llvm::Twine(
" has no typeid.")).str().c_str());
4592 "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
4599 nb::class_<PyTypeID>(m,
"TypeID")
4601 "Gets a capsule wrapping the `MlirTypeID`.")
4603 "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
4610 "Compares two `TypeID`s for equality.")
4613 [](
PyTypeID &self,
const nb::object &other) {
return false; },
4614 "Compares TypeID with non-TypeID object (always returns False).")
4623 "Returns the hash value of the `TypeID`.");
4628 m.attr(
"_T") = nb::type_var(
"_T", nb::arg(
"bound") = m.attr(
"Type"));
4630 nb::class_<PyValue>(m,
"Value", nb::is_generic(),
4631 nb::sig(
"class Value(Generic[_T])"))
4632 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg(
"value"),
4633 "Creates a Value reference from another `Value`.")
4635 "Gets a capsule wrapping the `MlirValue`.")
4637 "Creates a `Value` from a capsule wrapping `MlirValue`.")
4640 [](
PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
4643 "Context in which the value lives.")
4649 [](
PyValue &self) -> nb::object {
4650 MlirValue v = self.
get();
4654 "expected the owner of the value in Python to match "
4665 assert(
false &&
"Value must be a block argument or an op result");
4668 "Returns the owner of the value (`Operation` for results, `Block` "
4676 "Returns an iterator over uses of this value.")
4680 return self.
get().ptr == other.
get().ptr;
4682 "Compares two values for pointer equality.")
4684 "__eq__", [](
PyValue &self, nb::object other) {
return false; },
4685 "Compares value with non-value object (always returns False).")
4689 return static_cast<size_t>(llvm::hash_value(self.
get().ptr));
4691 "Returns the hash value of the value.")
4695 PyPrintAccumulator printAccum;
4696 printAccum.
parts.append(
"Value(");
4699 printAccum.
parts.append(
")");
4700 return printAccum.
join();
4703 Returns the string form of the value.
4705 If the value is a block argument, this is the assembly form of its type and the
4706 position in the argument list. If the value is an operation result, this is
4707 equivalent to printing the operation that produced it.
4711 [](
PyValue &self,
bool useLocalScope,
bool useNameLocAsPrefix) {
4712 PyPrintAccumulator printAccum;
4716 if (useNameLocAsPrefix)
4718 MlirAsmState valueState =
4725 return printAccum.
join();
4727 nb::arg(
"use_local_scope") =
false,
4728 nb::arg(
"use_name_loc_as_prefix") =
false,
4730 Returns the string form of value as an operand.
4733 use_local_scope: Whether to use local scope for naming.
4734 use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
4737 The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
4741 PyPrintAccumulator printAccum;
4742 MlirAsmState valueState = state.
get();
4746 return printAccum.
join();
4749 "Returns the string form of value as an operand (i.e., the ValueID).")
4752 [](
PyValue &self) -> nb::typed<nb::object, PyType> {
4757 "Returns the type of the value.")
4763 nb::arg(
"type"),
"Sets the type of the value.",
4764 nb::sig(
"def set_type(self, type: _T)"))
4766 "replace_all_uses_with",
4770 "Replace all uses of value with the new value, updating anything in "
4771 "the IR that uses `self` to use the other value instead.")
4773 "replace_all_uses_except",
4775 MlirOperation exceptedUser = exception.
get();
4778 nb::arg(
"with_"), nb::arg(
"exceptions"),
4781 "replace_all_uses_except",
4784 llvm::SmallVector<MlirOperation> exceptionOps;
4785 for (nb::handle exception : exceptions) {
4786 exceptionOps.push_back(nb::cast<PyOperation &>(exception).
get());
4790 self, with,
static_cast<intptr_t
>(exceptionOps.size()),
4791 exceptionOps.data());
4793 nb::arg(
"with_"), nb::arg(
"exceptions"),
4796 "replace_all_uses_except",
4798 MlirOperation exceptedUser = exception.
get();
4801 nb::arg(
"with_"), nb::arg(
"exceptions"),
4804 "replace_all_uses_except",
4806 std::vector<PyOperation> &exceptions) {
4808 llvm::SmallVector<MlirOperation> exceptionOps;
4810 exceptionOps.push_back(exception);
4812 self, with,
static_cast<intptr_t
>(exceptionOps.size()),
4813 exceptionOps.data());
4815 nb::arg(
"with_"), nb::arg(
"exceptions"),
4819 [](
PyValue &self) -> nb::typed<nb::object, PyValue> {
4822 "Downcasts the `Value` to a more specific kind if possible.")
4825 [](MlirValue self) {
4830 "Returns the source location of the value.");
4832 PyBlockArgument::bind(m);
4833 PyOpResult::bind(m);
4834 PyOpOperand::bind(m);
4836 nb::class_<PyAsmState>(m,
"AsmState")
4837 .def(nb::init<PyValue &, bool>(), nb::arg(
"value"),
4838 nb::arg(
"use_local_scope") =
false,
4840 Creates an `AsmState` for consistent SSA value naming.
4843 value: The value to create state for.
4844 use_local_scope: Whether to use local scope for naming.)")
4845 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4846 nb::arg(
"use_local_scope") =
false,
4848 Creates an AsmState for consistent SSA value naming.
4851 op: The operation to create state for.
4852 use_local_scope: Whether to use local scope for naming.)");
4857 nb::class_<PySymbolTable>(m,
"SymbolTable")
4858 .def(nb::init<PyOperationBase &>(),
4860 Creates a symbol table for an operation.
4863 operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
4866 TypeError: If the operation is not a symbol table.)")
4870 const std::string &name) -> nb::typed<nb::object, PyOpView> {
4874 Looks up a symbol by name in the symbol table.
4877 name: The name of the symbol to look up.
4880 The operation defining the symbol.
4883 KeyError: If the symbol is not found.)")
4886 Inserts a symbol operation into the symbol table.
4889 operation: An operation with a symbol name to insert.
4892 The symbol name attribute of the inserted operation.
4895 ValueError: If the operation does not have a symbol name.)")
4898 Erases a symbol operation from the symbol table.
4901 operation: The symbol operation to erase.
4904 The operation is also erased from the IR and invalidated.)")
4906 "Deletes a symbol by name from the symbol table.")
4913 "Checks if a symbol with the given name exists in the table.")
4916 nb::arg(
"symbol"), nb::arg(
"name"),
4917 "Sets the symbol name for a symbol operation.")
4920 "Gets the symbol name from a symbol operation.")
4923 "Gets the visibility attribute of a symbol operation.")
4925 nb::arg(
"symbol"), nb::arg(
"visibility"),
4926 "Sets the visibility attribute of a symbol operation.")
4927 .def_static(
"replace_all_symbol_uses",
4929 nb::arg(
"new_symbol"), nb::arg(
"from_op"),
4930 "Replaces all uses of a symbol with a new symbol name within "
4931 "the given operation.")
4933 nb::arg(
"from_op"), nb::arg(
"all_sym_uses_visible"),
4934 nb::arg(
"callback"),
4935 "Walks symbol tables starting from an operation with a "
4936 "callback function.");
4939 PyBlockArgumentList::bind(m);
4940 PyBlockIterator::bind(m);
4941 PyBlockList::bind(m);
4942 PyBlockSuccessors::bind(m);
4943 PyBlockPredecessors::bind(m);
4944 PyOperationIterator::bind(m);
4945 PyOperationList::bind(m);
4946 PyOpAttributeMap::bind(m);
4947 PyOpOperandIterator::bind(m);
4948 PyOpOperandList::bind(m);
4950 PyOpSuccessors::bind(m);
4951 PyRegionIterator::bind(m);
4952 PyRegionList::bind(m);
4960 nb::register_exception_translator([](
const std::exception_ptr &p,
4966 std::rethrow_exception(p);
4970 PyErr_SetObject(PyExc_Exception, obj.ptr());
void mlirSetGlobalDebugTypes(const char **types, intptr_t n)
MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type)
Sets the current debug type, similarly to -debug-only=type in the command-line tools.
MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled()
Retuns true if the global debugging flag is set, false otherwise.
MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable)
Sets the global debugging flag.
static const char kDumpDocstring[]
static MlirStringRef toMlirStringRef(const std::string &s)
static const char kModuleParseDocstring[]
static std::vector< nb::typed< nb::object, PyType > > getValueTypes(Container &container, PyMlirContextRef &context)
Returns the list of types of the values held by container.
static nb::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
static MlirValue getUniqueResult(MlirOperation operation)
#define _Py_CAST(type, expr)
static MlirValue getOpResultOrValue(nb::handle operand)
static void maybeInsertOperation(PyOperationRef &op, const nb::object &maybeIp)
static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor)
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
static void populateResultTypes(StringRef name, nb::list resultTypeList, const nb::object &resultSegmentSpecObj, std::vector< int32_t > &resultSegmentLengths, std::vector< PyType * > &resultTypes)
static const char kValueReplaceAllUsesExceptDocstring[]
MlirContext mlirModuleGetContext(MlirModule module)
size_t mlirModuleHashValue(MlirModule mod)
intptr_t mlirBlockGetNumPredecessors(MlirBlock block)
MlirIdentifier mlirOperationGetName(MlirOperation op)
bool mlirValueIsABlockArgument(MlirValue value)
intptr_t mlirOperationGetNumRegions(MlirOperation op)
MlirBlock mlirOperationGetBlock(MlirOperation op)
void mlirBlockArgumentSetType(MlirValue value, MlirType type)
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes)
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos)
MlirModule mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName)
MlirAsmState mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags)
intptr_t mlirOperationGetNumResults(MlirOperation op)
void mlirOperationDestroy(MlirOperation op)
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
MlirType mlirValueGetType(MlirValue value)
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData)
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
bool mlirModuleEqual(MlirModule lhs, MlirModule rhs)
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block)
MlirOperation mlirOperationGetNextInBlock(MlirOperation op)
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
MlirOperation mlirModuleGetOperation(MlirModule module)
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value)
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos)
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2)
bool mlirOperationEqual(MlirOperation op, MlirOperation other)
void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags)
void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config)
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos)
void mlirModuleDestroy(MlirModule module)
MlirModule mlirModuleCreateEmpty(MlirLocation location)
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
MlirOperation mlirOperationGetParentOperation(MlirOperation op)
void mlirValueSetType(MlirValue value, MlirType type)
intptr_t mlirOperationGetNumSuccessors(MlirOperation op)
MlirDialect mlirAttributeGetDialect(MlirAttribute attr)
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr)
void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue)
MlirOperation mlirOpResultGetOwner(MlirValue value)
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module)
size_t mlirOperationHashValue(MlirOperation op)
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, MlirType const *results)
MlirOperation mlirOperationClone(MlirOperation op)
MlirBlock mlirBlockArgumentGetOwner(MlirValue value)
void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc)
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos)
MlirLocation mlirOperationGetLocation(MlirOperation op)
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name)
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr)
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, MlirRegion const *regions)
void mlirOperationSetLocation(MlirOperation op, MlirLocation loc)
MlirType mlirAttributeGetType(MlirAttribute attribute)
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name)
bool mlirValueIsAOpResult(MlirValue value)
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos)
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos)
MlirOperation mlirOperationCreate(MlirOperationState *state)
void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, int64_t version)
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr)
intptr_t mlirBlockGetNumSuccessors(MlirBlock block)
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos)
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
void mlirValueDump(MlirValue value)
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData)
MlirBlock mlirModuleGetBody(MlirModule module)
MlirOperation mlirOperationCreateParse(MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName)
void mlirAsmStateDestroy(MlirAsmState state)
Destroys printing flags created with mlirAsmStateCreate.
MlirContext mlirOperationGetContext(MlirOperation op)
intptr_t mlirOpResultGetResultNumber(MlirValue value)
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, MlirBlock const *successors)
MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate()
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags)
void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags)
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, MlirValue const *operands)
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc)
intptr_t mlirOperationGetNumOperands(MlirOperation op)
void mlirTypeDump(MlirType type)
intptr_t mlirOperationGetNumAttributes(MlirOperation op)
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
static PyObject * mlirPythonContextToCapsule(MlirContext context)
Creates a capsule object encapsulating the raw C-API MlirContext.
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
static PyObject * mlirPythonBlockToCapsule(MlirBlock block)
Creates a capsule object encapsulating the raw C-API MlirBlock.
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static std::string diag(const llvm::Value &value)
A list of operation results.
static void bindDerived(ClassTy &c)
PyOpResultList(PyOperationRef operation, intptr_t startIndex=0, intptr_t length=-1, intptr_t step=1)
static constexpr const char * pyClassName
Sliceable< PyOpResultList, PyOpResult > SliceableT
PyOperationRef & getOperation()
Python wrapper for MlirOpResult.
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
Accumulates into a file, either writing text (default) or binary.
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
static void bind(nanobind::module_ &m)
nanobind::class_< PyOpResultList > ClassTy
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
Base class for all objects that directly or indirectly depend on an MlirContext.
PyMlirContextRef & getContext()
Accesses the context reference.
BaseContextObject(PyMlirContextRef ref)
static PyLocation & resolve()
Used in function arguments when None should resolve to the current context manager set instance.
static PyMlirContext & resolve()
ReferrentTy * get() const
Wrapper around an MlirAsmState.
Wrapper around the generic MlirAttribute.
static PyAttribute createFromCapsule(const nanobind::object &capsule)
Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirAttribute.
nanobind::object maybeDownCast()
MlirAttribute get() const
bool operator==(const PyAttribute &other) const
Wrapper around an MlirBlock.
PyOperationRef & getParentOperation()
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirBlock.
Represents a diagnostic handler attached to the context.
PyDiagnosticHandler(MlirContext context, nanobind::object callback)
nanobind::object contextEnter()
void detach()
Detaches the handler. Does nothing if not attached.
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
Python class mirroring the C MlirDiagnostic struct.
nanobind::tuple getNotes()
nanobind::str getMessage()
PyDiagnostic(MlirDiagnostic diagnostic)
MlirDiagnosticSeverity getSeverity()
Wrapper around an MlirDialect.
Wrapper around an MlirDialectRegistry.
nanobind::object getCapsule()
static PyDialectRegistry createFromCapsule(nanobind::object capsule)
User-level dialect object.
nanobind::object getDescriptor()
User-level object for accessing dialects with dotted syntax such as: ctx.dialect.std.
MlirDialect getDialectForKey(const std::string &key, bool attrError)
size_t locTracebackFramesLimit()
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
TracebackLoc & getTracebackLoc()
static PyGlobals & get()
Most code should get the globals via this static accessor.
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
An insertion point maintains a pointer to a Block and a reference operation.
static PyInsertionPoint atBlockTerminator(PyBlock &block)
Shortcut to create an insertion point before the block terminator.
static PyInsertionPoint atBlockBegin(PyBlock &block)
Shortcut to create an insertion point at the beginning of the block.
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
static PyInsertionPoint after(PyOperationBase &op)
Shortcut to create an insertion point to the node after the specified operation.
PyInsertionPoint(const PyBlock &block)
Creates an insertion point positioned after the last operation in the block, but still inside the blo...
void insert(PyOperationBase &operationBase)
Inserts an operation.
static nanobind::object contextEnter(nanobind::object insertionPoint)
Enter and exit the context manager.
std::optional< PyOperationRef > & getRefOperation()
Wrapper around an MlirLocation.
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirLocation.
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
static PyLocation createFromCapsule(nanobind::object capsule)
Creates a PyLocation from the MlirLocation wrapped by a capsule.
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
static nanobind::object contextEnter(nanobind::object location)
Enter and exit the context manager.
MlirContext get()
Accesses the underlying MlirContext.
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
static size_t getLiveCount()
Gets the count of live context objects. Used for testing.
size_t getLiveModuleCount()
Gets the count of live modules associated with this context.
nanobind::object attachDiagnosticHandler(nanobind::object callback)
Attaches a Python callback as a diagnostic handler, returning a registration object (internally a PyD...
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirContext.
void contextExit(const nanobind::object &excType, const nanobind::object &excVal, const nanobind::object &excTb)
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyMlirContext from the MlirContext wrapped by a capsule.
static nanobind::object contextEnter(nanobind::object context)
Enter and exit the context manager.
void setEmitErrorDiagnostics(bool value)
Controls whether error diagnostics should be propagated to diagnostic handlers, instead of being capt...
bool getEmitErrorDiagnostics()
MlirModule get()
Gets the backing MlirModule.
static PyModuleRef forModule(MlirModule module)
Returns a PyModule reference for the given MlirModule.
static nanobind::object createFromCapsule(nanobind::object capsule)
Creates a PyModule from the MlirModule wrapped by a capsule.
PyModuleRef getRef()
Gets a strong reference to this module.
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirModule.
PyModule(PyModule &)=delete
Represents a Python MlirNamedAttr, carrying an optional owned name.
PyNamedAttribute(MlirAttribute attr, std::string ownedName)
Constructs a PyNamedAttr that retains an owned name.
MlirNamedAttribute namedAttr
Template for a reference to a concrete type which captures a python reference to its underlying pytho...
nanobind::object getObject()
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for providing more instance-sp...
PyOperation & getOperation() override
Each must provide access to the raw Operation.
static nanobind::object constructDerived(const nanobind::object &cls, const nanobind::object &operation)
Construct an instance of a class deriving from OpView, bypassing its __init__ method.
static nanobind::object buildGeneric(std::string_view name, std::tuple< int, bool > opRegionSpec, nanobind::object operandSegmentSpecObj, nanobind::object resultSegmentSpecObj, std::optional< nanobind::list > resultTypeList, nanobind::list operandList, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, std::optional< int > regions, PyLocation &location, const nanobind::object &maybeIp)
nanobind::object getOperationObject()
PyOpView(const nanobind::object &operationObject)
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
void walk(std::function< MlirWalkResult(MlirOperation)> callback, MlirWalkOrder walkOrder)
bool isBeforeInBlock(PyOperationBase &other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
nanobind::object getAsm(bool binary, std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions)
void moveAfter(PyOperationBase &other)
Moves the operation before or after the other operation.
void print(std::optional< int64_t > largeElementsLimit, std::optional< int64_t > largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions)
Implements the bound 'print' method and helps with others.
void writeBytecode(const nanobind::object &fileObject, std::optional< int64_t > bytecodeVersion)
void moveBefore(PyOperationBase &other)
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
bool verify()
Verify the operation.
void detachFromParent()
Detaches the operation from its parent block and updates its state accordingly.
PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
void erase()
Erases the underlying MlirOperation, removes its pointer from the parent context's live operations ma...
static nanobind::object createFromCapsule(const nanobind::object &capsule)
Creates a PyOperation from the MlirOperation wrapped by a capsule.
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirOperation.
static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Creates a detached operation.
nanobind::object clone(const nanobind::object &ip)
Clones this operation.
MlirOperation get() const
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
void setAttached(const nanobind::object &parent=nanobind::object())
std::optional< PyOperationRef > getParentOperation()
Gets the parent operation or raises an exception if the operation has no parent.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
PyBlock getBlock()
Gets the owning block or raises an exception if the operation has no owning block.
static nanobind::object create(std::string_view name, std::optional< std::vector< PyType * > > results, llvm::ArrayRef< MlirValue > operands, std::optional< nanobind::dict > attributes, std::optional< std::vector< PyBlock * > > successors, int regions, PyLocation &location, const nanobind::object &ip, bool inferType)
Creates an operation. See corresponding python docstring.
static PyOperationRef parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName)
Parses a source string (either text assembly or bytecode), creating a detached operation.
void setInvalid()
Invalidate the operation.
Wrapper around an MlirRegion.
PyOperationRef & getParentOperation()
Bindings for MLIR symbol tables.
void dunderDel(const std::string &name)
Removes the operation with the given name from the symbol table and erases it, throws if there is no ...
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, nanobind::object callback)
Walks all symbol tables under and including 'from'.
static void replaceAllSymbolUses(const std::string &oldSymbol, const std::string &newSymbol, PyOperationBase &from)
Replaces all symbol uses within an operation.
static void setVisibility(PyOperationBase &symbol, const std::string &visibility)
static void setSymbolName(PyOperationBase &symbol, const std::string &name)
PyStringAttribute insert(PyOperationBase &symbol)
Inserts the given operation into the symbol table.
void erase(PyOperationBase &symbol)
Removes the given operation from the symbol table and erases it.
PySymbolTable(PyOperationBase &operation)
Constructs a symbol table for the given operation.
static PyStringAttribute getSymbolName(PyOperationBase &symbol)
Gets and sets the name of a symbol op.
static PyStringAttribute getVisibility(PyOperationBase &symbol)
Gets and sets the visibility of a symbol op.
nanobind::object dunderGetItem(const std::string &name)
Returns the symbol (opview) with the given name, throws if there is no such symbol in the table.
static PyThreadContextEntry * getTopOfStack()
Stack management.
static void popLocation(PyLocation &location)
static nanobind::object pushLocation(nanobind::object location)
PyLocation * getLocation()
static nanobind::object pushContext(nanobind::object context)
static PyLocation * getDefaultLocation()
Gets the top of stack location and returns nullptr if not defined.
static void popInsertionPoint(PyInsertionPoint &insertionPoint)
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint)
static void popContext(PyMlirContext &context)
static PyInsertionPoint * getDefaultInsertionPoint()
Gets the top of stack insertion point and return nullptr if not defined.
PyMlirContext * getContext()
static PyMlirContext * getDefaultContext()
Gets the top of stack context and return nullptr if not defined.
static std::vector< PyThreadContextEntry > & getStack()
Gets the thread local stack.
PyThreadContextEntry(FrameKind frameKind, nanobind::object context, nanobind::object insertionPoint, nanobind::object location)
PyInsertionPoint * getInsertionPoint()
Wrapper around MlirLlvmThreadPool Python object owns the C++ thread pool.
std::string _mlir_thread_pool_ptr() const
int getMaxConcurrency() const
A TypeID provides an efficient and unique identifier for a specific C++ type.
static PyTypeID createFromCapsule(nanobind::object capsule)
Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
bool operator==(const PyTypeID &other) const
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirTypeID.
PyTypeID(MlirTypeID typeID)
Wrapper around the generic MlirType.
PyType(PyMlirContextRef contextRef, MlirType type)
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirType.
static PyType createFromCapsule(nanobind::object capsule)
Creates a PyType from the MlirType wrapped by a capsule.
nanobind::object maybeDownCast()
bool operator==(const PyType &other) const
Wrapper around the generic MlirValue.
PyValue(PyOperationRef parentOperation, MlirValue value)
static PyValue createFromCapsule(nanobind::object capsule)
Creates a PyValue from the MlirValue wrapped by a capsule.
nanobind::object maybeDownCast()
nanobind::object getCapsule()
Gets a capsule wrapping the void* within the MlirValue.
PyOperationRef & getParentOperation()
MLIR_CAPI_EXPORTED intptr_t mlirDiagnosticGetNumNotes(MlirDiagnostic diagnostic)
Returns the number of notes attached to the diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnosticSeverity mlirDiagnosticGetSeverity(MlirDiagnostic diagnostic)
Returns the severity of the diagnostic.
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MlirDiagnosticSeverity
Severity of a diagnostic.
MLIR_CAPI_EXPORTED MlirDiagnostic mlirDiagnosticGetNote(MlirDiagnostic diagnostic, intptr_t pos)
Returns pos-th note attached to the diagnostic.
MLIR_CAPI_EXPORTED void mlirEmitError(MlirLocation location, const char *message)
Emits an error at the given location through the diagnostics engine.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirLocationGetAttribute(MlirLocation location)
Returns the underlying location attribute of this location.
MlirWalkResult(* MlirOperationWalkCallback)(MlirOperation, void *userData)
Operation walker type.
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v)
Gets the location of the value.
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context)
Gets the number of threads of the thread pool of the context when multithreading is enabled.
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but writing the bytecode format.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col)
Creates an File/Line/Column location owned by the given context.
MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, void(*callback)(MlirOperation, bool, void *userData), void *userData)
Walks all symbol table operations nested within, and including, op.
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location)
Getter for end_column of FileLineColRange.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation)
Inserts the given operation into the given symbol table.
MlirWalkOrder
Traversal order for operation walk.
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGetChildLoc(MlirLocation location)
Getter for childLoc of Name.
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation)
Removes the given operation from the symbol table and erases it.
MLIR_CAPI_EXPORTED void mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry)
Append the contents of the given dialect registry to the registry associated with the context.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type)
Parses a type. The type is owned by the context.
MLIR_CAPI_EXPORTED MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand)
Returns an op operand representing the next use of the value, or a null op operand if there is no nex...
MLIR_CAPI_EXPORTED void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow)
Sets whether unregistered dialects are allowed in this context.
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given ...
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location)
Checks whether the given location is an FileLineColRange.
MLIR_CAPI_EXPORTED unsigned mlirLocationFusedGetNumLocations(MlirLocation location)
Getter for number of locations fused together.
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with)
Replace all uses of 'of' value with the 'with' value, updating anything in the IR that uses 'of' to u...
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData)
Prints a value as an operand (i.e., the ValueID).
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context)
Creates a location with unknown position owned by the given context.
MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand)
Returns the owner operation of an op operand.
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location)
Getter for filename of FileLineColRange.
MLIR_CAPI_EXPORTED void mlirLocationFusedGetLocations(MlirLocation location, MlirLocation *locationsCPtr)
Getter for locations of Fused.
MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData to callback`...
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block)
Returns the region that contains this block.
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other)
Moves the given operation immediately before the other operation in its parent block.
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions)
Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use '...
MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts AsmState controlling the printing behavior as well as caching ...
MlirWalkResult
Operation walk result.
@ MlirWalkResultInterrupt
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block)
Takes a block owned by the caller and inserts it at pos to the given region.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name)
Returns whether the given fully-qualified operation (i.e.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block)
Returns the number of arguments of the block.
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartLine(MlirLocation location)
Getter for start_line of FileLineColRange.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata)
Creates a fused location with an array of locations and metadata.
MLIR_CAPI_EXPORTED void mlirBlockInsertOwnedOperationBefore(MlirBlock block, MlirOperation reference, MlirOperation operation)
Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in t...
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name)
Gets the dialect instance owned by the given context using the dialect namespace to identify it,...
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location)
Checks whether the given location is an CallSite.
struct MlirNamedAttribute MlirNamedAttribute
MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, MlirBlock block)
Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given r...
MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, MlirLocation const *locs)
Creates a new empty block with the given argument types and transfers ownership to the caller.
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation)
Takes an operation owned by the caller and appends it to the block.
MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos)
Returns pos-th argument of the block.
MLIR_CAPI_EXPORTED MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name)
Looks up a symbol with the given name in the given symbol table and returns the operation that corres...
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col)
Creates an File/Line/Column range location owned by the given context.
MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand)
Returns whether the op operand is null.
MLIR_CAPI_EXPORTED MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation)
Creates a symbol table for the given operation.
MLIR_CAPI_EXPORTED bool mlirLocationEqual(MlirLocation l1, MlirLocation l2)
Checks if two locations are equal.
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location)
Getter for start_column of FileLineColRange.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location)
Checks whether the given location is an Fused.
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc)
Appends an argument of the specified type to the block.
MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, MlirStringCallback callback, void *userData)
Same as mlirOperationPrint but accepts flags controlling the printing behavior.
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value)
Returns an op operand representing the first use of the value, or a null op operand if there are no u...
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool)
Sets the thread pool of the context explicitly, enabling multithreading in the process.
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op)
Verify the operation and return true if it passes, false if it fails.
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
MLIR_CAPI_EXPORTED unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand)
Returns the operand number of an op operand.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location)
Getter for caller of CallSite.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block)
Returns the terminator operation in the block or null if no terminator.
MLIR_CAPI_EXPORTED MlirIdentifier mlirLocationNameGetName(MlirLocation location)
Getter for name of Name.
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIR_CAPI_EXPORTED MlirLocation mlirLocationFromAttribute(MlirAttribute attribute)
Creates a location from a location attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type)
Gets the type ID of the type.
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(void)
Returns the name of the attribute used to store symbol visibility.
static bool mlirDialectIsNull(MlirDialect dialect)
Checks if the dialect is null.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block)
Returns the block immediately following the given block in its parent region.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller)
Creates a call site location with a callee and a caller.
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location)
Checks whether the given location is an Name.
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder)
Walks operation op in walkOrder and calls callback on that operation.
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithThreading(bool threadingEnabled)
Creates an MLIR context with an explicit setting of the multithreading setting and transfers its owne...
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location)
Gets the context that a location was created with.
MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index)
Erase the argument at 'index' and remove it from the argument list.
MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr)
Prints the attribute to the standard error stream.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from)
Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol...
MLIR_CAPI_EXPORTED void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block)
Takes a block owned by the caller and appends it to the given region.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetFirstOperation(MlirBlock block)
Returns the first operation in the block.
static bool mlirRegionIsNull(MlirRegion region)
Checks whether a region is null.
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type)
Gets the dialect a type belongs to.
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
MLIR_CAPI_EXPORTED void mlirContextLoadAllAvailableDialects(MlirContext context)
Eagerly loads all available dialects registered with a context, making them available for use for IR ...
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context)
Gets the thread pool of the context when enabled multithreading, otherwise an assertion is raised.
MLIR_CAPI_EXPORTED int mlirLocationFileLineColRangeGetEndLine(MlirLocation location)
Getter for end_line of FileLineColRange.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc)
Creates a name location owned by the given context.
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable)
Set threading mode (must be set to false to mlir-print-ir-after-all).
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location)
Getter for callee of CallSite.
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v)
Gets the context that a value was created with.
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void)
Returns the name of the attribute used to store symbol names compatible with symbol tables.
MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void)
Creates a new empty region and transfers ownership to the caller.
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block)
Detach a block from the owning region and assume ownership.
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op)
Prints an operation to stderr.
static bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable)
Returns true if the symbol table is null.
MLIR_CAPI_EXPORTED bool mlirContextGetAllowUnregisteredDialects(MlirContext context)
Returns whether the context allows unregistered dialects.
MLIR_CAPI_EXPORTED void mlirOperationReplaceUsesOfWith(MlirOperation op, MlirValue of, MlirValue with)
Replace uses of 'of' value with the 'with' value inside the 'op' operation.
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MlirOperation other)
Moves the given operation immediately after the other operation in its parent block.
MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData)
Prints a block by sending chunks of the string representation and forwarding userData to callback`.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData)
Same as mlirOperationWriteBytecode but with writer config and returns failure only if desired bytecod...
MLIR_CAPI_EXPORTED void mlirContextDestroy(MlirContext context)
Takes an MLIR context owned by the caller and destroys it.
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region)
Gets the first block in the region.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
struct MlirStringRef MlirStringRef
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
PyObjectRef< PyOperation > PyOperationRef
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
void populateIRCore(nanobind::module_ &m)
PyObjectRef< PyModule > PyModuleRef
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
A logical result value, essentially a boolean with named states.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
static bool dunderContains(const std::string &attributeKind)
static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace)
static nb::callable dunderGetItemNamed(const std::string &attributeKind)
static void bind(nb::module_ &m)
Wrapper for the global LLVM debugging flag.
static void bind(nb::module_ &m)
static void set(nb::object &o, bool enable)
static bool get(const nb::object &)
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
std::vector< PyDiagnostic::DiagnosticInfo > errorDiagnostics
Materialized diagnostic information.
MlirDiagnosticSeverity severity
std::vector< DiagnosticInfo > notes
RAII object that captures any error diagnostics emitted to the provided context.
ErrorCapture(PyMlirContextRef ctx)
std::vector< PyDiagnostic::DiagnosticInfo > take()