MLIR  17.0.0git
IRInterfaces.cpp
Go to the documentation of this file.
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 #include <optional>
11 
12 #include "IRModule.h"
14 #include "mlir-c/Interfaces.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 namespace py = pybind11;
18 
19 namespace mlir {
20 namespace python {
21 
22 constexpr static const char *constructorDoc =
23  R"(Creates an interface from a given operation/opview object or from a
24 subclass of OpView. Raises ValueError if the operation does not implement the
25 interface.)";
26 
27 constexpr static const char *operationDoc =
28  R"(Returns an Operation for which the interface was constructed.)";
29 
30 constexpr static const char *opviewDoc =
31  R"(Returns an OpView subclass _instance_ for which the interface was
32 constructed)";
33 
34 constexpr static const char *inferReturnTypesDoc =
35  R"(Given the arguments required to build an operation, attempts to infer
36 its return types. Raises ValueError on failure.)";
37 
38 /// CRTP base class for Python classes representing MLIR Op interfaces.
39 /// Interface hierarchies are flat so no base class is expected here. The
40 /// derived class is expected to define the following static fields:
41 /// - `const char *pyClassName` - the name of the Python class to create;
42 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
43 /// of the interface.
44 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
45 /// interface-specific methods.
46 ///
47 /// An interface class may be constructed from either an Operation/OpView object
48 /// or from a subclass of OpView. In the latter case, only the static interface
49 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
50 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
51 /// method to check whether the interface object was constructed from a class or
52 /// an operation/opview instance. The `getOpName` always succeeds and returns a
53 /// canonical name of the operation suitable for lookups.
54 template <typename ConcreteIface>
56 protected:
57  using ClassTy = py::class_<ConcreteIface>;
58  using GetTypeIDFunctionTy = MlirTypeID (*)();
59 
60 public:
61  /// Constructs an interface instance from an object that is either an
62  /// operation or a subclass of OpView. In the latter case, only the static
63  /// methods of the interface are accessible to the caller.
64  PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
65  : obj(std::move(object)) {
66  try {
67  operation = &py::cast<PyOperation &>(obj);
68  } catch (py::cast_error &) {
69  // Do nothing.
70  }
71 
72  try {
73  operation = &py::cast<PyOpView &>(obj).getOperation();
74  } catch (py::cast_error &) {
75  // Do nothing.
76  }
77 
78  if (operation != nullptr) {
79  if (!mlirOperationImplementsInterface(*operation,
80  ConcreteIface::getInterfaceID())) {
81  std::string msg = "the operation does not implement ";
82  throw py::value_error(msg + ConcreteIface::pyClassName);
83  }
84 
85  MlirIdentifier identifier = mlirOperationGetName(*operation);
86  MlirStringRef stringRef = mlirIdentifierStr(identifier);
87  opName = std::string(stringRef.data, stringRef.length);
88  } else {
89  try {
90  opName = obj.attr("OPERATION_NAME").template cast<std::string>();
91  } catch (py::cast_error &) {
92  throw py::type_error(
93  "Op interface does not refer to an operation or OpView class");
94  }
95 
97  mlirStringRefCreate(opName.data(), opName.length()),
98  context.resolve().get(), ConcreteIface::getInterfaceID())) {
99  std::string msg = "the operation does not implement ";
100  throw py::value_error(msg + ConcreteIface::pyClassName);
101  }
102  }
103  }
104 
105  /// Creates the Python bindings for this class in the given module.
106  static void bind(py::module &m) {
107  py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
108  py::module_local());
109  cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
110  py::arg("context") = py::none(), constructorDoc)
111  .def_property_readonly("operation",
113  operationDoc)
114  .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
115  opviewDoc);
116  ConcreteIface::bindDerived(cls);
117  }
118 
119  /// Hook for derived classes to add class-specific bindings.
120  static void bindDerived(ClassTy &cls) {}
121 
122  /// Returns `true` if this object was constructed from a subclass of OpView
123  /// rather than from an operation instance.
124  bool isStatic() { return operation == nullptr; }
125 
126  /// Returns the operation instance from which this object was constructed.
127  /// Throws a type error if this object was constructed from a subclass of
128  /// OpView.
129  py::object getOperationObject() {
130  if (operation == nullptr) {
131  throw py::type_error("Cannot get an operation from a static interface");
132  }
133 
134  return operation->getRef().releaseObject();
135  }
136 
137  /// Returns the opview of the operation instance from which this object was
138  /// constructed. Throws a type error if this object was constructed form a
139  /// subclass of OpView.
140  py::object getOpView() {
141  if (operation == nullptr) {
142  throw py::type_error("Cannot get an opview from a static interface");
143  }
144 
145  return operation->createOpView();
146  }
147 
148  /// Returns the canonical name of the operation this interface is constructed
149  /// from.
150  const std::string &getOpName() { return opName; }
151 
152 private:
153  PyOperation *operation = nullptr;
154  std::string opName;
155  py::object obj;
156 };
157 
158 /// Python wrapper for InterTypeOpInterface. This interface has only static
159 /// methods.
161  : public PyConcreteOpInterface<PyInferTypeOpInterface> {
162 public:
164 
165  constexpr static const char *pyClassName = "InferTypeOpInterface";
168 
169  /// C-style user-data structure for type appending callback.
171  std::vector<PyType> &inferredTypes;
173  };
174 
175  /// Appends the types provided as the two first arguments to the user-data
176  /// structure (expects AppendResultsCallbackData).
177  static void appendResultsCallback(intptr_t nTypes, MlirType *types,
178  void *userData) {
179  auto *data = static_cast<AppendResultsCallbackData *>(userData);
180  data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
181  for (intptr_t i = 0; i < nTypes; ++i) {
182  data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
183  }
184  }
185 
186  /// Given the arguments required to build an operation, attempts to infer its
187  /// return types. Throws value_error on failure.
188  std::vector<PyType>
189  inferReturnTypes(std::optional<py::list> operandList,
190  std::optional<PyAttribute> attributes,
191  std::optional<std::vector<PyRegion>> regions,
192  DefaultingPyMlirContext context,
193  DefaultingPyLocation location) {
194  llvm::SmallVector<MlirValue> mlirOperands;
195  llvm::SmallVector<MlirRegion> mlirRegions;
196 
197  if (operandList && !operandList->empty()) {
198  // Note: as the list may contain other lists this may not be final size.
199  mlirOperands.reserve(operandList->size());
200  for (const auto& it : llvm::enumerate(*operandList)) {
201  PyValue* val;
202  try {
203  val = py::cast<PyValue *>(it.value());
204  if (!val)
205  throw py::cast_error();
206  mlirOperands.push_back(val->get());
207  continue;
208  } catch (py::cast_error &err) {
209  // Intentionally unhandled to try sequence below first.
210  (void)err;
211  }
212 
213  try {
214  auto vals = py::cast<py::sequence>(it.value());
215  for (py::object v : vals) {
216  try {
217  val = py::cast<PyValue *>(v);
218  if (!val)
219  throw py::cast_error();
220  mlirOperands.push_back(val->get());
221  } catch (py::cast_error &err) {
222  throw py::value_error(
223  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
224  " must be a Value or Sequence of Values (" + err.what() +
225  ")")
226  .str());
227  }
228  }
229  continue;
230  } catch (py::cast_error &err) {
231  throw py::value_error(
232  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
233  " must be a Value or Sequence of Values (" + err.what() + ")")
234  .str());
235  }
236 
237  throw py::cast_error();
238  }
239  }
240 
241  if (regions) {
242  mlirRegions.reserve(regions->size());
243  for (PyRegion &region : *regions) {
244  mlirRegions.push_back(region);
245  }
246  }
247 
248  std::vector<PyType> inferredTypes;
249  PyMlirContext &pyContext = context.resolve();
250  AppendResultsCallbackData data{inferredTypes, pyContext};
251  MlirStringRef opNameRef =
252  mlirStringRefCreate(getOpName().data(), getOpName().length());
253  MlirAttribute attributeDict =
254  attributes ? attributes->get() : mlirAttributeGetNull();
255 
257  opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
258  mlirOperands.data(), attributeDict, mlirRegions.size(),
259  mlirRegions.data(), &appendResultsCallback, &data);
260 
261  if (mlirLogicalResultIsFailure(result)) {
262  throw py::value_error("Failed to infer result types");
263  }
264 
265  return inferredTypes;
266  }
267 
268  static void bindDerived(ClassTy &cls) {
269  cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
270  py::arg("operands") = py::none(),
271  py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
272  py::arg("context") = py::none(), py::arg("loc") = py::none(),
274  }
275 };
276 
278 
279 } // namespace python
280 } // namespace mlir
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:491
static PyLocation & resolve()
Definition: IRCore.cpp:949
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:266
static PyMlirContext & resolve()
Definition: IRCore.cpp:672
CRTP base class for Python classes representing MLIR Op interfaces.
static void bind(py::module &m)
Creates the Python bindings for this class in the given module.
py::object getOpView()
Returns the opview of the operation instance from which this object was constructed.
py::class_< ConcreteIface > ClassTy
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
Constructs an interface instance from an object that is either an operation or a subclass of OpView.
bool isStatic()
Returns true if this object was constructed from a subclass of OpView rather than from an operation i...
static void bindDerived(ClassTy &cls)
Hook for derived classes to add class-specific bindings.
py::object getOperationObject()
Returns the operation instance from which this object was constructed.
const std::string & getOpName()
Returns the canonical name of the operation this interface is constructed from.
Python wrapper for InterTypeOpInterface.
std::vector< PyType > inferReturnTypes(std::optional< py::list > operandList, std::optional< PyAttribute > attributes, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
constexpr static GetTypeIDFunctionTy getInterfaceID
constexpr static const char * pyClassName
static void bindDerived(ClassTy &cls)
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:180
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:71
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:582
PyOperationRef getRef()
Definition: IRModule.h:617
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1367
Wrapper around an MlirRegion.
Definition: IRModule.h:730
Wrapper around the generic MlirValue.
Definition: IRModule.h:978
MlirValue get()
Definition: IRModule.h:984
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:406
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:864
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
Definition: Interfaces.cpp:27
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
Definition: Interfaces.cpp:35
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
Definition: Interfaces.cpp:20
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
Definition: Interfaces.cpp:39
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:80
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
constexpr static const char * inferReturnTypesDoc
constexpr static const char * operationDoc
constexpr static const char * opviewDoc
constexpr static const char * constructorDoc
void populateIRInterfaces(py::module &m)
This header declares functions that assit transformations in the MemRef dialect.
A logical result value, essentially a boolean with named states.
Definition: Support.h:114
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:71
const char * data
Pointer to the first symbol.
Definition: Support.h:72
size_t length
Length of the fragment.
Definition: Support.h:73
C-style user-data structure for type appending callback.