MLIR  16.0.0git
IRInterfaces.cpp
Go to the documentation of this file.
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "IRModule.h"
13 #include "mlir-c/Interfaces.h"
14 
15 namespace py = pybind11;
16 
17 namespace mlir {
18 namespace python {
19 
20 constexpr static const char *constructorDoc =
21  R"(Creates an interface from a given operation/opview object or from a
22 subclass of OpView. Raises ValueError if the operation does not implement the
23 interface.)";
24 
25 constexpr static const char *operationDoc =
26  R"(Returns an Operation for which the interface was constructed.)";
27 
28 constexpr static const char *opviewDoc =
29  R"(Returns an OpView subclass _instance_ for which the interface was
30 constructed)";
31 
32 constexpr static const char *inferReturnTypesDoc =
33  R"(Given the arguments required to build an operation, attempts to infer
34 its return types. Raises ValueError on failure.)";
35 
36 /// CRTP base class for Python classes representing MLIR Op interfaces.
37 /// Interface hierarchies are flat so no base class is expected here. The
38 /// derived class is expected to define the following static fields:
39 /// - `const char *pyClassName` - the name of the Python class to create;
40 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
41 /// of the interface.
42 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
43 /// interface-specific methods.
44 ///
45 /// An interface class may be constructed from either an Operation/OpView object
46 /// or from a subclass of OpView. In the latter case, only the static interface
47 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
48 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
49 /// method to check whether the interface object was constructed from a class or
50 /// an operation/opview instance. The `getOpName` always succeeds and returns a
51 /// canonical name of the operation suitable for lookups.
52 template <typename ConcreteIface>
54 protected:
55  using ClassTy = py::class_<ConcreteIface>;
56  using GetTypeIDFunctionTy = MlirTypeID (*)();
57 
58 public:
59  /// Constructs an interface instance from an object that is either an
60  /// operation or a subclass of OpView. In the latter case, only the static
61  /// methods of the interface are accessible to the caller.
62  PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
63  : obj(std::move(object)) {
64  try {
65  operation = &py::cast<PyOperation &>(obj);
66  } catch (py::cast_error &) {
67  // Do nothing.
68  }
69 
70  try {
71  operation = &py::cast<PyOpView &>(obj).getOperation();
72  } catch (py::cast_error &) {
73  // Do nothing.
74  }
75 
76  if (operation != nullptr) {
77  if (!mlirOperationImplementsInterface(*operation,
78  ConcreteIface::getInterfaceID())) {
79  std::string msg = "the operation does not implement ";
80  throw py::value_error(msg + ConcreteIface::pyClassName);
81  }
82 
83  MlirIdentifier identifier = mlirOperationGetName(*operation);
84  MlirStringRef stringRef = mlirIdentifierStr(identifier);
85  opName = std::string(stringRef.data, stringRef.length);
86  } else {
87  try {
88  opName = obj.attr("OPERATION_NAME").template cast<std::string>();
89  } catch (py::cast_error &) {
90  throw py::type_error(
91  "Op interface does not refer to an operation or OpView class");
92  }
93 
95  mlirStringRefCreate(opName.data(), opName.length()),
96  context.resolve().get(), ConcreteIface::getInterfaceID())) {
97  std::string msg = "the operation does not implement ";
98  throw py::value_error(msg + ConcreteIface::pyClassName);
99  }
100  }
101  }
102 
103  /// Creates the Python bindings for this class in the given module.
104  static void bind(py::module &m) {
105  py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
106  py::module_local());
107  cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
108  py::arg("context") = py::none(), constructorDoc)
109  .def_property_readonly("operation",
111  operationDoc)
112  .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
113  opviewDoc);
114  ConcreteIface::bindDerived(cls);
115  }
116 
117  /// Hook for derived classes to add class-specific bindings.
118  static void bindDerived(ClassTy &cls) {}
119 
120  /// Returns `true` if this object was constructed from a subclass of OpView
121  /// rather than from an operation instance.
122  bool isStatic() { return operation == nullptr; }
123 
124  /// Returns the operation instance from which this object was constructed.
125  /// Throws a type error if this object was constructed from a subclass of
126  /// OpView.
127  py::object getOperationObject() {
128  if (operation == nullptr) {
129  throw py::type_error("Cannot get an operation from a static interface");
130  }
131 
132  return operation->getRef().releaseObject();
133  }
134 
135  /// Returns the opview of the operation instance from which this object was
136  /// constructed. Throws a type error if this object was constructed form a
137  /// subclass of OpView.
138  py::object getOpView() {
139  if (operation == nullptr) {
140  throw py::type_error("Cannot get an opview from a static interface");
141  }
142 
143  return operation->createOpView();
144  }
145 
146  /// Returns the canonical name of the operation this interface is constructed
147  /// from.
148  const std::string &getOpName() { return opName; }
149 
150 private:
151  PyOperation *operation = nullptr;
152  std::string opName;
153  py::object obj;
154 };
155 
156 /// Python wrapper for InterTypeOpInterface. This interface has only static
157 /// methods.
159  : public PyConcreteOpInterface<PyInferTypeOpInterface> {
160 public:
162 
163  constexpr static const char *pyClassName = "InferTypeOpInterface";
164  constexpr static GetTypeIDFunctionTy getInterfaceID =
166 
167  /// C-style user-data structure for type appending callback.
169  std::vector<PyType> &inferredTypes;
171  };
172 
173  /// Appends the types provided as the two first arguments to the user-data
174  /// structure (expects AppendResultsCallbackData).
175  static void appendResultsCallback(intptr_t nTypes, MlirType *types,
176  void *userData) {
177  auto *data = static_cast<AppendResultsCallbackData *>(userData);
178  data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
179  for (intptr_t i = 0; i < nTypes; ++i) {
180  data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
181  }
182  }
183 
184  /// Given the arguments required to build an operation, attempts to infer its
185  /// return types. Throws value_error on faliure.
186  std::vector<PyType>
187  inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
188  llvm::Optional<PyAttribute> attributes,
189  llvm::Optional<std::vector<PyRegion>> regions,
190  DefaultingPyMlirContext context,
191  DefaultingPyLocation location) {
192  llvm::SmallVector<MlirValue> mlirOperands;
193  llvm::SmallVector<MlirRegion> mlirRegions;
194 
195  if (operands) {
196  mlirOperands.reserve(operands->size());
197  for (PyValue &value : *operands) {
198  mlirOperands.push_back(value);
199  }
200  }
201 
202  if (regions) {
203  mlirRegions.reserve(regions->size());
204  for (PyRegion &region : *regions) {
205  mlirRegions.push_back(region);
206  }
207  }
208 
209  std::vector<PyType> inferredTypes;
210  PyMlirContext &pyContext = context.resolve();
211  AppendResultsCallbackData data{inferredTypes, pyContext};
212  MlirStringRef opNameRef =
213  mlirStringRefCreate(getOpName().data(), getOpName().length());
214  MlirAttribute attributeDict =
215  attributes ? attributes->get() : mlirAttributeGetNull();
216 
218  opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
219  mlirOperands.data(), attributeDict, mlirRegions.size(),
220  mlirRegions.data(), &appendResultsCallback, &data);
221 
222  if (mlirLogicalResultIsFailure(result)) {
223  throw py::value_error("Failed to infer result types");
224  }
225 
226  return inferredTypes;
227  }
228 
229  static void bindDerived(ClassTy &cls) {
230  cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
231  py::arg("operands") = py::none(),
232  py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
233  py::arg("context") = py::none(), py::arg("loc") = py::none(),
234  inferReturnTypesDoc);
235  }
236 };
237 
239 
240 } // namespace python
241 } // namespace mlir
Include the generated interface declarations.
static void bindDerived(ClassTy &cls)
const char * data
Pointer to the first symbol.
Definition: Support.h:72
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
Definition: Interfaces.cpp:26
std::vector< PyType > inferReturnTypes(llvm::Optional< std::vector< PyValue >> operands, llvm::Optional< PyAttribute > attributes, llvm::Optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
bool isStatic()
Returns true if this object was constructed from a subclass of OpView rather than from an operation i...
static constexpr const char * opviewDoc
Used in function arguments when None should resolve to the current context manager set instance...
Definition: IRModule.h:449
static constexpr const char * constructorDoc
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:125
static PyLocation & resolve()
Definition: IRCore.cpp:847
static void bind(py::module &m)
Creates the Python bindings for this class in the given module.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
Definition: Interfaces.cpp:34
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1244
static constexpr const char * inferReturnTypesDoc
Wrapper around an MlirRegion.
Definition: IRModule.h:668
Used in function arguments when None should resolve to the current context manager set instance...
Definition: IRModule.h:258
C-style user-data structure for type appending callback.
static constexpr const bool value
py::object getOpView()
Returns the opview of the operation instance from which this object was constructed.
void populateIRInterfaces(py::module &m)
static constexpr const char * operationDoc
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:71
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
Constructs an interface instance from an object that is either an operation or a subclass of OpView...
A logical result value, essentially a boolean with named states.
Definition: Support.h:114
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
Definition: Interfaces.cpp:38
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:384
static void bindDerived(ClassTy &cls)
Hook for derived classes to add class-specific bindings.
size_t length
Length of the fragment.
Definition: Support.h:73
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:180
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:80
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull()
Returns an empty attribute.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:71
CRTP base class for Python classes representing MLIR Op interfaces.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:799
Wrapper around the generic MlirValue.
Definition: IRModule.h:916
static PyMlirContext & resolve()
Definition: IRCore.cpp:577
Python wrapper for InterTypeOpInterface.
py::object getOperationObject()
Returns the operation instance from which this object was constructed.
const std::string & getOpName()
Returns the canonical name of the operation this interface is constructed from.
PyOperationRef getRef()
Definition: IRModule.h:563
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
Definition: Interfaces.cpp:19