MLIR 23.0.0git
Interfaces.cpp
Go to the documentation of this file.
1
2
3//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir-c/Interfaces.h"
12
13#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Wrap.h"
17#include "mlir/IR/ValueRange.h"
19#include "llvm/ADT/ScopeExit.h"
20#include <optional>
21
22using namespace mlir;
23
24namespace {
25
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28 StringRef name(opName.data, opName.length);
29 std::optional<RegisteredOperationName> info =
31 return info;
32}
33
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
36 if (!mlirLocationIsNull(location))
37 maybeLocation = unwrap(location);
38 return maybeLocation;
39}
40
41SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42 SmallVector<Value> unwrappedOperands;
43 (void)unwrapList(nOperands, operands, unwrappedOperands);
44 return unwrappedOperands;
45}
46
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51 return attributeDict;
52}
53
54SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55 MlirRegion *regions) {
56 // Create a vector of unique pointers to regions and make sure they are not
57 // deleted when exiting the scope. This is a hack caused by C++ API expecting
58 // an list of unique pointers to regions (without ownership transfer
59 // semantics) and C API making ownership transfer explicit.
61 unwrappedRegions.reserve(nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64 llvm::scope_exit cleaner([&]() {
65 for (auto &region : unwrappedRegions)
66 region.release();
67 });
68 return unwrappedRegions;
69}
70
71} // namespace
72
73bool mlirOperationImplementsInterface(MlirOperation operation,
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(unwrap(interfaceTypeID));
78}
79
81 MlirContext context,
82 MlirTypeID interfaceTypeID) {
83 std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84 StringRef(operationName.data, operationName.length), unwrap(context));
85 return info && info->hasInterface(unwrap(interfaceTypeID));
86}
87
89 return wrap(InferTypeOpInterface::getInterfaceID());
90}
91
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
96 MlirTypesCallback callback, void *userData) {
97 StringRef name(opName.data, opName.length);
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
100 if (!info)
102
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
104 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
106 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107 unwrapRegions(nRegions, regions);
108
109 SmallVector<Type> inferredTypes;
110 // The C API passes an opaque void*; we trust the caller to pass the correct
111 // properties type for this operation.
112 // TODO: Create a C API that's more type-safe.
113 PropertyRef propertyRef =
114 properties ? PropertyRef(info->getOpPropertiesTypeID(), properties)
115 : PropertyRef();
116 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
117 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
118 propertyRef, unwrappedRegions, inferredTypes)))
120
121 SmallVector<MlirType> wrappedInferredTypes;
122 wrappedInferredTypes.reserve(inferredTypes.size());
123 for (Type t : inferredTypes)
124 wrappedInferredTypes.push_back(wrap(t));
125 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
127}
128
130 return wrap(InferShapedTypeOpInterface::getInterfaceID());
131}
132
134 MlirStringRef opName, MlirContext context, MlirLocation location,
135 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
136 void *properties, intptr_t nRegions, MlirRegion *regions,
137 MlirShapedTypeComponentsCallback callback, void *userData) {
138 std::optional<RegisteredOperationName> info =
139 getRegisteredOperationName(context, opName);
140 if (!info)
142
143 std::optional<Location> maybeLocation = maybeGetLocation(location);
144 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
145 DictionaryAttr attributeDict = unwrapAttributes(attributes);
146 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
147 unwrapRegions(nRegions, regions);
148
149 SmallVector<ShapedTypeComponents> inferredTypeComponents;
150 // The C API passes an opaque void*; we trust the caller to pass the correct
151 // properties type for this operation.
152 PropertyRef propertyRef =
153 properties ? PropertyRef(info->getOpPropertiesTypeID(), properties)
154 : PropertyRef();
155 if (failed(info->getInterface<InferShapedTypeOpInterface>()
156 ->inferReturnTypeComponents(
157 unwrap(context), maybeLocation,
158 mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
159 attributeDict, propertyRef, unwrappedRegions,
160 inferredTypeComponents)))
162
163 bool hasRank;
164 intptr_t rank;
165 const int64_t *shapeData;
166 for (const ShapedTypeComponents &t : inferredTypeComponents) {
167 if (t.hasRank()) {
168 hasRank = true;
169 rank = t.getDims().size();
170 shapeData = t.getDims().data();
171 } else {
172 hasRank = false;
173 rank = 0;
174 shapeData = nullptr;
175 }
176 callback(hasRank, rank, shapeData, wrap(t.getElementType()),
177 wrap(t.getAttribute()), userData);
178 }
180}
181
182//===---------------------------------------------------------------------===//
183// ConditionallySpeculatable
184//===---------------------------------------------------------------------===//
185
187 return wrap(ConditionallySpeculatable::getInterfaceID());
188}
189
190/// Fallback model for the ConditionallySpeculatable interface that uses C API
191/// callbacks.
194 ConditionallySpeculatableOpInterfaceFallbackModel> {
195public:
196 /// Sets the callbacks that this FallbackModel will use.
197 /// NB: the callbacks can only be set through this method as the
198 /// RegisteredOperationName::attachInterface mechanism default-constructs
199 /// the FallbackModel without being able to provide arguments.
200 void
202 this->callbacks = callbacks;
203 }
204
206 if (callbacks.destruct)
207 callbacks.destruct(callbacks.userData);
208 }
209
211 return ConditionallySpeculatable::getInterfaceID();
212 }
213
214 static bool classof(const mlir::ConditionallySpeculatable::Concept *op) {
215 // Enable casting back to the FallbackModel from the Interface. This is
216 // necessary as attachInterface(...) default-constructs the FallbackModel
217 // without being able to pass in the callbacks and returns just the Concept.
218 return true;
219 }
220
222 assert(callbacks.getSpeculatability &&
223 "getSpeculatability callback not set");
224
225 switch (callbacks.getSpeculatability(wrap(op), callbacks.userData)) {
232 }
233 llvm_unreachable("unknown speculatability");
234 }
235
236private:
238};
239
240/// Attach a ConditionallySpeculatable FallbackModel to the given named op.
241/// The FallbackModel uses the provided callbacks to implement the interface.
243 MlirContext ctx, MlirStringRef opName,
245 // Look up the operation definition in the context.
246 std::optional<RegisteredOperationName> opInfo =
248
249 assert(opInfo.has_value() && "operation not found in context");
250
251 // NB: the following default-constructs the FallbackModel _without_ being able
252 // to provide arguments.
253 opInfo->attachInterface<ConditionallySpeculatableOpInterfaceFallbackModel>();
254 // Cast to get the underlying FallbackModel and set the callbacks.
255 auto *model = cast<ConditionallySpeculatableOpInterfaceFallbackModel>(
256 opInfo
257 ->getInterface<ConditionallySpeculatableOpInterfaceFallbackModel>());
258 assert(model &&
259 "Failed to get ConditionallySpeculatableOpInterfaceFallbackModel");
260 model->setCallbacks(callbacks);
261}
262
264 MlirOperation operation) {
265 auto iface = dyn_cast<ConditionallySpeculatable>(unwrap(operation));
266 assert(iface && "operation does not implement ConditionallySpeculatable");
267
268 switch (iface.getSpeculatability()) {
275 }
276 llvm_unreachable("unknown speculatability");
277}
278
279//===---------------------------------------------------------------------===//
280// MemoryEffectOpInterface
281//===---------------------------------------------------------------------===//
282
284 return wrap(MemoryEffectOpInterface::getInterfaceID());
285}
286
287/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
290 MemoryEffectOpInterfaceFallbackModel> {
291public:
292 /// Sets the callbacks that this FallbackModel will use.
293 /// NB: the callbacks can only be set through this method as the
294 /// RegisteredOperationName::attachInterface mechanism default-constructs
295 /// the FallbackModel without being able to provide arguments.
297 this->callbacks = callbacks;
298 }
299
301 if (callbacks.destruct)
302 callbacks.destruct(callbacks.userData);
303 }
304
306 return MemoryEffectOpInterface::getInterfaceID();
307 }
308
309 static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
310 // Enable casting back to the FallbackModel from the Interface. This is
311 // necessary as attachInterface(...) default-constructs the FallbackModel
312 // without being able to pass in the callbacks and returns just the Concept.
313 return true;
314 }
315
316 void
319 assert(callbacks.getEffects && "getEffects callback not set");
320 MlirMemoryEffectInstancesList cEffects = wrap(&effects);
321 callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
322 }
323
324private:
326};
327
328/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
329/// The FallbackModel uses the provided callbacks to implement the interface.
331 MlirContext ctx, MlirStringRef opName,
333 // Look up the operation definition in the context
334 std::optional<RegisteredOperationName> opInfo =
336
337 assert(opInfo.has_value() && "operation not found in context");
338
339 // NB: the following default-constructs the FallbackModel _without_ being able
340 // to provide arguments.
341 opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
342 // Cast to get the underlying FallbackModel and set the callbacks.
343 auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
344 opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
345 assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
346 model->setCallbacks(callbacks);
347}
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
MlirTypeID mlirConditionallySpeculatableOpInterfaceTypeID()
Returns the interface TypeID of the ConditionallySpeculatable interface.
MlirSpeculatability mlirConditionallySpeculatableOpInterfaceGetSpeculatability(MlirOperation operation)
Returns the speculatability of the given operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID()
Returns the interface TypeID of the MemoryEffectsOpInterface.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
void mlirConditionallySpeculatableOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirConditionallySpeculatableOpInterfaceCallbacks callbacks)
Attach a ConditionallySpeculatable FallbackModel to the given named op.
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition Wrap.h:40
Fallback model for the ConditionallySpeculatable interface that uses C API callbacks.
static bool classof(const mlir::ConditionallySpeculatable::Concept *op)
Speculation::Speculatability getSpeculatability(Operation *op) const
void setCallbacks(MlirConditionallySpeculatableOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op)
void getEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects) const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:376
MlirSpeculatability
Enum representing the speculatability of an operation.
Definition Interfaces.h:105
@ MlirSpeculatabilityRecursivelySpeculatable
The operation is speculatable if all nested operations are speculatable.
Definition Interfaces.h:111
@ MlirSpeculatabilitySpeculatable
The operation is speculatable.
Definition Interfaces.h:109
@ MlirSpeculatabilityNotSpeculatable
The operation is not speculatable.
Definition Interfaces.h:107
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
Definition Interfaces.h:87
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
Definition Interfaces.h:61
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
Include the generated interface declarations.
Callbacks for implementing ConditionallySpeculatable from external code.
Definition Interfaces.h:119
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Callbacks for implementing MemoryEffectsOpInterface from external code.
Definition Interfaces.h:151
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80