19#include "llvm/ADT/ScopeExit.h"
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context,
MlirStringRef opName) {
29 std::optional<RegisteredOperationName> info =
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
37 maybeLocation =
unwrap(location);
44 return unwrappedOperands;
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(
unwrap(attributes));
55 MlirRegion *regions) {
61 unwrappedRegions.reserve(nRegions);
62 for (
intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(
unwrap(*(regions + i)));
64 llvm::scope_exit cleaner([&]() {
65 for (
auto ®ion : unwrappedRegions)
68 return unwrappedRegions;
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(
unwrap(interfaceTypeID));
82 MlirTypeID interfaceTypeID) {
85 return info && info->hasInterface(
unwrap(interfaceTypeID));
89 return wrap(InferTypeOpInterface::getInterfaceID());
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties,
intptr_t nRegions, MlirRegion *regions,
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
107 unwrapRegions(nRegions, regions);
110 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112 properties, unwrappedRegions, inferredTypes)))
116 wrappedInferredTypes.reserve(inferredTypes.size());
117 for (
Type t : inferredTypes)
118 wrappedInferredTypes.push_back(
wrap(t));
119 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
124 return wrap(InferShapedTypeOpInterface::getInterfaceID());
128 MlirStringRef opName, MlirContext context, MlirLocation location,
129 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130 void *properties,
intptr_t nRegions, MlirRegion *regions,
132 std::optional<RegisteredOperationName> info =
133 getRegisteredOperationName(context, opName);
137 std::optional<Location> maybeLocation = maybeGetLocation(location);
139 DictionaryAttr attributeDict = unwrapAttributes(attributes);
141 unwrapRegions(nRegions, regions);
144 if (failed(info->getInterface<InferShapedTypeOpInterface>()
145 ->inferReturnTypeComponents(
146 unwrap(context), maybeLocation,
148 attributeDict, properties, unwrappedRegions,
149 inferredTypeComponents)))
158 rank = t.getDims().size();
159 shapeData = t.getDims().data();
165 callback(hasRank, rank, shapeData,
wrap(t.getElementType()),
166 wrap(t.getAttribute()), userData);
176 return wrap(MemoryEffectOpInterface::getInterfaceID());
182 MemoryEffectOpInterfaceFallbackModel> {
189 this->callbacks = callbacks;
193 if (callbacks.destruct)
194 callbacks.destruct(callbacks.userData);
198 return MemoryEffectOpInterface::getInterfaceID();
201 static bool classof(
const mlir::MemoryEffectOpInterface::Concept *op) {
211 assert(callbacks.getEffects &&
"getEffects callback not set");
212 MlirMemoryEffectInstancesList cEffects =
wrap(&effects);
213 callbacks.getEffects(
wrap(op), cEffects, callbacks.userData);
226 std::optional<RegisteredOperationName> opInfo =
229 assert(opInfo.has_value() &&
"operation not found in context");
235 auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
237 assert(model &&
"Failed to get MemoryEffectOpInterfaceFallbackModel");
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID()
Returns the interface TypeID of the MemoryEffectsOpInterface.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op)
~MemoryEffectOpInterfaceFallbackModel()
static TypeID getInterfaceID()
void getEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects) const
Operation is the basic unit of execution within MLIR.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Callbacks for implementing MemoryEffectsOpInterface from external code.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.