MLIR  19.0.0git
Interfaces.cpp
Go to the documentation of this file.
1 
2 
3 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "mlir-c/Interfaces.h"
12 
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Interfaces.h"
15 #include "mlir/CAPI/Support.h"
16 #include "mlir/CAPI/Wrap.h"
17 #include "mlir/IR/ValueRange.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include <optional>
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 std::optional<RegisteredOperationName>
27 getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28  StringRef name(opName.data, opName.length);
29  std::optional<RegisteredOperationName> info =
31  return info;
32 }
33 
34 std::optional<Location> maybeGetLocation(MlirLocation location) {
35  std::optional<Location> maybeLocation;
36  if (!mlirLocationIsNull(location))
37  maybeLocation = unwrap(location);
38  return maybeLocation;
39 }
40 
41 SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42  SmallVector<Value> unwrappedOperands;
43  (void)unwrapList(nOperands, operands, unwrappedOperands);
44  return unwrappedOperands;
45 }
46 
47 DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48  DictionaryAttr attributeDict;
49  if (!mlirAttributeIsNull(attributes))
50  attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51  return attributeDict;
52 }
53 
54 SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55  MlirRegion *regions) {
56  // Create a vector of unique pointers to regions and make sure they are not
57  // deleted when exiting the scope. This is a hack caused by C++ API expecting
58  // an list of unique pointers to regions (without ownership transfer
59  // semantics) and C API making ownership transfer explicit.
60  SmallVector<std::unique_ptr<Region>> unwrappedRegions;
61  unwrappedRegions.reserve(nRegions);
62  for (intptr_t i = 0; i < nRegions; ++i)
63  unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64  auto cleaner = llvm::make_scope_exit([&]() {
65  for (auto &region : unwrappedRegions)
66  region.release();
67  });
68  return unwrappedRegions;
69 }
70 
71 } // namespace
72 
73 bool mlirOperationImplementsInterface(MlirOperation operation,
74  MlirTypeID interfaceTypeID) {
75  std::optional<RegisteredOperationName> info =
76  unwrap(operation)->getRegisteredInfo();
77  return info && info->hasInterface(unwrap(interfaceTypeID));
78 }
79 
81  MlirContext context,
82  MlirTypeID interfaceTypeID) {
83  std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84  StringRef(operationName.data, operationName.length), unwrap(context));
85  return info && info->hasInterface(unwrap(interfaceTypeID));
86 }
87 
89  return wrap(InferTypeOpInterface::getInterfaceID());
90 }
91 
93  MlirStringRef opName, MlirContext context, MlirLocation location,
94  intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95  void *properties, intptr_t nRegions, MlirRegion *regions,
96  MlirTypesCallback callback, void *userData) {
97  StringRef name(opName.data, opName.length);
98  std::optional<RegisteredOperationName> info =
99  getRegisteredOperationName(context, opName);
100  if (!info)
101  return mlirLogicalResultFailure();
102 
103  std::optional<Location> maybeLocation = maybeGetLocation(location);
104  SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105  DictionaryAttr attributeDict = unwrapAttributes(attributes);
106  SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107  unwrapRegions(nRegions, regions);
108 
109  SmallVector<Type> inferredTypes;
110  if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111  unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112  properties, unwrappedRegions, inferredTypes)))
113  return mlirLogicalResultFailure();
114 
115  SmallVector<MlirType> wrappedInferredTypes;
116  wrappedInferredTypes.reserve(inferredTypes.size());
117  for (Type t : inferredTypes)
118  wrappedInferredTypes.push_back(wrap(t));
119  callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
120  return mlirLogicalResultSuccess();
121 }
122 
124  return wrap(InferShapedTypeOpInterface::getInterfaceID());
125 }
126 
128  MlirStringRef opName, MlirContext context, MlirLocation location,
129  intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130  void *properties, intptr_t nRegions, MlirRegion *regions,
131  MlirShapedTypeComponentsCallback callback, void *userData) {
132  std::optional<RegisteredOperationName> info =
133  getRegisteredOperationName(context, opName);
134  if (!info)
135  return mlirLogicalResultFailure();
136 
137  std::optional<Location> maybeLocation = maybeGetLocation(location);
138  SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139  DictionaryAttr attributeDict = unwrapAttributes(attributes);
140  SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141  unwrapRegions(nRegions, regions);
142 
143  SmallVector<ShapedTypeComponents> inferredTypeComponents;
144  if (failed(info->getInterface<InferShapedTypeOpInterface>()
145  ->inferReturnTypeComponents(
146  unwrap(context), maybeLocation,
147  mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148  attributeDict, properties, unwrappedRegions,
149  inferredTypeComponents)))
150  return mlirLogicalResultFailure();
151 
152  bool hasRank;
153  intptr_t rank;
154  const int64_t *shapeData;
155  for (const ShapedTypeComponents &t : inferredTypeComponents) {
156  if (t.hasRank()) {
157  hasRank = true;
158  rank = t.getDims().size();
159  shapeData = t.getDims().data();
160  } else {
161  hasRank = false;
162  rank = 0;
163  shapeData = nullptr;
164  }
165  callback(hasRank, rank, shapeData, wrap(t.getElementType()),
166  wrap(t.getAttribute()), userData);
167  }
168  return mlirLogicalResultSuccess();
169 }
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
Definition: Interfaces.cpp:127
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
Definition: Interfaces.cpp:73
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
Definition: Interfaces.cpp:80
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
Definition: Interfaces.cpp:88
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
Definition: Interfaces.cpp:92
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
Definition: Interfaces.cpp:123
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition: Wrap.h:40
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1019
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition: IR.h:282
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
Definition: Interfaces.h:77
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
Definition: Interfaces.h:51
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition: Support.h:138
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition: Support.h:132
Include the generated interface declarations.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75