21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
30 R
"(Creates an interface from a given operation/opview object or from a
31 subclass of OpView. Raises ValueError if the operation does not implement the
35 R
"(Returns an Operation for which the interface was constructed.)";
38 R
"(Returns an OpView subclass _instance_ for which the interface was
42 R
"(Given the arguments required to build an operation, attempts to infer
43 its return types. Raises ValueError on failure.)";
46 R
"(Given the arguments required to build an operation, attempts to infer
47 its return shaped type components. Raises ValueError on failure.)";
56 if (!operandList || operandList->size() == 0) {
61 mlirOperands.reserve(operandList->size());
63 if (it.value().is_none())
68 val = nb::cast<PyValue *>(it.value());
70 throw nb::cast_error();
71 mlirOperands.push_back(val->
get());
73 }
catch (nb::cast_error &err) {
79 auto vals = nb::cast<nb::sequence>(it.value());
80 for (nb::handle v : vals) {
82 val = nb::cast<PyValue *>(v);
84 throw nb::cast_error();
85 mlirOperands.push_back(val->
get());
86 }
catch (nb::cast_error &err) {
87 throw nb::value_error(
88 (llvm::Twine(
"Operand ") + llvm::Twine(it.index()) +
89 " must be a Value or Sequence of Values (" + err.what() +
")")
95 }
catch (nb::cast_error &err) {
96 throw nb::value_error((llvm::Twine(
"Operand ") + llvm::Twine(it.index()) +
97 " must be a Value or Sequence of Values (" +
103 throw nb::cast_error();
112 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
116 mlirRegions.reserve(regions->size());
118 mlirRegions.push_back(region);
143 template <
typename ConcreteIface>
154 : obj(std::move(object)) {
156 operation = &nb::cast<PyOperation &>(obj);
157 }
catch (nb::cast_error &) {
163 }
catch (nb::cast_error &) {
167 if (operation !=
nullptr) {
169 ConcreteIface::getInterfaceID())) {
170 std::string msg =
"the operation does not implement ";
171 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
176 opName = std::string(stringRef.
data, stringRef.
length);
179 opName = nb::cast<std::string>(obj.attr(
"OPERATION_NAME"));
180 }
catch (nb::cast_error &) {
181 throw nb::type_error(
182 "Op interface does not refer to an operation or OpView class");
187 context.
resolve().
get(), ConcreteIface::getInterfaceID())) {
188 std::string msg =
"the operation does not implement ";
189 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
195 static void bind(nb::module_ &m) {
196 nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
197 cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg(
"object"),
202 ConcreteIface::bindDerived(cls);
216 if (operation ==
nullptr) {
217 throw nb::type_error(
"Cannot get an operation from a static interface");
227 if (operation ==
nullptr) {
228 throw nb::type_error(
"Cannot get an opview from a static interface");
251 constexpr
static const char *
pyClassName =
"InferTypeOpInterface";
266 data->
inferredTypes.reserve(data->inferredTypes.size() + nTypes);
267 for (intptr_t i = 0; i < nTypes; ++i) {
268 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
276 std::optional<PyAttribute> attributes,
void *properties,
277 std::optional<std::vector<PyRegion>> regions,
281 wrapOperands(std::move(operandList));
284 std::vector<PyType> inferredTypes;
289 MlirAttribute attributeDict =
293 opNameRef, pyContext.
get(), location.
resolve(), mlirOperands.size(),
294 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
298 throw nb::value_error(
"Failed to infer result types");
301 return inferredTypes;
306 nb::arg(
"operands").none() = nb::none(),
307 nb::arg(
"attributes").none() = nb::none(),
308 nb::arg(
"properties").none() = nb::none(),
309 nb::arg(
"regions").none() = nb::none(),
310 nb::arg(
"context").none() = nb::none(),
320 : shape(std::move(shape)), elementType(elementType), ranked(true) {}
322 MlirAttribute attribute)
323 : shape(std::move(shape)), elementType(elementType), attribute(attribute),
327 : shape(other.shape), elementType(other.elementType),
328 attribute(other.attribute), ranked(other.ranked) {}
330 static void bind(nb::module_ &m) {
331 nb::class_<PyShapedTypeComponents>(m,
"ShapedTypeComponents")
335 "Returns the element type of the shaped type components.")
341 nb::arg(
"element_type"),
342 "Create an shaped type components object with only the element "
346 [](nb::list shape,
PyType &elementType) {
349 nb::arg(
"shape"), nb::arg(
"element_type"),
350 "Create a ranked shaped type components object.")
357 nb::arg(
"shape"), nb::arg(
"element_type"), nb::arg(
"attribute"),
358 "Create a ranked shaped type components object with attribute.")
362 "Returns whether the given shaped type component is ranked.")
369 return nb::int_(
self.shape.size());
371 "Returns the rank of the given ranked shaped type components. If "
372 "the shaped type components does not have a rank, None is "
380 return nb::list(
self.shape);
382 "Returns the shape of the ranked shaped type components as a list "
383 "of integers. Returns none if the shaped type component does not "
392 MlirType elementType;
393 MlirAttribute attribute;
405 constexpr
static const char *
pyClassName =
"InferShapedTypeOpInterface";
417 const int64_t *shape, MlirType elementType,
418 MlirAttribute attribute,
void *userData) {
424 for (intptr_t i = 0; i < rank; ++i) {
425 shapeList.append(shape[i]);
427 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
435 std::optional<nb::list> operandList,
436 std::optional<PyAttribute> attributes,
void *properties,
437 std::optional<std::vector<PyRegion>> regions,
440 wrapOperands(std::move(operandList));
443 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
448 MlirAttribute attributeDict =
452 opNameRef, pyContext.
get(), location.
resolve(), mlirOperands.size(),
453 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
457 throw nb::value_error(
"Failed to infer result shape type components");
460 return inferredShapedTypeComponents;
464 cls.def(
"inferReturnTypeComponents",
466 nb::arg(
"operands").none() = nb::none(),
467 nb::arg(
"attributes").none() = nb::none(),
468 nb::arg(
"regions").none() = nb::none(),
469 nb::arg(
"properties").none() = nb::none(),
470 nb::arg(
"context").none() = nb::none(),
Used in function arguments when None should resolve to the current context manager set instance.
static PyLocation & resolve()
Used in function arguments when None should resolve to the current context manager set instance.
static PyMlirContext & resolve()
Wrapper around the generic MlirAttribute.
CRTP base class for Python classes representing MLIR Op interfaces.
MlirTypeID(*)() GetTypeIDFunctionTy
nb::object getOperationObject()
Returns the operation instance from which this object was constructed.
bool isStatic()
Returns true if this object was constructed from a subclass of OpView rather than from an operation i...
static void bind(nb::module_ &m)
Creates the Python bindings for this class in the given module.
PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
Constructs an interface instance from an object that is either an operation or a subclass of OpView.
static void bindDerived(ClassTy &cls)
Hook for derived classes to add class-specific bindings.
const std::string & getOpName()
Returns the canonical name of the operation this interface is constructed from.
nb::class_< ConcreteIface > ClassTy
nb::object getOpView()
Returns the opview of the operation instance from which this object was constructed.
Python wrapper for InferShapedTypeOpInterface.
constexpr static const char * pyClassName
static void appendResultsCallback(bool hasRank, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute attribute, void *userData)
Appends the shaped type components provided as unpacked shape, element type, attribute to the user-da...
std::vector< PyShapedTypeComponents > inferReturnTypeComponents(std::optional< nb::list > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer the shaped type components.
constexpr static GetTypeIDFunctionTy getInterfaceID
static void bindDerived(ClassTy &cls)
Python wrapper for InferTypeOpInterface.
std::vector< PyType > inferReturnTypes(std::optional< nb::list > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
constexpr static GetTypeIDFunctionTy getInterfaceID
constexpr static const char * pyClassName
static void bindDerived(ClassTy &cls)
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
MlirContext get()
Accesses the underlying MlirContext.
nanobind::object releaseObject()
Releases the object held by this instance, returning it.
PyOperation & getOperation() override
Each must provide access to the raw Operation.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Wrapper around an MlirRegion.
Wrapper around an shaped type components.
PyShapedTypeComponents(PyShapedTypeComponents &)=delete
PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute)
PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
PyShapedTypeComponents(nb::list shape, MlirType elementType)
PyShapedTypeComponents(MlirType elementType)
static PyShapedTypeComponents createFromCapsule(nb::object capsule)
static void bind(nb::module_ &m)
Wrapper around the generic MlirType.
Wrapper around the generic MlirValue.
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr static const char * inferReturnTypesDoc
constexpr static const char * operationDoc
void populateIRInterfaces(nb::module_ &m)
constexpr static const char * opviewDoc
constexpr static const char * constructorDoc
constexpr static const char * inferReturnTypeComponentsDoc
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
C-style user-data structure for type appending callback.
std::vector< PyShapedTypeComponents > & inferredShapedTypeComponents
C-style user-data structure for type appending callback.
std::vector< PyType > & inferredTypes
PyMlirContext & pyMlirContext