20 #include "llvm/ADT/STLExtras.h" 
   28 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
 
   35 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
 
   36     transform::TypeConverterBuilderOpInterface builder) {
 
   37   if (builder.getTypeConverterType() != 
"LLVMTypeConverter")
 
   38     return emitOpError(
"expected LLVMTypeConverter");
 
   52     llvm::append_range(inputs, state.getPayloadValues(getInputs()));
 
   56     outputs.insert_range(state.getPayloadValues(getOutputs()));
 
   60         llvm::range_size(state.getPayloadValues(getOutputs()))) {
 
   62              << 
"cast and call output values must be unique";
 
   67   auto insertionOps = state.getPayloadOps(getInsertionPoint());
 
   68   if (!llvm::hasSingleElement(insertionOps)) {
 
   70            << 
"Only one op can be specified as an insertion point";
 
   72   bool insertAfter = getInsertAfter();
 
   73   Operation *insertionPoint = *insertionOps.begin();
 
   78   for (
Value output : outputs) {
 
   79     for (
Operation *user : output.getUsers()) {
 
   83       bool doesDominate = insertAfter
 
   84                               ? dom.properlyDominates(insertionPoint, user)
 
   85                               : dom.dominates(insertionPoint, user);
 
   88                << 
"User " << user << 
" is not dominated by insertion point " 
   94   for (
Value input : inputs) {
 
   98     bool doesDominate = insertAfter
 
   99                             ? dom.dominates(input, insertionPoint)
 
  100                             : dom.properlyDominates(input, insertionPoint);
 
  103              << 
"input " << input << 
" does not dominate insertion point " 
  110   func::FuncOp targetFunction = 
nullptr;
 
  111   if (getFunctionName()) {
 
  112     targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
 
  113         insertionPoint, *getFunctionName());
 
  114     if (!targetFunction) {
 
  116              << 
"unresolved symbol " << *getFunctionName();
 
  118   } 
else if (getFunction()) {
 
  119     auto payloadOps = state.getPayloadOps(getFunction());
 
  120     if (!llvm::hasSingleElement(payloadOps)) {
 
  123     targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
 
  124     if (!targetFunction) {
 
  128     llvm_unreachable(
"Invalid CastAndCall op without a function to call");
 
  134   if (targetFunction.getNumArguments() != inputs.size()) {
 
  136            << 
"mismatch between number of function arguments " 
  137            << targetFunction.getNumArguments() << 
" and number of inputs " 
  140   if (targetFunction.getNumResults() != outputs.size()) {
 
  142            << 
"mismatch between number of function results " 
  143            << targetFunction->getNumResults() << 
" and number of outputs " 
  149   if (!getRegion().empty()) {
 
  150     for (
Operation &op : getRegion().front()) {
 
  151       cast<transform::TypeConverterBuilderOpInterface>(&op)
 
  152           .populateTypeMaterializations(converter);
 
  161   for (
auto [input, type] :
 
  162        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
 
  163     if (input.getType() != type) {
 
  165           rewriter, input.getLoc(), type, input);
 
  168                                      << input << 
" to type " << type;
 
  174   auto callOp = func::CallOp::create(rewriter, insertionPoint->
getLoc(),
 
  175                                      targetFunction, inputs);
 
  179   for (
auto [output, newOutput] :
 
  180        llvm::zip_equal(outputs, callOp.getResults())) {
 
  181     Value convertedOutput = newOutput;
 
  182     if (output.getType() != newOutput.getType()) {
 
  184           rewriter, output.getLoc(), output.getType(), newOutput);
 
  185       if (!convertedOutput) {
 
  187                << 
"Failed to materialize conversion of " << newOutput
 
  188                << 
" to type " << output.getType();
 
  193   results.
set(cast<OpResult>(getResult()), {callOp});
 
  198   if (!getRegion().empty()) {
 
  199     for (
Operation &op : getRegion().front()) {
 
  200       if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
 
  202                                   << 
"expected children ops to implement " 
  203                                      "TypeConverterBuilderOpInterface";
 
  204         diag.attachNote(op.getLoc()) << 
"op without interface";
 
  209   if (!getFunction() && !getFunctionName()) {
 
  210     return emitOpError() << 
"expected a function handle or name to call";
 
  212   if (getFunction() && getFunctionName()) {
 
  213     return emitOpError() << 
"function handle and name are mutually exclusive";
 
  218 void transform::CastAndCallOp::getEffects(
 
  239   auto payloadOps = state.getPayloadOps(getModule());
 
  240   if (!llvm::hasSingleElement(payloadOps))
 
  243   auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
 
  246            << 
"target is expected to be module operation";
 
  248   func::FuncOp funcOp =
 
  249       targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
 
  252            << 
"function with name '" << getFunctionName() << 
"' not found";
 
  254   unsigned numArgs = funcOp.getNumArguments();
 
  255   unsigned numResults = funcOp.getNumResults();
 
  258   if (numArgs != getArgsInterchange().size())
 
  260            << 
"function with name '" << getFunctionName() << 
"' has " << numArgs
 
  261            << 
" arguments, but " << getArgsInterchange().size()
 
  262            << 
" args interchange were given";
 
  264   if (numResults != getResultsInterchange().size())
 
  266            << 
"function with name '" << getFunctionName() << 
"' has " 
  267            << numResults << 
" results, but " << getResultsInterchange().size()
 
  268            << 
" results interchange were given";
 
  272   argsInterchange.insert_range(getArgsInterchange());
 
  273   resultsInterchange.insert_range(getResultsInterchange());
 
  274   if (argsInterchange.size() != getArgsInterchange().size())
 
  276            << 
"args interchange must be unique";
 
  278   if (resultsInterchange.size() != getResultsInterchange().size())
 
  280            << 
"results interchange must be unique";
 
  283   for (
unsigned index : argsInterchange) {
 
  284     if (index >= numArgs) {
 
  286              << 
"args interchange index " << index
 
  287              << 
" is out of bounds for function with name '" 
  288              << getFunctionName() << 
"' with " << numArgs << 
" arguments";
 
  291   for (
unsigned index : resultsInterchange) {
 
  292     if (index >= numResults) {
 
  294              << 
"results interchange index " << index
 
  295              << 
" is out of bounds for function with name '" 
  296              << getFunctionName() << 
"' with " << numResults << 
" results";
 
  302     oldArgToNewArg[oldArgIdx] = newArgIdx;
 
  305   for (
auto [newResIdx, oldResIdx] : 
llvm::enumerate(resultsInterchange))
 
  306     oldResToNewRes[oldResIdx] = newResIdx;
 
  309       rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
 
  310   if (
failed(newFuncOpOrFailure))
 
  312            << 
"failed to replace function signature '" << getFunctionName()
 
  313            << 
"' with new order";
 
  315   if (getAdjustFuncCalls()) {
 
  317     targetModuleOp.walk([&](func::CallOp callOp) {
 
  318       if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
 
  319         callOps.push_back(callOp);
 
  322     for (func::CallOp callOp : callOps)
 
  327   results.
set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
 
  328   results.
set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
 
  333 void transform::ReplaceFuncSignatureOp::getEffects(
 
  348   auto payloadOps = state.getPayloadOps(getModule());
 
  349   if (!llvm::hasSingleElement(payloadOps))
 
  352   auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
 
  355            << 
"target is expected to be module operation";
 
  357   func::FuncOp funcOp =
 
  358       targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
 
  361            << 
"function with name '" << getFunctionName() << 
"' is not found";
 
  363   auto transformationResult =
 
  365   if (
failed(transformationResult))
 
  367            << 
"failed to deduplicate function arguments of function " 
  370   auto [newFuncOp, newCallOp] = *transformationResult;
 
  372   results.
set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
 
  373   results.
set(cast<OpResult>(getTransformedFunction()), {newFuncOp});
 
  378 void transform::DeduplicateFuncArgsOp::getEffects(
 
  390 class FuncTransformDialectExtension
 
  392           FuncTransformDialectExtension> {
 
  399     declareGeneratedDialect<LLVM::LLVMDialect>();
 
  401     registerTransformOps<
 
  403 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" 
  409 #define GET_OP_CLASSES 
  410 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" 
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
void registerTransformDialectExtension(DialectRegistry ®istry)
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...