17 #include "llvm/ADT/STLExtras.h" 
   18 #include "llvm/ADT/SmallVector.h" 
   19 #include "llvm/Support/DebugLog.h" 
   21 #define DEBUG_TYPE "func-utils" 
   40   if (!oldIdxToNewIdx.empty())
 
   41     numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
 
   44     newToOldIdxs[newIdx].push_back(oldIdx);
 
   53 template <
typename Element>
 
   58   for (
const auto &oldIdxs : newIdxToOldIdxs) {
 
   59     assert(llvm::all_of(oldIdxs,
 
   60                         [&origElements](
int idx) -> 
bool {
 
   62                                  static_cast<size_t>(idx) < origElements.size();
 
   64            "idx must be less than the number of elements in the original " 
   66     assert(!oldIdxs.empty() && 
"oldIdx must not be empty");
 
   67     Element origTypeToCheck = origElements[oldIdxs.front()];
 
   68     assert(llvm::all_of(oldIdxs,
 
   69                         [&](
int idx) -> 
bool {
 
   70                           return origElements[idx] == origTypeToCheck;
 
   72            "all oldIdxs must be equal");
 
   73     newElements.push_back(origTypeToCheck);
 
   78 FailureOr<func::FuncOp>
 
   84   assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
 
   85          "oldArgIdxToNewArgIdx must match the number of arguments in the " 
   88       funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
 
   89       "oldResIdxToNewResIdx must match the number of results in the function");
 
   91   if (!funcOp.getBody().hasOneBlock())
 
   93         funcOp, 
"expected function to have exactly one block");
 
  101       funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
 
  104   for (
const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
 
  105     locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
 
  110       funcOp.getFunctionType().getResults(), newResToOldResIdxs);
 
  113   auto newFuncOp = func::FuncOp::create(
 
  114       rewriter, funcOp.getLoc(), funcOp.getName(),
 
  117   Region &newRegion = newFuncOp.getBody();
 
  119   newFuncOp.setVisibility(funcOp.getVisibility());
 
  125   funcOp.getAllArgAttrs(argAttrs);
 
  126   for (
auto [oldArgIdx, newArgIdx] : 
llvm::enumerate(oldArgIdxToNewArgIdx))
 
  127     operandMapper.
map(funcOp.getArgument(oldArgIdx),
 
  128                       newFuncOp.getArgument(newArgIdx));
 
  129   for (
auto [newArgIdx, oldArgIdx] : 
llvm::enumerate(newArgIdxToOldArgIdxs))
 
  130     newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
 
  132   funcOp.getAllResultAttrs(resultAttrs);
 
  133   for (
auto [newResIdx, oldResIdx] : 
llvm::enumerate(newResToOldResIdxs))
 
  134     newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
 
  139     rewriter.
clone(op, operandMapper);
 
  142   auto returnOp = cast<func::ReturnOp>(
 
  143       newFuncOp.getFunctionBody().begin()->getTerminator());
 
  145   for (
const auto &oldResIdxs : newResToOldResIdxs)
 
  146     newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
 
  149   func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
 
  161   assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
 
  162          "oldArgIdxToNewArgIdx must match the number of operands in the call " 
  164   assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
 
  165          "oldResIdxToNewResIdx must match the number of results in the call " 
  172       getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
 
  177       getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
 
  183       func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
 
  184                            newResultTypes, newOperandsValues);
 
  185   newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
 
  186   for (
auto &&[oldResIdx, newResIdx] : 
llvm::enumerate(oldResIdxToNewResIdx))
 
  188                                 newCallOp.getResult(newResIdx));
 
  194 FailureOr<std::pair<func::FuncOp, func::CallOp>>
 
  198   auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
 
  199     if (callOp.getCallee() == funcOp.getSymName()) {
 
  200       if (!callOps.empty())
 
  202         return WalkResult::interrupt();
 
  203       callOps.push_back(callOp);
 
  208   if (traversalResult.wasInterrupted()) {
 
  209     LDBG() << 
"function " << funcOp.getName() << 
" has more than one callOp";
 
  213   if (callOps.empty()) {
 
  214     LDBG() << 
"function " << funcOp.getName() << 
" does not have any callOp";
 
  218   func::CallOp callOp = callOps.front();
 
  223   for (
auto [operandIdx, operand] : 
llvm::enumerate(callOp.getOperands())) {
 
  224     auto [iterator, inserted] = valueToNewArgIdx.insert(
 
  225         {operand, 
static_cast<int>(valueToNewArgIdx.size())});
 
  227     oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
 
  230   bool hasDuplicateOperands =
 
  231       valueToNewArgIdx.size() != callOp.getNumOperands();
 
  232   if (!hasDuplicateOperands) {
 
  233     LDBG() << 
"function " << funcOp.getName()
 
  234            << 
" does not have duplicate operands";
 
  240   for (
int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
 
  241     oldResIdxToNewResIdx[resultIdx] = resultIdx;
 
  245       rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
 
  246   if (
failed(newFuncOpOrFailure)) {
 
  247     LDBG() << 
"failed to replace function signature with name " 
  248            << funcOp.getName() << 
" with new order";
 
  253       rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
 
  255   return std::make_pair(*newFuncOpOrFailure, newCallOp);
 
static SmallVector< Element > getMappedElements(ArrayRef< Element > origElements, const llvm::SmallVector< llvm::SmallVector< int >> &newIdxToOldIdxs)
This method returns a new vector of elements that are mapped from the origElements based on the newId...
static llvm::SmallVector< llvm::SmallVector< int > > getInverseMapping(ArrayRef< int > oldIdxToNewIdx)
This method creates an inverse mapping of the provided map oldToNew.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static WalkResult advance()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
Include the generated interface declarations.