19#include "llvm/ADT/ScopeExit.h"
26std::optional<RegisteredOperationName>
27getRegisteredOperationName(MlirContext context,
MlirStringRef opName) {
29 std::optional<RegisteredOperationName> info =
34std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
37 maybeLocation =
unwrap(location);
44 return unwrappedOperands;
47DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(
unwrap(attributes));
55 MlirRegion *regions) {
61 unwrappedRegions.reserve(nRegions);
62 for (
intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(
unwrap(*(regions + i)));
64 llvm::scope_exit cleaner([&]() {
65 for (
auto ®ion : unwrappedRegions)
68 return unwrappedRegions;
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(
unwrap(interfaceTypeID));
82 MlirTypeID interfaceTypeID) {
85 return info && info->hasInterface(
unwrap(interfaceTypeID));
89 return wrap(InferTypeOpInterface::getInterfaceID());
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties,
intptr_t nRegions, MlirRegion *regions,
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
107 unwrapRegions(nRegions, regions);
114 properties ?
PropertyRef(info->getOpPropertiesTypeID(), properties)
116 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
117 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
118 propertyRef, unwrappedRegions, inferredTypes)))
122 wrappedInferredTypes.reserve(inferredTypes.size());
123 for (
Type t : inferredTypes)
124 wrappedInferredTypes.push_back(
wrap(t));
125 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
130 return wrap(InferShapedTypeOpInterface::getInterfaceID());
134 MlirStringRef opName, MlirContext context, MlirLocation location,
135 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
136 void *properties,
intptr_t nRegions, MlirRegion *regions,
138 std::optional<RegisteredOperationName> info =
139 getRegisteredOperationName(context, opName);
143 std::optional<Location> maybeLocation = maybeGetLocation(location);
145 DictionaryAttr attributeDict = unwrapAttributes(attributes);
147 unwrapRegions(nRegions, regions);
153 properties ?
PropertyRef(info->getOpPropertiesTypeID(), properties)
155 if (failed(info->getInterface<InferShapedTypeOpInterface>()
156 ->inferReturnTypeComponents(
157 unwrap(context), maybeLocation,
159 attributeDict, propertyRef, unwrappedRegions,
160 inferredTypeComponents)))
169 rank = t.getDims().size();
170 shapeData = t.getDims().data();
176 callback(hasRank, rank, shapeData,
wrap(t.getElementType()),
177 wrap(t.getAttribute()), userData);
187 return wrap(ConditionallySpeculatable::getInterfaceID());
194 ConditionallySpeculatableOpInterfaceFallbackModel> {
202 this->callbacks = callbacks;
206 if (callbacks.destruct)
207 callbacks.destruct(callbacks.userData);
211 return ConditionallySpeculatable::getInterfaceID();
214 static bool classof(
const mlir::ConditionallySpeculatable::Concept *op) {
222 assert(callbacks.getSpeculatability &&
223 "getSpeculatability callback not set");
225 switch (callbacks.getSpeculatability(
wrap(op), callbacks.userData)) {
233 llvm_unreachable(
"unknown speculatability");
246 std::optional<RegisteredOperationName> opInfo =
249 assert(opInfo.has_value() &&
"operation not found in context");
255 auto *model = cast<ConditionallySpeculatableOpInterfaceFallbackModel>(
257 ->getInterface<ConditionallySpeculatableOpInterfaceFallbackModel>());
259 "Failed to get ConditionallySpeculatableOpInterfaceFallbackModel");
264 MlirOperation operation) {
265 auto iface = dyn_cast<ConditionallySpeculatable>(
unwrap(operation));
266 assert(iface &&
"operation does not implement ConditionallySpeculatable");
268 switch (iface.getSpeculatability()) {
276 llvm_unreachable(
"unknown speculatability");
284 return wrap(MemoryEffectOpInterface::getInterfaceID());
290 MemoryEffectOpInterfaceFallbackModel> {
297 this->callbacks = callbacks;
301 if (callbacks.destruct)
302 callbacks.destruct(callbacks.userData);
306 return MemoryEffectOpInterface::getInterfaceID();
309 static bool classof(
const mlir::MemoryEffectOpInterface::Concept *op) {
319 assert(callbacks.getEffects &&
"getEffects callback not set");
320 MlirMemoryEffectInstancesList cEffects =
wrap(&effects);
321 callbacks.getEffects(
wrap(op), cEffects, callbacks.userData);
334 std::optional<RegisteredOperationName> opInfo =
337 assert(opInfo.has_value() &&
"operation not found in context");
343 auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
345 assert(model &&
"Failed to get MemoryEffectOpInterfaceFallbackModel");
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
MlirTypeID mlirConditionallySpeculatableOpInterfaceTypeID()
Returns the interface TypeID of the ConditionallySpeculatable interface.
MlirSpeculatability mlirConditionallySpeculatableOpInterfaceGetSpeculatability(MlirOperation operation)
Returns the speculatability of the given operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID()
Returns the interface TypeID of the MemoryEffectsOpInterface.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
void mlirConditionallySpeculatableOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirConditionallySpeculatableOpInterfaceCallbacks callbacks)
Attach a ConditionallySpeculatable FallbackModel to the given named op.
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Fallback model for the ConditionallySpeculatable interface that uses C API callbacks.
static bool classof(const mlir::ConditionallySpeculatable::Concept *op)
static TypeID getInterfaceID()
Speculation::Speculatability getSpeculatability(Operation *op) const
~ConditionallySpeculatableOpInterfaceFallbackModel()
void setCallbacks(MlirConditionallySpeculatableOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op)
~MemoryEffectOpInterfaceFallbackModel()
static TypeID getInterfaceID()
void getEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects) const
Operation is the basic unit of execution within MLIR.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
MlirSpeculatability
Enum representing the speculatability of an operation.
@ MlirSpeculatabilityRecursivelySpeculatable
The operation is speculatable if all nested operations are speculatable.
@ MlirSpeculatabilitySpeculatable
The operation is speculatable.
@ MlirSpeculatabilityNotSpeculatable
The operation is not speculatable.
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
Include the generated interface declarations.
Callbacks for implementing ConditionallySpeculatable from external code.
A logical result value, essentially a boolean with named states.
Callbacks for implementing MemoryEffectsOpInterface from external code.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.