27 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
34 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
35 transform::TypeConverterBuilderOpInterface builder) {
36 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
37 return emitOpError(
"expected LLVMTypeConverter");
51 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
55 outputs.insert_range(state.getPayloadValues(getOutputs()));
59 llvm::range_size(state.getPayloadValues(getOutputs()))) {
61 <<
"cast and call output values must be unique";
66 auto insertionOps = state.getPayloadOps(getInsertionPoint());
67 if (!llvm::hasSingleElement(insertionOps)) {
69 <<
"Only one op can be specified as an insertion point";
71 bool insertAfter = getInsertAfter();
72 Operation *insertionPoint = *insertionOps.begin();
77 for (
Value output : outputs) {
82 bool doesDominate = insertAfter
83 ? dom.properlyDominates(insertionPoint, user)
84 : dom.dominates(insertionPoint, user);
87 <<
"User " << user <<
" is not dominated by insertion point "
93 for (
Value input : inputs) {
97 bool doesDominate = insertAfter
98 ? dom.dominates(input, insertionPoint)
99 : dom.properlyDominates(input, insertionPoint);
102 <<
"input " << input <<
" does not dominate insertion point "
109 func::FuncOp targetFunction =
nullptr;
110 if (getFunctionName()) {
111 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112 insertionPoint, *getFunctionName());
113 if (!targetFunction) {
115 <<
"unresolved symbol " << *getFunctionName();
117 }
else if (getFunction()) {
118 auto payloadOps = state.getPayloadOps(getFunction());
119 if (!llvm::hasSingleElement(payloadOps)) {
122 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123 if (!targetFunction) {
127 llvm_unreachable(
"Invalid CastAndCall op without a function to call");
133 if (targetFunction.getNumArguments() != inputs.size()) {
135 <<
"mismatch between number of function arguments "
136 << targetFunction.getNumArguments() <<
" and number of inputs "
139 if (targetFunction.getNumResults() != outputs.size()) {
141 <<
"mismatch between number of function results "
142 << targetFunction->getNumResults() <<
" and number of outputs "
148 if (!getRegion().empty()) {
149 for (
Operation &op : getRegion().front()) {
150 cast<transform::TypeConverterBuilderOpInterface>(&op)
151 .populateTypeMaterializations(converter);
160 for (
auto [input, type] :
161 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162 if (input.getType() != type) {
164 rewriter, input.getLoc(), type, input);
167 << input <<
" to type " << type;
173 auto callOp = func::CallOp::create(rewriter, insertionPoint->
getLoc(),
174 targetFunction, inputs);
178 for (
auto [output, newOutput] :
179 llvm::zip_equal(outputs, callOp.getResults())) {
180 Value convertedOutput = newOutput;
181 if (output.getType() != newOutput.getType()) {
183 rewriter, output.getLoc(), output.getType(), newOutput);
184 if (!convertedOutput) {
186 <<
"Failed to materialize conversion of " << newOutput
187 <<
" to type " << output.getType();
192 results.
set(cast<OpResult>(getResult()), {callOp});
197 if (!getRegion().empty()) {
198 for (
Operation &op : getRegion().front()) {
199 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
201 <<
"expected children ops to implement "
202 "TypeConverterBuilderOpInterface";
203 diag.attachNote(op.getLoc()) <<
"op without interface";
208 if (!getFunction() && !getFunctionName()) {
209 return emitOpError() <<
"expected a function handle or name to call";
211 if (getFunction() && getFunctionName()) {
212 return emitOpError() <<
"function handle and name are mutually exclusive";
217 void transform::CastAndCallOp::getEffects(
238 auto payloadOps = state.getPayloadOps(getModule());
239 if (!llvm::hasSingleElement(payloadOps))
242 auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
245 <<
"target is expected to be module operation";
247 func::FuncOp funcOp =
248 targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
251 <<
"function with name '" << getFunctionName() <<
"' not found";
253 unsigned numArgs = funcOp.getNumArguments();
254 unsigned numResults = funcOp.getNumResults();
257 if (numArgs != getArgsInterchange().size())
259 <<
"function with name '" << getFunctionName() <<
"' has " << numArgs
260 <<
" arguments, but " << getArgsInterchange().size()
261 <<
" args interchange were given";
263 if (numResults != getResultsInterchange().size())
265 <<
"function with name '" << getFunctionName() <<
"' has "
266 << numResults <<
" results, but " << getResultsInterchange().size()
267 <<
" results interchange were given";
271 argsInterchange.insert_range(getArgsInterchange());
272 resultsInterchange.insert_range(getResultsInterchange());
273 if (argsInterchange.size() != getArgsInterchange().size())
275 <<
"args interchange must be unique";
277 if (resultsInterchange.size() != getResultsInterchange().size())
279 <<
"results interchange must be unique";
282 for (
unsigned index : argsInterchange) {
283 if (index >= numArgs) {
285 <<
"args interchange index " << index
286 <<
" is out of bounds for function with name '"
287 << getFunctionName() <<
"' with " << numArgs <<
" arguments";
290 for (
unsigned index : resultsInterchange) {
291 if (index >= numResults) {
293 <<
"results interchange index " << index
294 <<
" is out of bounds for function with name '"
295 << getFunctionName() <<
"' with " << numResults <<
" results";
300 rewriter, funcOp, argsInterchange.getArrayRef(),
301 resultsInterchange.getArrayRef());
302 if (
failed(newFuncOpOrFailure))
304 <<
"failed to replace function signature '" << getFunctionName()
305 <<
"' with new order";
307 if (getAdjustFuncCalls()) {
309 targetModuleOp.walk([&](func::CallOp callOp) {
310 if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
311 callOps.push_back(callOp);
314 for (func::CallOp callOp : callOps)
316 argsInterchange.getArrayRef(),
317 resultsInterchange.getArrayRef());
320 results.
set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
321 results.
set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
326 void transform::ReplaceFuncSignatureOp::getEffects(
338 class FuncTransformDialectExtension
340 FuncTransformDialectExtension> {
347 declareGeneratedDialect<LLVM::LLVMDialect>();
349 registerTransformOps<
351 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
357 #define GET_OP_CLASSES
358 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new call operation with the values as the original call operation, but with the arguments r...
void registerTransformDialectExtension(DialectRegistry ®istry)
FailureOr< FuncOp > replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new function operation with the same name as the original function operation,...
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...