19#include "llvm/ADT/ScopeExit.h"
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context,
MlirStringRef opName) {
29 std::optional<RegisteredOperationName> info =
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
37 maybeLocation =
unwrap(location);
44 return unwrappedOperands;
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(
unwrap(attributes));
55 MlirRegion *regions) {
61 unwrappedRegions.reserve(nRegions);
62 for (
intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(
unwrap(*(regions + i)));
64 llvm::scope_exit cleaner([&]() {
65 for (
auto ®ion : unwrappedRegions)
68 return unwrappedRegions;
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(
unwrap(interfaceTypeID));
82 MlirTypeID interfaceTypeID) {
85 return info && info->hasInterface(
unwrap(interfaceTypeID));
89 return wrap(InferTypeOpInterface::getInterfaceID());
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties,
intptr_t nRegions, MlirRegion *regions,
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
107 unwrapRegions(nRegions, regions);
114 properties ?
PropertyRef(info->getOpPropertiesTypeID(), properties)
116 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
117 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
118 propertyRef, unwrappedRegions, inferredTypes)))
122 wrappedInferredTypes.reserve(inferredTypes.size());
123 for (
Type t : inferredTypes)
124 wrappedInferredTypes.push_back(
wrap(t));
125 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
130 return wrap(InferShapedTypeOpInterface::getInterfaceID());
134 MlirStringRef opName, MlirContext context, MlirLocation location,
135 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
136 void *properties,
intptr_t nRegions, MlirRegion *regions,
138 std::optional<RegisteredOperationName> info =
139 getRegisteredOperationName(context, opName);
143 std::optional<Location> maybeLocation = maybeGetLocation(location);
145 DictionaryAttr attributeDict = unwrapAttributes(attributes);
147 unwrapRegions(nRegions, regions);
153 properties ?
PropertyRef(info->getOpPropertiesTypeID(), properties)
155 if (failed(info->getInterface<InferShapedTypeOpInterface>()
156 ->inferReturnTypeComponents(
157 unwrap(context), maybeLocation,
159 attributeDict, propertyRef, unwrappedRegions,
160 inferredTypeComponents)))
169 rank = t.getDims().size();
170 shapeData = t.getDims().data();
176 callback(hasRank, rank, shapeData,
wrap(t.getElementType()),
177 wrap(t.getAttribute()), userData);
187 return wrap(MemoryEffectOpInterface::getInterfaceID());
193 MemoryEffectOpInterfaceFallbackModel> {
200 this->callbacks = callbacks;
204 if (callbacks.destruct)
205 callbacks.destruct(callbacks.userData);
209 return MemoryEffectOpInterface::getInterfaceID();
212 static bool classof(
const mlir::MemoryEffectOpInterface::Concept *op) {
222 assert(callbacks.getEffects &&
"getEffects callback not set");
223 MlirMemoryEffectInstancesList cEffects =
wrap(&effects);
224 callbacks.getEffects(
wrap(op), cEffects, callbacks.userData);
237 std::optional<RegisteredOperationName> opInfo =
240 assert(opInfo.has_value() &&
"operation not found in context");
246 auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
248 assert(model &&
"Failed to get MemoryEffectOpInterfaceFallbackModel");
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID()
Returns the interface TypeID of the MemoryEffectsOpInterface.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op)
~MemoryEffectOpInterfaceFallbackModel()
static TypeID getInterfaceID()
void getEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects) const
Operation is the basic unit of execution within MLIR.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Callbacks for implementing MemoryEffectsOpInterface from external code.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.