MLIR 23.0.0git
Interfaces.cpp
Go to the documentation of this file.
1
2
3//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir-c/Interfaces.h"
12
13#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Wrap.h"
17#include "mlir/IR/ValueRange.h"
19#include "llvm/ADT/ScopeExit.h"
20#include <optional>
21
22using namespace mlir;
23
24namespace {
25
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28 StringRef name(opName.data, opName.length);
29 std::optional<RegisteredOperationName> info =
31 return info;
32}
33
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
36 if (!mlirLocationIsNull(location))
37 maybeLocation = unwrap(location);
38 return maybeLocation;
39}
40
41SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42 SmallVector<Value> unwrappedOperands;
43 (void)unwrapList(nOperands, operands, unwrappedOperands);
44 return unwrappedOperands;
45}
46
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51 return attributeDict;
52}
53
54SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55 MlirRegion *regions) {
56 // Create a vector of unique pointers to regions and make sure they are not
57 // deleted when exiting the scope. This is a hack caused by C++ API expecting
58 // an list of unique pointers to regions (without ownership transfer
59 // semantics) and C API making ownership transfer explicit.
61 unwrappedRegions.reserve(nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64 llvm::scope_exit cleaner([&]() {
65 for (auto &region : unwrappedRegions)
66 region.release();
67 });
68 return unwrappedRegions;
69}
70
71} // namespace
72
73bool mlirOperationImplementsInterface(MlirOperation operation,
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(unwrap(interfaceTypeID));
78}
79
81 MlirContext context,
82 MlirTypeID interfaceTypeID) {
83 std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84 StringRef(operationName.data, operationName.length), unwrap(context));
85 return info && info->hasInterface(unwrap(interfaceTypeID));
86}
87
89 return wrap(InferTypeOpInterface::getInterfaceID());
90}
91
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
96 MlirTypesCallback callback, void *userData) {
97 StringRef name(opName.data, opName.length);
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
100 if (!info)
102
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
104 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
106 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107 unwrapRegions(nRegions, regions);
108
109 SmallVector<Type> inferredTypes;
110 // The C API passes an opaque void*; we trust the caller to pass the correct
111 // properties type for this operation.
112 // TODO: Create a C API that's more type-safe.
113 PropertyRef propertyRef =
114 properties ? PropertyRef(info->getOpPropertiesTypeID(), properties)
115 : PropertyRef();
116 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
117 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
118 propertyRef, unwrappedRegions, inferredTypes)))
120
121 SmallVector<MlirType> wrappedInferredTypes;
122 wrappedInferredTypes.reserve(inferredTypes.size());
123 for (Type t : inferredTypes)
124 wrappedInferredTypes.push_back(wrap(t));
125 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
127}
128
130 return wrap(InferShapedTypeOpInterface::getInterfaceID());
131}
132
134 MlirStringRef opName, MlirContext context, MlirLocation location,
135 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
136 void *properties, intptr_t nRegions, MlirRegion *regions,
137 MlirShapedTypeComponentsCallback callback, void *userData) {
138 std::optional<RegisteredOperationName> info =
139 getRegisteredOperationName(context, opName);
140 if (!info)
142
143 std::optional<Location> maybeLocation = maybeGetLocation(location);
144 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
145 DictionaryAttr attributeDict = unwrapAttributes(attributes);
146 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
147 unwrapRegions(nRegions, regions);
148
149 SmallVector<ShapedTypeComponents> inferredTypeComponents;
150 // The C API passes an opaque void*; we trust the caller to pass the correct
151 // properties type for this operation.
152 PropertyRef propertyRef =
153 properties ? PropertyRef(info->getOpPropertiesTypeID(), properties)
154 : PropertyRef();
155 if (failed(info->getInterface<InferShapedTypeOpInterface>()
156 ->inferReturnTypeComponents(
157 unwrap(context), maybeLocation,
158 mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
159 attributeDict, propertyRef, unwrappedRegions,
160 inferredTypeComponents)))
162
163 bool hasRank;
164 intptr_t rank;
165 const int64_t *shapeData;
166 for (const ShapedTypeComponents &t : inferredTypeComponents) {
167 if (t.hasRank()) {
168 hasRank = true;
169 rank = t.getDims().size();
170 shapeData = t.getDims().data();
171 } else {
172 hasRank = false;
173 rank = 0;
174 shapeData = nullptr;
175 }
176 callback(hasRank, rank, shapeData, wrap(t.getElementType()),
177 wrap(t.getAttribute()), userData);
178 }
180}
181
182//===---------------------------------------------------------------------===//
183// MemoryEffectOpInterface
184//===---------------------------------------------------------------------===//
185
187 return wrap(MemoryEffectOpInterface::getInterfaceID());
188}
189
190/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
193 MemoryEffectOpInterfaceFallbackModel> {
194public:
195 /// Sets the callbacks that this FallbackModel will use.
196 /// NB: the callbacks can only be set through this method as the
197 /// RegisteredOperationName::attachInterface mechanism default-constructs
198 /// the FallbackModel without being able to provide arguments.
200 this->callbacks = callbacks;
201 }
202
204 if (callbacks.destruct)
205 callbacks.destruct(callbacks.userData);
206 }
207
209 return MemoryEffectOpInterface::getInterfaceID();
210 }
211
212 static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
213 // Enable casting back to the FallbackModel from the Interface. This is
214 // necessary as attachInterface(...) default-constructs the FallbackModel
215 // without being able to pass in the callbacks and returns just the Concept.
216 return true;
217 }
218
219 void
222 assert(callbacks.getEffects && "getEffects callback not set");
223 MlirMemoryEffectInstancesList cEffects = wrap(&effects);
224 callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
225 }
226
227private:
229};
230
231/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
232/// The FallbackModel uses the provided callbacks to implement the interface.
234 MlirContext ctx, MlirStringRef opName,
236 // Look up the operation definition in the context
237 std::optional<RegisteredOperationName> opInfo =
239
240 assert(opInfo.has_value() && "operation not found in context");
241
242 // NB: the following default-constructs the FallbackModel _without_ being able
243 // to provide arguments.
244 opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
245 // Cast to get the underlying FallbackModel and set the callbacks.
246 auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
247 opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
248 assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
249 model->setCallbacks(callbacks);
250}
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID()
Returns the interface TypeID of the MemoryEffectsOpInterface.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition Wrap.h:40
Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op)
void getEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects) const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
Definition Interfaces.h:87
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
Definition Interfaces.h:61
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Callbacks for implementing MemoryEffectsOpInterface from external code.
Definition Interfaces.h:108
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80