17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/DebugLog.h"
21 #define DEBUG_TYPE "func-utils"
40 if (!oldIdxToNewIdx.empty())
41 numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
44 newToOldIdxs[newIdx].push_back(oldIdx);
53 template <
typename Element>
58 for (
const auto &oldIdxs : newIdxToOldIdxs) {
59 assert(llvm::all_of(oldIdxs,
60 [&origElements](
int idx) ->
bool {
62 static_cast<size_t>(idx) < origElements.size();
64 "idx must be less than the number of elements in the original "
66 assert(!oldIdxs.empty() &&
"oldIdx must not be empty");
67 Element origTypeToCheck = origElements[oldIdxs.front()];
68 assert(llvm::all_of(oldIdxs,
69 [&](
int idx) ->
bool {
70 return origElements[idx] == origTypeToCheck;
72 "all oldIdxs must be equal");
73 newElements.push_back(origTypeToCheck);
78 FailureOr<func::FuncOp>
84 assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
85 "oldArgIdxToNewArgIdx must match the number of arguments in the "
88 funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
89 "oldResIdxToNewResIdx must match the number of results in the function");
91 if (!funcOp.getBody().hasOneBlock())
93 funcOp,
"expected function to have exactly one block");
101 funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
104 for (
const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
105 locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
110 funcOp.getFunctionType().getResults(), newResToOldResIdxs);
113 auto newFuncOp = func::FuncOp::create(
114 rewriter, funcOp.getLoc(), funcOp.getName(),
117 Region &newRegion = newFuncOp.getBody();
119 newFuncOp.setVisibility(funcOp.getVisibility());
125 funcOp.getAllArgAttrs(argAttrs);
126 for (
auto [oldArgIdx, newArgIdx] :
llvm::enumerate(oldArgIdxToNewArgIdx))
127 operandMapper.
map(funcOp.getArgument(oldArgIdx),
128 newFuncOp.getArgument(newArgIdx));
129 for (
auto [newArgIdx, oldArgIdx] :
llvm::enumerate(newArgIdxToOldArgIdxs))
130 newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
132 funcOp.getAllResultAttrs(resultAttrs);
133 for (
auto [newResIdx, oldResIdx] :
llvm::enumerate(newResToOldResIdxs))
134 newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
139 rewriter.
clone(op, operandMapper);
142 auto returnOp = cast<func::ReturnOp>(
143 newFuncOp.getFunctionBody().begin()->getTerminator());
145 for (
const auto &oldResIdxs : newResToOldResIdxs)
146 newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
149 func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
161 assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
162 "oldArgIdxToNewArgIdx must match the number of operands in the call "
164 assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
165 "oldResIdxToNewResIdx must match the number of results in the call "
172 getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
177 getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
183 func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
184 newResultTypes, newOperandsValues);
185 newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
186 for (
auto &&[oldResIdx, newResIdx] :
llvm::enumerate(oldResIdxToNewResIdx))
188 newCallOp.getResult(newResIdx));
194 FailureOr<std::pair<func::FuncOp, func::CallOp>>
198 auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
199 if (callOp.getCallee() == funcOp.getSymName()) {
200 if (!callOps.empty())
202 return WalkResult::interrupt();
203 callOps.push_back(callOp);
208 if (traversalResult.wasInterrupted()) {
209 LDBG() <<
"function " << funcOp.getName() <<
" has more than one callOp";
213 if (callOps.empty()) {
214 LDBG() <<
"function " << funcOp.getName() <<
" does not have any callOp";
218 func::CallOp callOp = callOps.front();
223 for (
auto [operandIdx, operand] :
llvm::enumerate(callOp.getOperands())) {
224 auto [iterator, inserted] = valueToNewArgIdx.insert(
225 {operand,
static_cast<int>(valueToNewArgIdx.size())});
227 oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
230 bool hasDuplicateOperands =
231 valueToNewArgIdx.size() != callOp.getNumOperands();
232 if (!hasDuplicateOperands) {
233 LDBG() <<
"function " << funcOp.getName()
234 <<
" does not have duplicate operands";
240 for (
int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
241 oldResIdxToNewResIdx[resultIdx] = resultIdx;
245 rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
246 if (
failed(newFuncOpOrFailure)) {
247 LDBG() <<
"failed to replace function signature with name "
248 << funcOp.getName() <<
" with new order";
253 rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
255 return std::make_pair(*newFuncOpOrFailure, newCallOp);
static SmallVector< Element > getMappedElements(ArrayRef< Element > origElements, const llvm::SmallVector< llvm::SmallVector< int >> &newIdxToOldIdxs)
This method returns a new vector of elements that are mapped from the origElements based on the newId...
static llvm::SmallVector< llvm::SmallVector< int > > getInverseMapping(ArrayRef< int > oldIdxToNewIdx)
This method creates an inverse mapping of the provided map oldToNew.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static WalkResult advance()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
Include the generated interface declarations.