19 #include "llvm/ADT/ScopeExit.h"
26 std::optional<RegisteredOperationName>
27 getRegisteredOperationName(MlirContext context,
MlirStringRef opName) {
29 std::optional<RegisteredOperationName> info =
34 std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
37 maybeLocation =
unwrap(location);
43 (void)
unwrapList(nOperands, operands, unwrappedOperands);
44 return unwrappedOperands;
47 DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
50 attributeDict = llvm::cast<DictionaryAttr>(
unwrap(attributes));
55 MlirRegion *regions) {
61 unwrappedRegions.reserve(nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(
unwrap(*(regions + i)));
64 auto cleaner = llvm::make_scope_exit([&]() {
65 for (
auto ®ion : unwrappedRegions)
68 return unwrappedRegions;
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(
unwrap(interfaceTypeID));
82 MlirTypeID interfaceTypeID) {
85 return info && info->hasInterface(
unwrap(interfaceTypeID));
89 return wrap(InferTypeOpInterface::getInterfaceID());
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
107 unwrapRegions(nRegions, regions);
110 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112 properties, unwrappedRegions, inferredTypes)))
116 wrappedInferredTypes.reserve(inferredTypes.size());
117 for (
Type t : inferredTypes)
118 wrappedInferredTypes.push_back(
wrap(t));
119 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
124 return wrap(InferShapedTypeOpInterface::getInterfaceID());
128 MlirStringRef opName, MlirContext context, MlirLocation location,
129 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130 void *properties, intptr_t nRegions, MlirRegion *regions,
132 std::optional<RegisteredOperationName> info =
133 getRegisteredOperationName(context, opName);
137 std::optional<Location> maybeLocation = maybeGetLocation(location);
139 DictionaryAttr attributeDict = unwrapAttributes(attributes);
141 unwrapRegions(nRegions, regions);
144 if (failed(info->getInterface<InferShapedTypeOpInterface>()
145 ->inferReturnTypeComponents(
146 unwrap(context), maybeLocation,
148 attributeDict, properties, unwrappedRegions,
149 inferredTypeComponents)))
154 const int64_t *shapeData;
158 rank = t.getDims().size();
159 shapeData = t.getDims().data();
165 callback(hasRank, rank, shapeData,
wrap(t.getElementType()),
166 wrap(t.getAttribute()), userData);
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID)
Returns true if the given operation implements an interface identified by its TypeID.
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID)
Returns true if the operation identified by its canonical string name implements the interface identi...
MlirTypeID mlirInferTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferTypeOpInterface.
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID()
Returns the interface TypeID of the InferShapedTypeOpInterface.
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
void(* MlirShapedTypeComponentsCallback)(bool, intptr_t, const int64_t *, MlirType, MlirAttribute, void *)
These callbacks are used to return multiple shaped type components from functions while transferring ...
void(* MlirTypesCallback)(intptr_t, MlirType *, void *)
These callbacks are used to return multiple types from functions while transferring ownership to the ...
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.