28 R
"(Given the arguments required to build an operation, attempts to infer
29its return types. Raises ValueError on failure.)";
32 R
"(Given the arguments required to build an operation, attempts to infer
33its return shaped type components. Raises ValueError on failure.)";
39std::vector<MlirValue> wrapOperands(std::optional<nb::sequence> operandList) {
40 std::vector<MlirValue> mlirOperands;
42 if (!operandList || nb::len(*operandList) == 0) {
47 mlirOperands.reserve(nb::len(*operandList));
48 for (
size_t i = 0, e = nb::len(*operandList); i < e; ++i) {
49 nb::handle operand = (*operandList)[i];
51 if (operand.is_none())
56 val = nb::cast<PyValue *>(operand);
58 throw nb::cast_error();
59 mlirOperands.push_back(val->
get());
61 }
catch (nb::cast_error &err) {
67 auto vals = nb::cast<nb::sequence>(operand);
68 for (nb::handle v : vals) {
70 val = nb::cast<PyValue *>(v);
72 throw nb::cast_error();
73 mlirOperands.push_back(val->
get());
74 }
catch (nb::cast_error &err) {
75 throw nb::value_error(
77 " must be a Value or Sequence of Values (",
83 }
catch (nb::cast_error &err) {
84 throw nb::value_error(
86 " must be a Value or Sequence of Values (",
91 throw nb::cast_error();
99std::vector<MlirRegion>
100wrapRegions(std::optional<std::vector<PyRegion>> regions) {
101 std::vector<MlirRegion> mlirRegions;
104 mlirRegions.reserve(regions->size());
106 mlirRegions.push_back(region);
122 constexpr static const char *
pyClassName =
"InferTypeOpInterface";
137 data->
inferredTypes.reserve(data->inferredTypes.size() + nTypes);
138 for (
intptr_t i = 0; i < nTypes; ++i) {
139 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
147 std::optional<PyAttribute> attributes,
void *properties,
148 std::optional<std::vector<PyRegion>> regions,
151 std::vector<MlirValue> mlirOperands = wrapOperands(std::move(operandList));
152 std::vector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
154 std::vector<PyType> inferredTypes;
159 MlirAttribute attributeDict =
163 opNameRef, pyContext.
get(), location.
resolve(), mlirOperands.size(),
164 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
168 throw nb::value_error(
"Failed to infer result types");
171 return inferredTypes;
176 nb::arg(
"operands") = nb::none(),
177 nb::arg(
"attributes") = nb::none(),
178 nb::arg(
"properties") = nb::none(), nb::arg(
"regions") = nb::none(),
179 nb::arg(
"context") = nb::none(), nb::arg(
"loc") = nb::none(),
189 : shape(std::move(shape)), elementType(elementType), ranked(
true) {}
191 MlirAttribute attribute)
192 : shape(std::move(shape)), elementType(elementType), attribute(attribute),
196 : shape(other.shape), elementType(other.elementType),
197 attribute(other.attribute), ranked(other.ranked) {}
199 static void bind(nb::module_ &m) {
200 nb::class_<PyShapedTypeComponents>(m,
"ShapedTypeComponents")
204 nb::sig(
"def element_type(self) -> Type"),
205 "Returns the element type of the shaped type components.")
211 nb::arg(
"element_type"),
212 "Create an shaped type components object with only the element "
216 [](nb::typed<nb::list, nb::int_> shape,
PyType &elementType) {
219 nb::arg(
"shape"), nb::arg(
"element_type"),
220 "Create a ranked shaped type components object.")
223 [](nb::typed<nb::list, nb::int_> shape,
PyType &elementType,
228 nb::arg(
"shape"), nb::arg(
"element_type"), nb::arg(
"attribute"),
229 "Create a ranked shaped type components object with attribute.")
233 "Returns whether the given shaped type component is ranked.")
239 return nb::int_(self.shape.size());
241 "Returns the rank of the given ranked shaped type components. If "
242 "the shaped type components does not have a rank, None is "
249 return nb::list(self.shape);
251 "Returns the shape of the ranked shaped type components as a list "
252 "of integers. Returns none if the shaped type component does not "
261 MlirType elementType;
262 MlirAttribute attribute;
274 constexpr static const char *
pyClassName =
"InferShapedTypeOpInterface";
287 MlirAttribute attribute,
void *userData) {
290 data->inferredShapedTypeComponents.emplace_back(elementType);
293 for (
intptr_t i = 0; i < rank; ++i) {
294 shapeList.append(
shape[i]);
296 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
304 std::optional<nb::sequence> operandList,
305 std::optional<PyAttribute> attributes,
void *properties,
306 std::optional<std::vector<PyRegion>> regions,
308 std::vector<MlirValue> mlirOperands = wrapOperands(std::move(operandList));
309 std::vector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
311 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
316 MlirAttribute attributeDict =
320 opNameRef, pyContext.
get(), location.
resolve(), mlirOperands.size(),
321 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
325 throw nb::value_error(
"Failed to infer result shape type components");
328 return inferredShapedTypeComponents;
332 cls.def(
"inferReturnTypeComponents",
334 nb::arg(
"operands") = nb::none(),
335 nb::arg(
"attributes") = nb::none(), nb::arg(
"regions") = nb::none(),
336 nb::arg(
"properties") = nb::none(), nb::arg(
"context") = nb::none(),
348 constexpr static const char *
pyClassName =
"MemoryEffectsOpInterface";
359 nb::handle(
static_cast<PyObject *
>(callbacks.
userData)).inc_ref();
361 callbacks.
destruct = [](
void *userData) {
362 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
365 MlirMemoryEffectInstancesList effects,
367 nb::handle pyClass(
static_cast<PyObject *
>(userData));
371 nb::cast<nb::callable>(nb::getattr(pyClass,
"get_effects"));
380 pyGetEffects(opview, effectsWrapper);
390 [](
const nb::object &cls,
const nb::object &opName, nb::object
target,
394 return attach(
target, nb::cast<std::string>(opName), context);
396 nb::arg(
"cls"), nb::arg(
"op_name"), nb::kw_only(),
397 nb::arg(
"target").none() = nb::none(),
398 nb::arg(
"context").none() = nb::none(),
399 "Attach the interface subclass to the given operation name.");
404 nb::class_<PyMemoryEffectsInstanceList>(m,
"MemoryEffectInstancesList");
true
Given two iterators into the same block, return "true" if a is before `b.
MlirContext mlirOperationGetContext(MlirOperation op)
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
static PyLocation & resolve()
Used in function arguments when None should resolve to the current context manager set instance.
static PyMlirContext & resolve()
Wrapper around the generic MlirAttribute.
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
const std::string & getOpName()
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
nanobind::class_< PyInferTypeOpInterface > ClassTy
Python wrapper for InferShapedTypeOpInterface.
static constexpr GetTypeIDFunctionTy getInterfaceID
static void bindDerived(ClassTy &cls)
static constexpr const char * pyClassName
static void appendResultsCallback(bool hasRank, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute attribute, void *userData)
Appends the shaped type components provided as unpacked shape, element type, attribute to the user-da...
std::vector< PyShapedTypeComponents > inferReturnTypeComponents(std::optional< nb::sequence > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion > > regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer the shaped type components.
Python wrapper for InferTypeOpInterface.
std::vector< PyType > inferReturnTypes(std::optional< nb::sequence > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion > > regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
static void bindDerived(ClassTy &cls)
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getInterfaceID
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
Wrapper around the MemoryEffectsOpInterface.
static constexpr GetTypeIDFunctionTy getInterfaceID
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new MemoryEffectsOpInterface FallbackModel to the named operation.
static void bindDerived(ClassTy &cls)
static constexpr const char * pyClassName
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
MlirContext get()
Accesses the underlying MlirContext.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Wrapper around an MlirRegion.
PyShapedTypeComponents(nb::list shape, MlirType elementType)
PyShapedTypeComponents(MlirType elementType)
PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute)
static PyShapedTypeComponents createFromCapsule(nb::object capsule)
PyShapedTypeComponents(PyShapedTypeComponents &)=delete
PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
static void bind(nb::module_ &m)
Wrapper around the generic MlirType.
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void)
Returns the interface TypeID of the InferTypeOpInterface.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void)
Returns the interface TypeID of the MemoryEffectsOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void)
Returns the interface TypeID of the InferShapedTypeOpInterface.
MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a new FallbackModel for the MemoryEffectsOpInterface to the named operation.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
void populateIRInterfaces(nb::module_ &m)
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
static constexpr const char * inferReturnTypesDoc
static constexpr const char * inferReturnTypeComponentsDoc
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Include the generated interface declarations.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A logical result value, essentially a boolean with named states.
Callbacks for implementing MemoryEffectsOpInterface from external code.
void(* construct)(void *userData)
Optional constructor for user data. Set to nullptr to disable it.
void(* getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects, void *userData)
Get memory effects callback.
void(* destruct)(void *userData)
Optional destructor for user data. Set to nullptr to disable it.
A pointer to a sized fragment of a string, not necessarily null-terminated.
C-style user-data structure for type appending callback.
std::vector< PyShapedTypeComponents > & inferredShapedTypeComponents
C-style user-data structure for type appending callback.
PyMlirContext & pyMlirContext
std::vector< PyType > & inferredTypes