18#include "llvm/ADT/TypeSwitch.h"
23 transform::TransformDialect)
30 return isa<transform::AnyOpType>(
unwrap(type));
34 return wrap(transform::AnyOpType::getTypeID());
38 return wrap(transform::AnyOpType::get(
unwrap(ctx)));
42 return wrap(transform::AnyOpType::name);
50 return isa<transform::AnyParamType>(
unwrap(type));
54 return wrap(transform::AnyParamType::getTypeID());
58 return wrap(transform::AnyParamType::get(
unwrap(ctx)));
62 return wrap(transform::AnyParamType::name);
70 return isa<transform::AnyValueType>(
unwrap(type));
74 return wrap(transform::AnyValueType::getTypeID());
78 return wrap(transform::AnyValueType::get(
unwrap(ctx)));
82 return wrap(transform::AnyValueType::name);
90 return isa<transform::OperationType>(
unwrap(type));
94 return wrap(transform::OperationType::getTypeID());
100 transform::OperationType::get(
unwrap(ctx),
unwrap(operationName)));
104 return wrap(transform::OperationType::name);
108 return wrap(cast<transform::OperationType>(
unwrap(type)).getOperationName());
116 return isa<transform::ParamType>(
unwrap(type));
120 return wrap(transform::ParamType::getTypeID());
128 return wrap(transform::ParamType::name);
151 intptr_t numOps, MlirOperation *ops) {
153 opsVec.reserve(numOps);
154 for (
intptr_t i = 0; i < numOps; ++i)
155 opsVec.push_back(
unwrap(ops[i]));
163 valuesVec.reserve(numValues);
164 for (
intptr_t i = 0; i < numValues; ++i)
165 valuesVec.push_back(
unwrap(values[i]));
171 MlirAttribute *params) {
173 paramsVec.reserve(numParams);
174 for (
intptr_t i = 0; i < numParams; ++i)
175 paramsVec.push_back(
unwrap(params[i]));
188 callback(
wrap(op), userData);
196 callback(
wrap(val), userData);
203 callback(
wrap(attr), userData);
211 return wrap(transform::TransformOpInterface::getInterfaceID());
217 TransformOpInterfaceFallbackModel> {
224 this->callbacks = callbacks;
228 if (callbacks.destruct)
229 callbacks.destruct(callbacks.userData);
233 return transform::TransformOpInterface::getInterfaceID();
237 TransformOpInterfaceInterfaceTraits::Concept *op) {
248 assert(callbacks.apply &&
"apply callback not set");
251 callbacks.apply(
wrap(op),
wrap(&rewriter),
wrap(&transformResults),
252 wrap(&state), callbacks.userData);
261 <<
"TransformOpInterfaceFallbackModel: silenceable failure")
262 .getUnderlyingDiagnostic()));
266 llvm_unreachable(
"unknown transform status");
270 assert(callbacks.allowsRepeatedHandleOperands &&
271 "allowsRepeatedHandleOperands callback not set");
272 return callbacks.allowsRepeatedHandleOperands(
wrap(op), callbacks.userData);
285 std::optional<RegisteredOperationName> opInfo =
288 assert(opInfo.has_value() &&
"operation not found in context");
294 auto *model = cast<TransformOpInterfaceFallbackModel>(
297 assert(model &&
"Failed to get TransformOpInterfaceFallbackModel");
306 return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
313 PatternDescriptorOpInterfaceFallbackModel> {
320 this->callbacks = callbacks;
324 if (callbacks.destruct)
325 callbacks.destruct(callbacks.userData);
329 return transform::PatternDescriptorOpInterface::getInterfaceID();
334 PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
342 assert(callbacks.populatePatterns &&
"populatePatterns callback not set");
343 callbacks.populatePatterns(
wrap(op),
wrap(&patterns), callbacks.userData);
348 if (callbacks.populatePatternsWithState) {
349 callbacks.populatePatternsWithState(
wrap(op),
wrap(&patterns),
350 wrap(&state), callbacks.userData);
368 std::optional<RegisteredOperationName> opInfo =
371 assert(opInfo.has_value() &&
"operation not found in context");
377 auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
380 assert(model &&
"Failed to get PatternDescriptorOpInterfaceFallbackModel");
390 MlirMemoryEffectInstancesList effects) {
397 MlirMemoryEffectInstancesList effects) {
404 MlirMemoryEffectInstancesList effects) {
409 for (
intptr_t i = 0; i < numResults; ++i) {
410 auto opResult = cast<OpResult>(
unwrap(results[i]));
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName)
Fallback model for the PatternDescriptorOpInterface that uses C API callbacks.
static bool classof(const mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::Concept *op)
~PatternDescriptorOpInterfaceFallbackModel()
void populatePatterns(Operation *op, RewritePatternSet &patterns) const
void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
void populatePatternsWithState(Operation *op, RewritePatternSet &patterns, transform::TransformState &state) const
static TypeID getInterfaceID()
Attributes are known-constant values of operations.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an efficient unique identifier for a specific C++ type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Callbacks for implementing PatternDescriptorOpInterface from external code.
A pointer to a sized fragment of a string, not necessarily null-terminated.