MLIR 23.0.0git
Interfaces.cpp
Go to the documentation of this file.
1
2
3//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir-c/Interfaces.h"
12
13#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/CAPI/Wrap.h"
17#include "mlir/IR/ValueRange.h"
19#include "llvm/ADT/ScopeExit.h"
20#include <optional>
21
22using namespace mlir;
23
24namespace {
25
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28 StringRef name(opName.data, opName.length);
29 std::optional<RegisteredOperationName> info =
31 return info;
32}
33
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
36 if (!mlirLocationIsNull(location))
37 maybeLocation = unwrap(location);
38 return maybeLocation;
39}
40
41SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42 SmallVector<Value> unwrappedOperands;
43 (void)unwrapList(nOperands, operands, unwrappedOperands);
44 return unwrappedOperands;
45}
46
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51 return attributeDict;
52}
53
54SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55 MlirRegion *regions) {
56 // Create a vector of unique pointers to regions and make sure they are not
57 // deleted when exiting the scope. This is a hack caused by C++ API expecting
58 // an list of unique pointers to regions (without ownership transfer
59 // semantics) and C API making ownership transfer explicit.
61 unwrappedRegions.reserve(nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64 llvm::scope_exit cleaner([&]() {
65 for (auto &region : unwrappedRegions)
66 region.release();
67 });
68 return unwrappedRegions;
69}
70
71} // namespace
72
73bool mlirOperationImplementsInterface(MlirOperation operation,
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(unwrap(interfaceTypeID));
78}
79
81 MlirContext context,
82 MlirTypeID interfaceTypeID) {
83 std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84 StringRef(operationName.data, operationName.length), unwrap(context));
85 return info && info->hasInterface(unwrap(interfaceTypeID));
86}
87
89 return wrap(InferTypeOpInterface::getInterfaceID());
90}
91
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
96 MlirTypesCallback callback, void *userData) {
97 StringRef name(opName.data, opName.length);
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
100 if (!info)
102
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
104 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
106 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107 unwrapRegions(nRegions, regions);
108
109 SmallVector<Type> inferredTypes;
110 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112 properties, unwrappedRegions, inferredTypes)))
114
115 SmallVector<MlirType> wrappedInferredTypes;
116 wrappedInferredTypes.reserve(inferredTypes.size());
117 for (Type t : inferredTypes)
118 wrappedInferredTypes.push_back(wrap(t));
119 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
121}
122
124 return wrap(InferShapedTypeOpInterface::getInterfaceID());
125}
126
128 MlirStringRef opName, MlirContext context, MlirLocation location,
129 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130 void *properties, intptr_t nRegions, MlirRegion *regions,
131 MlirShapedTypeComponentsCallback callback, void *userData) {
132 std::optional<RegisteredOperationName> info =
133 getRegisteredOperationName(context, opName);
134 if (!info)
136
137 std::optional<Location> maybeLocation = maybeGetLocation(location);
138 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139 DictionaryAttr attributeDict = unwrapAttributes(attributes);
140 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141 unwrapRegions(nRegions, regions);
142
143 SmallVector<ShapedTypeComponents> inferredTypeComponents;
144 if (failed(info->getInterface<InferShapedTypeOpInterface>()
145 ->inferReturnTypeComponents(
146 unwrap(context), maybeLocation,
147 mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148 attributeDict, properties, unwrappedRegions,
149 inferredTypeComponents)))
151
152 bool hasRank;
153 intptr_t rank;
154 const int64_t *shapeData;
155 for (const ShapedTypeComponents &t : inferredTypeComponents) {
156 if (t.hasRank()) {
157 hasRank = true;
158 rank = t.getDims().size();
159 shapeData = t.getDims().data();
160 } else {
161 hasRank = false;
162 rank = 0;
163 shapeData = nullptr;
164 }
165 callback(hasRank, rank, shapeData, wrap(t.getElementType()),
166 wrap(t.getAttribute()), userData);
167 }
169}
170
171//===---------------------------------------------------------------------===//
172// MemoryEffectOpInterface
173//===---------------------------------------------------------------------===//
174
176 return wrap(MemoryEffectOpInterface::getInterfaceID());
177}
178
179/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
182 MemoryEffectOpInterfaceFallbackModel> {
183public:
184 /// Sets the callbacks that this FallbackModel will use.
185 /// NB: the callbacks can only be set through this method as the
186 /// RegisteredOperationName::attachInterface mechanism default-constructs
187 /// the FallbackModel without being able to provide arguments.
189 this->callbacks = callbacks;
190 }
191
193 if (callbacks.destruct)
194 callbacks.destruct(callbacks.userData);
195 }
196
198 return MemoryEffectOpInterface::getInterfaceID();
199 }
200
201 static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
202 // Enable casting back to the FallbackModel from the Interface. This is
203 // necessary as attachInterface(...) default-constructs the FallbackModel
204 // without being able to pass in the callbacks and returns just the Concept.
205 return true;
206 }
207
208 void
211 assert(callbacks.getEffects && "getEffects callback not set");
212 MlirMemoryEffectInstancesList cEffects = wrap(&effects);
213 callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
214 }
215
216private:
218};
219
220/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
221/// The FallbackModel uses the provided callbacks to implement the interface.
223 MlirContext ctx, MlirStringRef opName,
225 // Look up the operation definition in the context
226 std::optional<RegisteredOperationName> opInfo =
228
229 assert(opInfo.has_value() && "operation not found in context");
230
231 // NB: the following default-constructs the FallbackModel _without_ being able
232 // to provide arguments.
233 opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
234 // Cast to get the underlying FallbackModel and set the callbacks.
235 auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
236 opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
237 assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
238 model->setCallbacks(callbacks);
239}
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID()
Returns the interface TypeID of the MemoryEffectsOpInterface.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition Wrap.h:40
Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op)
void getEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects) const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
Definition IR.h:370
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
Definition Interfaces.h:87
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
Definition Interfaces.h:61
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Callbacks for implementing MemoryEffectsOpInterface from external code.
Definition Interfaces.h:108
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80