MLIR  19.0.0git
IRInterfaces.cpp
Go to the documentation of this file.
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <optional>
11 #include <pybind11/cast.h>
12 #include <pybind11/detail/common.h>
13 #include <pybind11/pybind11.h>
14 #include <pybind11/pytypes.h>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "IRModule.h"
21 #include "mlir-c/IR.h"
22 #include "mlir-c/Interfaces.h"
23 #include "mlir-c/Support.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 namespace py = pybind11;
28 
29 namespace mlir {
30 namespace python {
31 
32 constexpr static const char *constructorDoc =
33  R"(Creates an interface from a given operation/opview object or from a
34 subclass of OpView. Raises ValueError if the operation does not implement the
35 interface.)";
36 
37 constexpr static const char *operationDoc =
38  R"(Returns an Operation for which the interface was constructed.)";
39 
40 constexpr static const char *opviewDoc =
41  R"(Returns an OpView subclass _instance_ for which the interface was
42 constructed)";
43 
44 constexpr static const char *inferReturnTypesDoc =
45  R"(Given the arguments required to build an operation, attempts to infer
46 its return types. Raises ValueError on failure.)";
47 
48 constexpr static const char *inferReturnTypeComponentsDoc =
49  R"(Given the arguments required to build an operation, attempts to infer
50 its return shaped type components. Raises ValueError on failure.)";
51 
52 namespace {
53 
54 /// Takes in an optional ist of operands and converts them into a SmallVector
55 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
56 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
57  llvm::SmallVector<MlirValue> mlirOperands;
58 
59  if (!operandList || operandList->empty()) {
60  return mlirOperands;
61  }
62 
63  // Note: as the list may contain other lists this may not be final size.
64  mlirOperands.reserve(operandList->size());
65  for (const auto &&it : llvm::enumerate(*operandList)) {
66  if (it.value().is_none())
67  continue;
68 
69  PyValue *val;
70  try {
71  val = py::cast<PyValue *>(it.value());
72  if (!val)
73  throw py::cast_error();
74  mlirOperands.push_back(val->get());
75  continue;
76  } catch (py::cast_error &err) {
77  // Intentionally unhandled to try sequence below first.
78  (void)err;
79  }
80 
81  try {
82  auto vals = py::cast<py::sequence>(it.value());
83  for (py::object v : vals) {
84  try {
85  val = py::cast<PyValue *>(v);
86  if (!val)
87  throw py::cast_error();
88  mlirOperands.push_back(val->get());
89  } catch (py::cast_error &err) {
90  throw py::value_error(
91  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
92  " must be a Value or Sequence of Values (" + err.what() + ")")
93  .str());
94  }
95  }
96  continue;
97  } catch (py::cast_error &err) {
98  throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
99  " must be a Value or Sequence of Values (" +
100  err.what() + ")")
101  .str());
102  }
103 
104  throw py::cast_error();
105  }
106 
107  return mlirOperands;
108 }
109 
110 /// Takes in an optional vector of PyRegions and returns a SmallVector of
111 /// MlirRegion. Returns an empty SmallVector if the list is empty.
113 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
114  llvm::SmallVector<MlirRegion> mlirRegions;
115 
116  if (regions) {
117  mlirRegions.reserve(regions->size());
118  for (PyRegion &region : *regions) {
119  mlirRegions.push_back(region);
120  }
121  }
122 
123  return mlirRegions;
124 }
125 
126 } // namespace
127 
128 /// CRTP base class for Python classes representing MLIR Op interfaces.
129 /// Interface hierarchies are flat so no base class is expected here. The
130 /// derived class is expected to define the following static fields:
131 /// - `const char *pyClassName` - the name of the Python class to create;
132 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
133 /// of the interface.
134 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
135 /// interface-specific methods.
136 ///
137 /// An interface class may be constructed from either an Operation/OpView object
138 /// or from a subclass of OpView. In the latter case, only the static interface
139 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
140 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
141 /// method to check whether the interface object was constructed from a class or
142 /// an operation/opview instance. The `getOpName` always succeeds and returns a
143 /// canonical name of the operation suitable for lookups.
144 template <typename ConcreteIface>
146 protected:
147  using ClassTy = py::class_<ConcreteIface>;
148  using GetTypeIDFunctionTy = MlirTypeID (*)();
149 
150 public:
151  /// Constructs an interface instance from an object that is either an
152  /// operation or a subclass of OpView. In the latter case, only the static
153  /// methods of the interface are accessible to the caller.
154  PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
155  : obj(std::move(object)) {
156  try {
157  operation = &py::cast<PyOperation &>(obj);
158  } catch (py::cast_error &) {
159  // Do nothing.
160  }
161 
162  try {
163  operation = &py::cast<PyOpView &>(obj).getOperation();
164  } catch (py::cast_error &) {
165  // Do nothing.
166  }
167 
168  if (operation != nullptr) {
169  if (!mlirOperationImplementsInterface(*operation,
170  ConcreteIface::getInterfaceID())) {
171  std::string msg = "the operation does not implement ";
172  throw py::value_error(msg + ConcreteIface::pyClassName);
173  }
174 
175  MlirIdentifier identifier = mlirOperationGetName(*operation);
176  MlirStringRef stringRef = mlirIdentifierStr(identifier);
177  opName = std::string(stringRef.data, stringRef.length);
178  } else {
179  try {
180  opName = obj.attr("OPERATION_NAME").template cast<std::string>();
181  } catch (py::cast_error &) {
182  throw py::type_error(
183  "Op interface does not refer to an operation or OpView class");
184  }
185 
187  mlirStringRefCreate(opName.data(), opName.length()),
188  context.resolve().get(), ConcreteIface::getInterfaceID())) {
189  std::string msg = "the operation does not implement ";
190  throw py::value_error(msg + ConcreteIface::pyClassName);
191  }
192  }
193  }
194 
195  /// Creates the Python bindings for this class in the given module.
196  static void bind(py::module &m) {
197  py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
198  py::module_local());
199  cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
200  py::arg("context") = py::none(), constructorDoc)
201  .def_property_readonly("operation",
203  operationDoc)
204  .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
205  opviewDoc);
206  ConcreteIface::bindDerived(cls);
207  }
208 
209  /// Hook for derived classes to add class-specific bindings.
210  static void bindDerived(ClassTy &cls) {}
211 
212  /// Returns `true` if this object was constructed from a subclass of OpView
213  /// rather than from an operation instance.
214  bool isStatic() { return operation == nullptr; }
215 
216  /// Returns the operation instance from which this object was constructed.
217  /// Throws a type error if this object was constructed from a subclass of
218  /// OpView.
219  py::object getOperationObject() {
220  if (operation == nullptr) {
221  throw py::type_error("Cannot get an operation from a static interface");
222  }
223 
224  return operation->getRef().releaseObject();
225  }
226 
227  /// Returns the opview of the operation instance from which this object was
228  /// constructed. Throws a type error if this object was constructed form a
229  /// subclass of OpView.
230  py::object getOpView() {
231  if (operation == nullptr) {
232  throw py::type_error("Cannot get an opview from a static interface");
233  }
234 
235  return operation->createOpView();
236  }
237 
238  /// Returns the canonical name of the operation this interface is constructed
239  /// from.
240  const std::string &getOpName() { return opName; }
241 
242 private:
243  PyOperation *operation = nullptr;
244  std::string opName;
245  py::object obj;
246 };
247 
248 /// Python wrapper for InferTypeOpInterface. This interface has only static
249 /// methods.
251  : public PyConcreteOpInterface<PyInferTypeOpInterface> {
252 public:
254 
255  constexpr static const char *pyClassName = "InferTypeOpInterface";
258 
259  /// C-style user-data structure for type appending callback.
261  std::vector<PyType> &inferredTypes;
263  };
264 
265  /// Appends the types provided as the two first arguments to the user-data
266  /// structure (expects AppendResultsCallbackData).
267  static void appendResultsCallback(intptr_t nTypes, MlirType *types,
268  void *userData) {
269  auto *data = static_cast<AppendResultsCallbackData *>(userData);
270  data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
271  for (intptr_t i = 0; i < nTypes; ++i) {
272  data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
273  }
274  }
275 
276  /// Given the arguments required to build an operation, attempts to infer its
277  /// return types. Throws value_error on failure.
278  std::vector<PyType>
279  inferReturnTypes(std::optional<py::list> operandList,
280  std::optional<PyAttribute> attributes, void *properties,
281  std::optional<std::vector<PyRegion>> regions,
282  DefaultingPyMlirContext context,
283  DefaultingPyLocation location) {
284  llvm::SmallVector<MlirValue> mlirOperands =
285  wrapOperands(std::move(operandList));
286  llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
287 
288  std::vector<PyType> inferredTypes;
289  PyMlirContext &pyContext = context.resolve();
290  AppendResultsCallbackData data{inferredTypes, pyContext};
291  MlirStringRef opNameRef =
292  mlirStringRefCreate(getOpName().data(), getOpName().length());
293  MlirAttribute attributeDict =
294  attributes ? attributes->get() : mlirAttributeGetNull();
295 
297  opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
298  mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
299  mlirRegions.data(), &appendResultsCallback, &data);
300 
301  if (mlirLogicalResultIsFailure(result)) {
302  throw py::value_error("Failed to infer result types");
303  }
304 
305  return inferredTypes;
306  }
307 
308  static void bindDerived(ClassTy &cls) {
309  cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
310  py::arg("operands") = py::none(),
311  py::arg("attributes") = py::none(),
312  py::arg("properties") = py::none(), py::arg("regions") = py::none(),
313  py::arg("context") = py::none(), py::arg("loc") = py::none(),
315  }
316 };
317 
318 /// Wrapper around an shaped type components.
320 public:
321  PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
322  PyShapedTypeComponents(py::list shape, MlirType elementType)
323  : shape(std::move(shape)), elementType(elementType), ranked(true) {}
324  PyShapedTypeComponents(py::list shape, MlirType elementType,
325  MlirAttribute attribute)
326  : shape(std::move(shape)), elementType(elementType), attribute(attribute),
327  ranked(true) {}
330  : shape(other.shape), elementType(other.elementType),
331  attribute(other.attribute), ranked(other.ranked) {}
332 
333  static void bind(py::module &m) {
334  py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
335  py::module_local())
336  .def_property_readonly(
337  "element_type",
338  [](PyShapedTypeComponents &self) { return self.elementType; },
339  "Returns the element type of the shaped type components.")
340  .def_static(
341  "get",
342  [](PyType &elementType) {
343  return PyShapedTypeComponents(elementType);
344  },
345  py::arg("element_type"),
346  "Create an shaped type components object with only the element "
347  "type.")
348  .def_static(
349  "get",
350  [](py::list shape, PyType &elementType) {
351  return PyShapedTypeComponents(std::move(shape), elementType);
352  },
353  py::arg("shape"), py::arg("element_type"),
354  "Create a ranked shaped type components object.")
355  .def_static(
356  "get",
357  [](py::list shape, PyType &elementType, PyAttribute &attribute) {
358  return PyShapedTypeComponents(std::move(shape), elementType,
359  attribute);
360  },
361  py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
362  "Create a ranked shaped type components object with attribute.")
363  .def_property_readonly(
364  "has_rank",
365  [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
366  "Returns whether the given shaped type component is ranked.")
367  .def_property_readonly(
368  "rank",
369  [](PyShapedTypeComponents &self) -> py::object {
370  if (!self.ranked) {
371  return py::none();
372  }
373  return py::int_(self.shape.size());
374  },
375  "Returns the rank of the given ranked shaped type components. If "
376  "the shaped type components does not have a rank, None is "
377  "returned.")
378  .def_property_readonly(
379  "shape",
380  [](PyShapedTypeComponents &self) -> py::object {
381  if (!self.ranked) {
382  return py::none();
383  }
384  return py::list(self.shape);
385  },
386  "Returns the shape of the ranked shaped type components as a list "
387  "of integers. Returns none if the shaped type component does not "
388  "have a rank.");
389  }
390 
391  pybind11::object getCapsule();
392  static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
393 
394 private:
395  py::list shape;
396  MlirType elementType;
397  MlirAttribute attribute;
398  bool ranked{false};
399 };
400 
401 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
402 /// static methods.
404  : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
405 public:
406  using PyConcreteOpInterface<
408 
409  constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
412 
413  /// C-style user-data structure for type appending callback.
415  std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
416  };
417 
418  /// Appends the shaped type components provided as unpacked shape, element
419  /// type, attribute to the user-data.
420  static void appendResultsCallback(bool hasRank, intptr_t rank,
421  const int64_t *shape, MlirType elementType,
422  MlirAttribute attribute, void *userData) {
423  auto *data = static_cast<AppendResultsCallbackData *>(userData);
424  if (!hasRank) {
425  data->inferredShapedTypeComponents.emplace_back(elementType);
426  } else {
427  py::list shapeList;
428  for (intptr_t i = 0; i < rank; ++i) {
429  shapeList.append(shape[i]);
430  }
431  data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
432  attribute);
433  }
434  }
435 
436  /// Given the arguments required to build an operation, attempts to infer the
437  /// shaped type components. Throws value_error on failure.
438  std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
439  std::optional<py::list> operandList,
440  std::optional<PyAttribute> attributes, void *properties,
441  std::optional<std::vector<PyRegion>> regions,
442  DefaultingPyMlirContext context, DefaultingPyLocation location) {
443  llvm::SmallVector<MlirValue> mlirOperands =
444  wrapOperands(std::move(operandList));
445  llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
446 
447  std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
448  PyMlirContext &pyContext = context.resolve();
449  AppendResultsCallbackData data{inferredShapedTypeComponents};
450  MlirStringRef opNameRef =
451  mlirStringRefCreate(getOpName().data(), getOpName().length());
452  MlirAttribute attributeDict =
453  attributes ? attributes->get() : mlirAttributeGetNull();
454 
456  opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
457  mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
458  mlirRegions.data(), &appendResultsCallback, &data);
459 
460  if (mlirLogicalResultIsFailure(result)) {
461  throw py::value_error("Failed to infer result shape type components");
462  }
463 
464  return inferredShapedTypeComponents;
465  }
466 
467  static void bindDerived(ClassTy &cls) {
468  cls.def("inferReturnTypeComponents",
470  py::arg("operands") = py::none(),
471  py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
472  py::arg("properties") = py::none(), py::arg("context") = py::none(),
473  py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
474  }
475 };
476 
477 void populateIRInterfaces(py::module &m) {
481 }
482 
483 } // namespace python
484 } // namespace mlir
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:510
static PyLocation & resolve()
Definition: IRCore.cpp:1039
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:284
static PyMlirContext & resolve()
Definition: IRCore.cpp:763
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:989
CRTP base class for Python classes representing MLIR Op interfaces.
static void bind(py::module &m)
Creates the Python bindings for this class in the given module.
py::object getOpView()
Returns the opview of the operation instance from which this object was constructed.
py::class_< ConcreteIface > ClassTy
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
Constructs an interface instance from an object that is either an operation or a subclass of OpView.
bool isStatic()
Returns true if this object was constructed from a subclass of OpView rather than from an operation i...
static void bindDerived(ClassTy &cls)
Hook for derived classes to add class-specific bindings.
py::object getOperationObject()
Returns the operation instance from which this object was constructed.
const std::string & getOpName()
Returns the canonical name of the operation this interface is constructed from.
Python wrapper for InferShapedTypeOpInterface.
constexpr static const char * pyClassName
static void appendResultsCallback(bool hasRank, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute attribute, void *userData)
Appends the shaped type components provided as unpacked shape, element type, attribute to the user-da...
std::vector< PyShapedTypeComponents > inferReturnTypeComponents(std::optional< py::list > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer the shaped type components.
constexpr static GetTypeIDFunctionTy getInterfaceID
Python wrapper for InferTypeOpInterface.
constexpr static GetTypeIDFunctionTy getInterfaceID
std::vector< PyType > inferReturnTypes(std::optional< py::list > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
constexpr static const char * pyClassName
static void bindDerived(ClassTy &cls)
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:184
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:75
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:605
PyOperationRef getRef()
Definition: IRModule.h:640
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1484
Wrapper around an MlirRegion.
Definition: IRModule.h:753
Wrapper around an shaped type components.
PyShapedTypeComponents(PyShapedTypeComponents &)=delete
static void bind(py::module &m)
PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
PyShapedTypeComponents(py::list shape, MlirType elementType)
PyShapedTypeComponents(MlirType elementType)
PyShapedTypeComponents(py::list shape, MlirType elementType, MlirAttribute attribute)
static PyShapedTypeComponents createFromCapsule(pybind11::object capsule)
Wrapper around the generic MlirType.
Definition: IRModule.h:867
Wrapper around the generic MlirValue.
Definition: IRModule.h:1120
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:519
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1103
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
Definition: Interfaces.cpp:127
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
Definition: Interfaces.cpp:80
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
Definition: Interfaces.cpp:88
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
Definition: Interfaces.cpp:73
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
Definition: Interfaces.cpp:92
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
Definition: Interfaces.cpp:123
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
constexpr static const char * inferReturnTypesDoc
constexpr static const char * operationDoc
constexpr static const char * opviewDoc
constexpr static const char * constructorDoc
constexpr static const char * inferReturnTypeComponentsDoc
void populateIRInterfaces(py::module &m)
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
C-style user-data structure for type appending callback.
std::vector< PyShapedTypeComponents > & inferredShapedTypeComponents
C-style user-data structure for type appending callback.