MLIR  18.0.0git
IRInterfaces.cpp
Go to the documentation of this file.
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <optional>
10 #include <utility>
11 
12 #include "IRModule.h"
14 #include "mlir-c/Interfaces.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 namespace py = pybind11;
18 
19 namespace mlir {
20 namespace python {
21 
22 constexpr static const char *constructorDoc =
23  R"(Creates an interface from a given operation/opview object or from a
24 subclass of OpView. Raises ValueError if the operation does not implement the
25 interface.)";
26 
27 constexpr static const char *operationDoc =
28  R"(Returns an Operation for which the interface was constructed.)";
29 
30 constexpr static const char *opviewDoc =
31  R"(Returns an OpView subclass _instance_ for which the interface was
32 constructed)";
33 
34 constexpr static const char *inferReturnTypesDoc =
35  R"(Given the arguments required to build an operation, attempts to infer
36 its return types. Raises ValueError on failure.)";
37 
38 constexpr static const char *inferReturnTypeComponentsDoc =
39  R"(Given the arguments required to build an operation, attempts to infer
40 its return shaped type components. Raises ValueError on failure.)";
41 
42 namespace {
43 
44 /// Takes in an optional ist of operands and converts them into a SmallVector
45 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
46 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
47  llvm::SmallVector<MlirValue> mlirOperands;
48 
49  if (!operandList || operandList->empty()) {
50  return mlirOperands;
51  }
52 
53  // Note: as the list may contain other lists this may not be final size.
54  mlirOperands.reserve(operandList->size());
55  for (const auto &&it : llvm::enumerate(*operandList)) {
56  if (it.value().is_none())
57  continue;
58 
59  PyValue *val;
60  try {
61  val = py::cast<PyValue *>(it.value());
62  if (!val)
63  throw py::cast_error();
64  mlirOperands.push_back(val->get());
65  continue;
66  } catch (py::cast_error &err) {
67  // Intentionally unhandled to try sequence below first.
68  (void)err;
69  }
70 
71  try {
72  auto vals = py::cast<py::sequence>(it.value());
73  for (py::object v : vals) {
74  try {
75  val = py::cast<PyValue *>(v);
76  if (!val)
77  throw py::cast_error();
78  mlirOperands.push_back(val->get());
79  } catch (py::cast_error &err) {
80  throw py::value_error(
81  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
82  " must be a Value or Sequence of Values (" + err.what() + ")")
83  .str());
84  }
85  }
86  continue;
87  } catch (py::cast_error &err) {
88  throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
89  " must be a Value or Sequence of Values (" +
90  err.what() + ")")
91  .str());
92  }
93 
94  throw py::cast_error();
95  }
96 
97  return mlirOperands;
98 }
99 
100 /// Takes in an optional vector of PyRegions and returns a SmallVector of
101 /// MlirRegion. Returns an empty SmallVector if the list is empty.
103 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
104  llvm::SmallVector<MlirRegion> mlirRegions;
105 
106  if (regions) {
107  mlirRegions.reserve(regions->size());
108  for (PyRegion &region : *regions) {
109  mlirRegions.push_back(region);
110  }
111  }
112 
113  return mlirRegions;
114 }
115 
116 } // namespace
117 
118 /// CRTP base class for Python classes representing MLIR Op interfaces.
119 /// Interface hierarchies are flat so no base class is expected here. The
120 /// derived class is expected to define the following static fields:
121 /// - `const char *pyClassName` - the name of the Python class to create;
122 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
123 /// of the interface.
124 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
125 /// interface-specific methods.
126 ///
127 /// An interface class may be constructed from either an Operation/OpView object
128 /// or from a subclass of OpView. In the latter case, only the static interface
129 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
130 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
131 /// method to check whether the interface object was constructed from a class or
132 /// an operation/opview instance. The `getOpName` always succeeds and returns a
133 /// canonical name of the operation suitable for lookups.
134 template <typename ConcreteIface>
136 protected:
137  using ClassTy = py::class_<ConcreteIface>;
138  using GetTypeIDFunctionTy = MlirTypeID (*)();
139 
140 public:
141  /// Constructs an interface instance from an object that is either an
142  /// operation or a subclass of OpView. In the latter case, only the static
143  /// methods of the interface are accessible to the caller.
144  PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
145  : obj(std::move(object)) {
146  try {
147  operation = &py::cast<PyOperation &>(obj);
148  } catch (py::cast_error &) {
149  // Do nothing.
150  }
151 
152  try {
153  operation = &py::cast<PyOpView &>(obj).getOperation();
154  } catch (py::cast_error &) {
155  // Do nothing.
156  }
157 
158  if (operation != nullptr) {
159  if (!mlirOperationImplementsInterface(*operation,
160  ConcreteIface::getInterfaceID())) {
161  std::string msg = "the operation does not implement ";
162  throw py::value_error(msg + ConcreteIface::pyClassName);
163  }
164 
165  MlirIdentifier identifier = mlirOperationGetName(*operation);
166  MlirStringRef stringRef = mlirIdentifierStr(identifier);
167  opName = std::string(stringRef.data, stringRef.length);
168  } else {
169  try {
170  opName = obj.attr("OPERATION_NAME").template cast<std::string>();
171  } catch (py::cast_error &) {
172  throw py::type_error(
173  "Op interface does not refer to an operation or OpView class");
174  }
175 
177  mlirStringRefCreate(opName.data(), opName.length()),
178  context.resolve().get(), ConcreteIface::getInterfaceID())) {
179  std::string msg = "the operation does not implement ";
180  throw py::value_error(msg + ConcreteIface::pyClassName);
181  }
182  }
183  }
184 
185  /// Creates the Python bindings for this class in the given module.
186  static void bind(py::module &m) {
187  py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
188  py::module_local());
189  cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
190  py::arg("context") = py::none(), constructorDoc)
191  .def_property_readonly("operation",
193  operationDoc)
194  .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
195  opviewDoc);
196  ConcreteIface::bindDerived(cls);
197  }
198 
199  /// Hook for derived classes to add class-specific bindings.
200  static void bindDerived(ClassTy &cls) {}
201 
202  /// Returns `true` if this object was constructed from a subclass of OpView
203  /// rather than from an operation instance.
204  bool isStatic() { return operation == nullptr; }
205 
206  /// Returns the operation instance from which this object was constructed.
207  /// Throws a type error if this object was constructed from a subclass of
208  /// OpView.
209  py::object getOperationObject() {
210  if (operation == nullptr) {
211  throw py::type_error("Cannot get an operation from a static interface");
212  }
213 
214  return operation->getRef().releaseObject();
215  }
216 
217  /// Returns the opview of the operation instance from which this object was
218  /// constructed. Throws a type error if this object was constructed form a
219  /// subclass of OpView.
220  py::object getOpView() {
221  if (operation == nullptr) {
222  throw py::type_error("Cannot get an opview from a static interface");
223  }
224 
225  return operation->createOpView();
226  }
227 
228  /// Returns the canonical name of the operation this interface is constructed
229  /// from.
230  const std::string &getOpName() { return opName; }
231 
232 private:
233  PyOperation *operation = nullptr;
234  std::string opName;
235  py::object obj;
236 };
237 
238 /// Python wrapper for InferTypeOpInterface. This interface has only static
239 /// methods.
241  : public PyConcreteOpInterface<PyInferTypeOpInterface> {
242 public:
244 
245  constexpr static const char *pyClassName = "InferTypeOpInterface";
248 
249  /// C-style user-data structure for type appending callback.
251  std::vector<PyType> &inferredTypes;
253  };
254 
255  /// Appends the types provided as the two first arguments to the user-data
256  /// structure (expects AppendResultsCallbackData).
257  static void appendResultsCallback(intptr_t nTypes, MlirType *types,
258  void *userData) {
259  auto *data = static_cast<AppendResultsCallbackData *>(userData);
260  data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
261  for (intptr_t i = 0; i < nTypes; ++i) {
262  data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
263  }
264  }
265 
266  /// Given the arguments required to build an operation, attempts to infer its
267  /// return types. Throws value_error on failure.
268  std::vector<PyType>
269  inferReturnTypes(std::optional<py::list> operandList,
270  std::optional<PyAttribute> attributes, void *properties,
271  std::optional<std::vector<PyRegion>> regions,
272  DefaultingPyMlirContext context,
273  DefaultingPyLocation location) {
274  llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
275  llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
276 
277  std::vector<PyType> inferredTypes;
278  PyMlirContext &pyContext = context.resolve();
279  AppendResultsCallbackData data{inferredTypes, pyContext};
280  MlirStringRef opNameRef =
281  mlirStringRefCreate(getOpName().data(), getOpName().length());
282  MlirAttribute attributeDict =
283  attributes ? attributes->get() : mlirAttributeGetNull();
284 
286  opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
287  mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
288  mlirRegions.data(), &appendResultsCallback, &data);
289 
290  if (mlirLogicalResultIsFailure(result)) {
291  throw py::value_error("Failed to infer result types");
292  }
293 
294  return inferredTypes;
295  }
296 
297  static void bindDerived(ClassTy &cls) {
298  cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
299  py::arg("operands") = py::none(),
300  py::arg("attributes") = py::none(),
301  py::arg("properties") = py::none(), py::arg("regions") = py::none(),
302  py::arg("context") = py::none(), py::arg("loc") = py::none(),
304  }
305 };
306 
307 /// Wrapper around an shaped type components.
309 public:
310  PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
311  PyShapedTypeComponents(py::list shape, MlirType elementType)
312  : shape(shape), elementType(elementType), ranked(true) {}
313  PyShapedTypeComponents(py::list shape, MlirType elementType,
314  MlirAttribute attribute)
315  : shape(shape), elementType(elementType), attribute(attribute),
316  ranked(true) {}
319  : shape(other.shape), elementType(other.elementType),
320  attribute(other.attribute), ranked(other.ranked) {}
321 
322  static void bind(py::module &m) {
323  py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
324  py::module_local())
325  .def_property_readonly(
326  "element_type",
327  [](PyShapedTypeComponents &self) { return self.elementType; },
328  "Returns the element type of the shaped type components.")
329  .def_static(
330  "get",
331  [](PyType &elementType) {
332  return PyShapedTypeComponents(elementType);
333  },
334  py::arg("element_type"),
335  "Create an shaped type components object with only the element "
336  "type.")
337  .def_static(
338  "get",
339  [](py::list shape, PyType &elementType) {
340  return PyShapedTypeComponents(shape, elementType);
341  },
342  py::arg("shape"), py::arg("element_type"),
343  "Create a ranked shaped type components object.")
344  .def_static(
345  "get",
346  [](py::list shape, PyType &elementType, PyAttribute &attribute) {
347  return PyShapedTypeComponents(shape, elementType, attribute);
348  },
349  py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
350  "Create a ranked shaped type components object with attribute.")
351  .def_property_readonly(
352  "has_rank",
353  [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
354  "Returns whether the given shaped type component is ranked.")
355  .def_property_readonly(
356  "rank",
357  [](PyShapedTypeComponents &self) -> py::object {
358  if (!self.ranked) {
359  return py::none();
360  }
361  return py::int_(self.shape.size());
362  },
363  "Returns the rank of the given ranked shaped type components. If "
364  "the shaped type components does not have a rank, None is "
365  "returned.")
366  .def_property_readonly(
367  "shape",
368  [](PyShapedTypeComponents &self) -> py::object {
369  if (!self.ranked) {
370  return py::none();
371  }
372  return py::list(self.shape);
373  },
374  "Returns the shape of the ranked shaped type components as a list "
375  "of integers. Returns none if the shaped type component does not "
376  "have a rank.");
377  }
378 
379  pybind11::object getCapsule();
380  static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
381 
382 private:
383  py::list shape;
384  MlirType elementType;
385  MlirAttribute attribute;
386  bool ranked{false};
387 };
388 
389 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
390 /// static methods.
392  : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
393 public:
394  using PyConcreteOpInterface<
396 
397  constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
400 
401  /// C-style user-data structure for type appending callback.
403  std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
404  };
405 
406  /// Appends the shaped type components provided as unpacked shape, element
407  /// type, attribute to the user-data.
408  static void appendResultsCallback(bool hasRank, intptr_t rank,
409  const int64_t *shape, MlirType elementType,
410  MlirAttribute attribute, void *userData) {
411  auto *data = static_cast<AppendResultsCallbackData *>(userData);
412  if (!hasRank) {
413  data->inferredShapedTypeComponents.emplace_back(elementType);
414  } else {
415  py::list shapeList;
416  for (intptr_t i = 0; i < rank; ++i) {
417  shapeList.append(shape[i]);
418  }
419  data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
420  attribute);
421  }
422  }
423 
424  /// Given the arguments required to build an operation, attempts to infer the
425  /// shaped type components. Throws value_error on failure.
426  std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
427  std::optional<py::list> operandList,
428  std::optional<PyAttribute> attributes, void *properties,
429  std::optional<std::vector<PyRegion>> regions,
430  DefaultingPyMlirContext context, DefaultingPyLocation location) {
431  llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
432  llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
433 
434  std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
435  PyMlirContext &pyContext = context.resolve();
436  AppendResultsCallbackData data{inferredShapedTypeComponents};
437  MlirStringRef opNameRef =
438  mlirStringRefCreate(getOpName().data(), getOpName().length());
439  MlirAttribute attributeDict =
440  attributes ? attributes->get() : mlirAttributeGetNull();
441 
443  opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
444  mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
445  mlirRegions.data(), &appendResultsCallback, &data);
446 
447  if (mlirLogicalResultIsFailure(result)) {
448  throw py::value_error("Failed to infer result shape type components");
449  }
450 
451  return inferredShapedTypeComponents;
452  }
453 
454  static void bindDerived(ClassTy &cls) {
455  cls.def("inferReturnTypeComponents",
457  py::arg("operands") = py::none(),
458  py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
459  py::arg("properties") = py::none(), py::arg("context") = py::none(),
460  py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
461  }
462 };
463 
464 void populateIRInterfaces(py::module &m) {
468 }
469 
470 } // namespace python
471 } // namespace mlir
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:493
static PyLocation & resolve()
Definition: IRCore.cpp:992
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:268
static PyMlirContext & resolve()
Definition: IRCore.cpp:715
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:960
CRTP base class for Python classes representing MLIR Op interfaces.
static void bind(py::module &m)
Creates the Python bindings for this class in the given module.
py::object getOpView()
Returns the opview of the operation instance from which this object was constructed.
py::class_< ConcreteIface > ClassTy
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
Constructs an interface instance from an object that is either an operation or a subclass of OpView.
bool isStatic()
Returns true if this object was constructed from a subclass of OpView rather than from an operation i...
static void bindDerived(ClassTy &cls)
Hook for derived classes to add class-specific bindings.
py::object getOperationObject()
Returns the operation instance from which this object was constructed.
const std::string & getOpName()
Returns the canonical name of the operation this interface is constructed from.
Python wrapper for InferShapedTypeOpInterface.
constexpr static const char * pyClassName
static void appendResultsCallback(bool hasRank, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute attribute, void *userData)
Appends the shaped type components provided as unpacked shape, element type, attribute to the user-da...
std::vector< PyShapedTypeComponents > inferReturnTypeComponents(std::optional< py::list > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer the shaped type components.
constexpr static GetTypeIDFunctionTy getInterfaceID
Python wrapper for InferTypeOpInterface.
constexpr static GetTypeIDFunctionTy getInterfaceID
std::vector< PyType > inferReturnTypes(std::optional< py::list > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion >> regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
constexpr static const char * pyClassName
static void bindDerived(ClassTy &cls)
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:182
pybind11::object releaseObject()
Releases the object held by this instance, returning it.
Definition: IRModule.h:73
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition: IRModule.h:585
PyOperationRef getRef()
Definition: IRModule.h:620
pybind11::object createOpView()
Creates an OpView suitable for this operation.
Definition: IRCore.cpp:1425
Wrapper around an MlirRegion.
Definition: IRModule.h:733
Wrapper around an shaped type components.
PyShapedTypeComponents(PyShapedTypeComponents &)=delete
static void bind(py::module &m)
PyShapedTypeComponents(py::list shape, MlirType elementType)
PyShapedTypeComponents(MlirType elementType)
PyShapedTypeComponents(PyShapedTypeComponents &&other)
PyShapedTypeComponents(py::list shape, MlirType elementType, MlirAttribute attribute)
static PyShapedTypeComponents createFromCapsule(pybind11::object capsule)
Wrapper around the generic MlirType.
Definition: IRModule.h:838
Wrapper around the generic MlirValue.
Definition: IRModule.h:1091
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op)
Gets the name of the operation as an identifier.
Definition: IR.cpp:523
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1027
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
Definition: Interfaces.cpp:127
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
Definition: Interfaces.cpp:80
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
Definition: Interfaces.cpp:88
MLIR_CAPI_EXPORTED bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
Definition: Interfaces.cpp:73
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
Definition: Interfaces.cpp:92
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
Definition: Interfaces.cpp:123
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
constexpr static const char * inferReturnTypesDoc
constexpr static const char * operationDoc
constexpr static const char * opviewDoc
constexpr static const char * constructorDoc
constexpr static const char * inferReturnTypeComponentsDoc
void populateIRInterfaces(py::module &m)
This header declares functions that assist transformations in the MemRef dialect.
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
C-style user-data structure for type appending callback.
std::vector< PyShapedTypeComponents > & inferredShapedTypeComponents
C-style user-data structure for type appending callback.