MLIR 23.0.0git
Transform.cpp
Go to the documentation of this file.
1//===- Transform.cpp - C Interface for Transform dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir-c/Support.h"
14#include "mlir/CAPI/Rewrite.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21
23 transform::TransformDialect)
24
25//===---------------------------------------------------------------------===//
26// AnyOpType
27//===---------------------------------------------------------------------===//
28
29bool mlirTypeIsATransformAnyOpType(MlirType type) {
30 return isa<transform::AnyOpType>(unwrap(type));
31}
32
34 return wrap(transform::AnyOpType::getTypeID());
35}
36
37MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
38 return wrap(transform::AnyOpType::get(unwrap(ctx)));
39}
40
42 return wrap(transform::AnyOpType::name);
43}
44
45//===---------------------------------------------------------------------===//
46// AnyParamType
47//===---------------------------------------------------------------------===//
48
50 return isa<transform::AnyParamType>(unwrap(type));
51}
52
54 return wrap(transform::AnyParamType::getTypeID());
55}
56
57MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
58 return wrap(transform::AnyParamType::get(unwrap(ctx)));
59}
60
62 return wrap(transform::AnyParamType::name);
63}
64
65//===---------------------------------------------------------------------===//
66// AnyValueType
67//===---------------------------------------------------------------------===//
68
70 return isa<transform::AnyValueType>(unwrap(type));
71}
72
74 return wrap(transform::AnyValueType::getTypeID());
75}
76
77MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
78 return wrap(transform::AnyValueType::get(unwrap(ctx)));
79}
80
82 return wrap(transform::AnyValueType::name);
83}
84
85//===---------------------------------------------------------------------===//
86// OperationType
87//===---------------------------------------------------------------------===//
88
90 return isa<transform::OperationType>(unwrap(type));
91}
92
94 return wrap(transform::OperationType::getTypeID());
95}
96
97MlirType mlirTransformOperationTypeGet(MlirContext ctx,
98 MlirStringRef operationName) {
99 return wrap(
100 transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
101}
102
104 return wrap(transform::OperationType::name);
105}
106
108 return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
109}
110
111//===---------------------------------------------------------------------===//
112// ParamType
113//===---------------------------------------------------------------------===//
114
115bool mlirTypeIsATransformParamType(MlirType type) {
116 return isa<transform::ParamType>(unwrap(type));
117}
118
120 return wrap(transform::ParamType::getTypeID());
121}
122
123MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
124 return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
125}
126
128 return wrap(transform::ParamType::name);
129}
130
131MlirType mlirTransformParamTypeGetType(MlirType type) {
132 return wrap(cast<transform::ParamType>(unwrap(type)).getType());
133}
134
135//===---------------------------------------------------------------------===//
136// TransformRewriter
137//===---------------------------------------------------------------------===//
138
139/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
140MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
142 mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
143 return wrap(base);
144}
145
146//===---------------------------------------------------------------------===//
147// TransformResults
148//===---------------------------------------------------------------------===//
149
150void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
151 intptr_t numOps, MlirOperation *ops) {
153 opsVec.reserve(numOps);
154 for (intptr_t i = 0; i < numOps; ++i)
155 opsVec.push_back(unwrap(ops[i]));
156 unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
157}
158
159void mlirTransformResultsSetValues(MlirTransformResults results,
160 MlirValue result, intptr_t numValues,
161 MlirValue *values) {
162 SmallVector<Value> valuesVec;
163 valuesVec.reserve(numValues);
164 for (intptr_t i = 0; i < numValues; ++i)
165 valuesVec.push_back(unwrap(values[i]));
166 unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
167}
168
169void mlirTransformResultsSetParams(MlirTransformResults results,
170 MlirValue result, intptr_t numParams,
171 MlirAttribute *params) {
172 SmallVector<Attribute> paramsVec;
173 paramsVec.reserve(numParams);
174 for (intptr_t i = 0; i < numParams; ++i)
175 paramsVec.push_back(unwrap(params[i]));
176 unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
177}
178
179//===---------------------------------------------------------------------===//
180// TransformState
181//===---------------------------------------------------------------------===//
182
183void mlirTransformStateForEachPayloadOp(MlirTransformState state,
184 MlirValue value,
185 MlirOperationCallback callback,
186 void *userData) {
187 for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
188 callback(wrap(op), userData);
189}
190
191void mlirTransformStateForEachPayloadValue(MlirTransformState state,
192 MlirValue value,
193 MlirValueCallback callback,
194 void *userData) {
195 for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
196 callback(wrap(val), userData);
197}
198
199void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
200 MlirAttributeCallback callback,
201 void *userData) {
202 for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
203 callback(wrap(attr), userData);
204}
205
206//===---------------------------------------------------------------------===//
207// TransformOpInterface
208//===---------------------------------------------------------------------===//
209
211 return wrap(transform::TransformOpInterface::getInterfaceID());
212}
213
214/// Fallback model for the TransformOpInterface that uses C API callbacks.
217 TransformOpInterfaceFallbackModel> {
218public:
219 /// Sets the callbacks that this FallbackModel will use.
220 /// NB: the callbacks can only be set through this method as the
221 /// RegisteredOperationName::attachInterface mechanism default-constructs
222 /// the FallbackModel without being able to provide arguments.
224 this->callbacks = callbacks;
225 }
226
228 if (callbacks.destruct)
229 callbacks.destruct(callbacks.userData);
230 }
231
233 return transform::TransformOpInterface::getInterfaceID();
234 }
235
237 TransformOpInterfaceInterfaceTraits::Concept *op) {
238 // Enable casting back to the FallbackModel from the Interface. This is
239 // necessary as attachInterface(...) default-constructs the FallbackModel
240 // without being able to pass in the callbacks and returns just the Concept.
241 return true;
242 }
243
246 ::mlir::transform::TransformResults &transformResults,
248 assert(callbacks.apply && "apply callback not set");
249
251 callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
252 wrap(&state), callbacks.userData);
253
254 switch (status) {
258 // TODO: enable passing diagnostic info from C API to C++ API.
260 *(op->emitError()
261 << "TransformOpInterfaceFallbackModel: silenceable failure")
262 .getUnderlyingDiagnostic()));
265 }
266 llvm_unreachable("unknown transform status");
267 }
268
270 assert(callbacks.allowsRepeatedHandleOperands &&
271 "allowsRepeatedHandleOperands callback not set");
272 return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
273 }
274
275private:
277};
278
279/// Attach a TransformOpInterface FallbackModel to the given named operation.
280/// The FallbackModel uses the provided callbacks to implement the interface.
282 MlirContext ctx, MlirStringRef opName,
284 // Look up the operation definition in the context.
285 std::optional<RegisteredOperationName> opInfo =
287
288 assert(opInfo.has_value() && "operation not found in context");
289
290 // NB: the following default-constructs the FallbackModel _without_ being able
291 // to provide arguments.
292 opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
293 // Cast to get the underlying FallbackModel and set the callbacks.
294 auto *model = cast<TransformOpInterfaceFallbackModel>(
295 opInfo->getInterface<TransformOpInterfaceFallbackModel>());
296
297 assert(model && "Failed to get TransformOpInterfaceFallbackModel");
298 model->setCallbacks(callbacks);
299}
300
301//===---------------------------------------------------------------------===//
302// MemoryEffectsOpInterface helpers
303//===---------------------------------------------------------------------===//
304
305/// Set the effect for the operands to only read the transform handles.
306void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
307 MlirMemoryEffectInstancesList effects) {
308 MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
309 transform::onlyReadsHandle(operandArray, *unwrap(effects));
310}
311
312/// Set the effect for the operands to consuming the transform handles.
313void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
314 MlirMemoryEffectInstancesList effects) {
315 MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
316 transform::consumesHandle(operandArray, *unwrap(effects));
317}
318
319/// Set the effect for the results to that they produce transform handles.
320void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
321 MlirMemoryEffectInstancesList effects) {
322 // NB: calling `producesHandle()` `numResults` as we cannot cast array of
323 // `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
324 // to Python). `producesHandle` iterates over the given `ResultRange` anyway.
326 for (intptr_t i = 0; i < numResults; ++i) {
327 auto opResult = cast<OpResult>(unwrap(results[i]));
328 transform::producesHandle(ResultRange(opResult), effectList);
329 }
330}
331
332/// Set the effect of potentially modifying payload IR.
333void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects) {
335}
336
337/// Set the effect of potentially reading payload IR.
338void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects) {
340}
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName)
MlirTypeID mlirTransformAnyValueTypeGetTypeID(void)
Definition Transform.cpp:73
MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter)
Casts a MlirTransformRewriter to a MlirRewriterBase.
void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects)
Set the effect of potentially reading payload IR.
void mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value, MlirOperationCallback callback, void *userData)
Iterate over payload operations associated with the transform IR value.
MlirTypeID mlirTransformOperationTypeGetTypeID(void)
Definition Transform.cpp:93
bool mlirTypeIsATransformAnyValueType(MlirType type)
Definition Transform.cpp:69
MlirType mlirTransformParamTypeGetType(MlirType type)
MlirStringRef mlirTransformParamTypeGetName(void)
void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults, MlirMemoryEffectInstancesList effects)
Set the effect for the results to that they produce transform handles.
MlirTypeID mlirTransformAnyOpTypeGetTypeID(void)
Definition Transform.cpp:33
MlirTypeID mlirTransformAnyParamTypeGetTypeID(void)
Definition Transform.cpp:53
MlirTypeID mlirTransformParamTypeGetTypeID(void)
MlirStringRef mlirTransformAnyParamTypeGetName(void)
Definition Transform.cpp:61
void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Set the effect for the operands to only read the transform handles.
void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result, intptr_t numOps, MlirOperation *ops)
Set the payload operations for a transform result by iterating over a list.
bool mlirTypeIsATransformParamType(MlirType type)
void mlirTransformOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirTransformOpInterfaceCallbacks callbacks)
Attach a TransformOpInterface FallbackModel to the given named operation.
bool mlirTypeIsATransformAnyParamType(MlirType type)
Definition Transform.cpp:49
void mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value, MlirValueCallback callback, void *userData)
Iterate over payload values associated with the transform IR value.
bool mlirTypeIsATransformOperationType(MlirType type)
Definition Transform.cpp:89
MlirStringRef mlirTransformAnyOpTypeGetName(void)
Definition Transform.cpp:41
void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value, MlirAttributeCallback callback, void *userData)
Iterate over parameters associated with the transform IR value.
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx)
Definition Transform.cpp:37
MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName)
Definition Transform.cpp:97
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type)
MlirStringRef mlirTransformAnyValueTypeGetName(void)
Definition Transform.cpp:81
void mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result, intptr_t numParams, MlirAttribute *params)
Set the parameters for a transform result by iterating over a list.
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx)
Definition Transform.cpp:57
MlirStringRef mlirTransformOperationTypeGetName(void)
MlirType mlirTransformAnyValueTypeGet(MlirContext ctx)
Definition Transform.cpp:77
void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Set the effect for the operands to consuming the transform handles.
void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects)
Set the effect of potentially modifying payload IR.
void mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result, intptr_t numValues, MlirValue *values)
Set the payload values for a transform result by iterating over a list.
MlirTypeID mlirTransformOpInterfaceTypeID(void)
Returns the interface TypeID of the TransformOpInterface.
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type)
Fallback model for the TransformOpInterface that uses C API callbacks.
void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks)
Sets the callbacks that this FallbackModel will use.
bool allowsRepeatedHandleOperands(Operation *op) const
static bool classof(const mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept *op)
::mlir::DiagnosedSilenceableFailure apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state) const
Attributes are known-constant values of operations.
Definition Attributes.h:25
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
void(* MlirOperationCallback)(MlirOperation, void *userData)
Callback for iterating over payload operations.
Definition Transform.h:150
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type)
MlirDiagnosedSilenceableFailure
Enum representing the result of a transform operation.
Definition Transform.h:41
@ MlirDiagnosedSilenceableFailureSuccess
The operation succeeded.
Definition Transform.h:43
@ MlirDiagnosedSilenceableFailureDefiniteFailure
The operation failed definitively.
Definition Transform.h:47
@ MlirDiagnosedSilenceableFailureSilenceableFailure
The operation failed in a silenceable way.
Definition Transform.h:45
void(* MlirValueCallback)(MlirValue, void *userData)
Callback for iterating over payload values.
Definition Transform.h:160
void(* MlirAttributeCallback)(MlirAttribute, void *userData)
Callback for iterating over parameters.
Definition Transform.h:170
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
Callbacks for implementing TransformOpInterface from external code.
Definition Transform.h:186