MLIR 22.0.0git
Interfaces.cpp
Go to the documentation of this file.
1
2
3//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir-c/Interfaces.h"
12
13#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Wrap.h"
17#include "mlir/IR/ValueRange.h"
19#include "llvm/ADT/ScopeExit.h"
20#include <optional>
21
22using namespace mlir;
23
24namespace {
25
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28 StringRef name(opName.data, opName.length);
29 std::optional<RegisteredOperationName> info =
31 return info;
32}
33
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
36 if (!mlirLocationIsNull(location))
37 maybeLocation = unwrap(location);
38 return maybeLocation;
39}
40
41SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42 SmallVector<Value> unwrappedOperands;
43 (void)unwrapList(nOperands, operands, unwrappedOperands);
44 return unwrappedOperands;
45}
46
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51 return attributeDict;
52}
53
54SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55 MlirRegion *regions) {
56 // Create a vector of unique pointers to regions and make sure they are not
57 // deleted when exiting the scope. This is a hack caused by C++ API expecting
58 // an list of unique pointers to regions (without ownership transfer
59 // semantics) and C API making ownership transfer explicit.
61 unwrappedRegions.reserve(nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64 auto cleaner = llvm::make_scope_exit([&]() {
65 for (auto &region : unwrappedRegions)
66 region.release();
67 });
68 return unwrappedRegions;
69}
70
71} // namespace
72
73bool mlirOperationImplementsInterface(MlirOperation operation,
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(unwrap(interfaceTypeID));
78}
79
81 MlirContext context,
82 MlirTypeID interfaceTypeID) {
83 std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84 StringRef(operationName.data, operationName.length), unwrap(context));
85 return info && info->hasInterface(unwrap(interfaceTypeID));
86}
87
89 return wrap(InferTypeOpInterface::getInterfaceID());
90}
91
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
96 MlirTypesCallback callback, void *userData) {
97 StringRef name(opName.data, opName.length);
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
100 if (!info)
102
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
104 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
106 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107 unwrapRegions(nRegions, regions);
108
109 SmallVector<Type> inferredTypes;
110 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112 properties, unwrappedRegions, inferredTypes)))
114
115 SmallVector<MlirType> wrappedInferredTypes;
116 wrappedInferredTypes.reserve(inferredTypes.size());
117 for (Type t : inferredTypes)
118 wrappedInferredTypes.push_back(wrap(t));
119 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
121}
122
124 return wrap(InferShapedTypeOpInterface::getInterfaceID());
125}
126
128 MlirStringRef opName, MlirContext context, MlirLocation location,
129 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130 void *properties, intptr_t nRegions, MlirRegion *regions,
131 MlirShapedTypeComponentsCallback callback, void *userData) {
132 std::optional<RegisteredOperationName> info =
133 getRegisteredOperationName(context, opName);
134 if (!info)
136
137 std::optional<Location> maybeLocation = maybeGetLocation(location);
138 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139 DictionaryAttr attributeDict = unwrapAttributes(attributes);
140 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141 unwrapRegions(nRegions, regions);
142
143 SmallVector<ShapedTypeComponents> inferredTypeComponents;
144 if (failed(info->getInterface<InferShapedTypeOpInterface>()
145 ->inferReturnTypeComponents(
146 unwrap(context), maybeLocation,
147 mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148 attributeDict, properties, unwrappedRegions,
149 inferredTypeComponents)))
151
152 bool hasRank;
153 intptr_t rank;
154 const int64_t *shapeData;
155 for (const ShapedTypeComponents &t : inferredTypeComponents) {
156 if (t.hasRank()) {
157 hasRank = true;
158 rank = t.getDims().size();
159 shapeData = t.getDims().data();
160 } else {
161 hasRank = false;
162 rank = 0;
163 shapeData = nullptr;
164 }
165 callback(hasRank, rank, shapeData, wrap(t.getElementType()),
166 wrap(t.getAttribute()), userData);
167 }
169}
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition Wrap.h:40
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
Definition Interfaces.h:77
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
Definition Interfaces.h:51
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:138
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:132
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Definition Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:73
const char * data
Pointer to the first symbol.
Definition Support.h:74
size_t length
Length of the fragment.
Definition Support.h:75