17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/Support/DebugLog.h"
21#define DEBUG_TYPE "func-utils"
40 if (!oldIdxToNewIdx.empty())
41 numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
43 for (
auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
44 newToOldIdxs[newIdx].push_back(oldIdx);
53template <
typename Element>
58 for (
const auto &oldIdxs : newIdxToOldIdxs) {
59 assert(llvm::all_of(oldIdxs,
60 [&origElements](
int idx) ->
bool {
62 static_cast<size_t>(idx) < origElements.size();
64 "idx must be less than the number of elements in the original "
66 assert(!oldIdxs.empty() &&
"oldIdx must not be empty");
67 Element origTypeToCheck = origElements[oldIdxs.front()];
68 assert(llvm::all_of(oldIdxs,
69 [&](
int idx) ->
bool {
70 return origElements[idx] == origTypeToCheck;
72 "all oldIdxs must be equal");
73 newElements.push_back(origTypeToCheck);
78FailureOr<func::FuncOp>
84 assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
85 "oldArgIdxToNewArgIdx must match the number of arguments in the "
88 funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
89 "oldResIdxToNewResIdx must match the number of results in the function");
91 if (!funcOp.getBody().hasOneBlock())
93 funcOp,
"expected function to have exactly one block");
101 funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
104 for (
const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
105 locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
110 funcOp.getFunctionType().getResults(), newResToOldResIdxs);
113 auto newFuncOp = func::FuncOp::create(
114 rewriter, funcOp.getLoc(), funcOp.getName(),
117 Region &newRegion = newFuncOp.getBody();
119 newFuncOp.setVisibility(funcOp.getVisibility());
125 funcOp.getAllArgAttrs(argAttrs);
126 for (
auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
127 operandMapper.
map(funcOp.getArgument(oldArgIdx),
128 newFuncOp.getArgument(newArgIdx));
129 for (
auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
130 newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
132 funcOp.getAllResultAttrs(resultAttrs);
133 for (
auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
134 newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
139 rewriter.
clone(op, operandMapper);
142 auto returnOp = cast<func::ReturnOp>(
143 newFuncOp.getFunctionBody().begin()->getTerminator());
145 for (
const auto &oldResIdxs : newResToOldResIdxs)
146 newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
149 func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
161 assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
162 "oldArgIdxToNewArgIdx must match the number of operands in the call "
164 assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
165 "oldResIdxToNewResIdx must match the number of results in the call "
183 func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
184 newResultTypes, newOperandsValues);
185 newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
186 for (
auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
188 newCallOp.getResult(newResIdx));
194FailureOr<std::pair<func::FuncOp, func::CallOp>>
198 auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
199 if (callOp.getCallee() == funcOp.getSymName()) {
200 if (!callOps.empty())
203 callOps.push_back(callOp);
208 if (traversalResult.wasInterrupted()) {
209 LDBG() <<
"function " << funcOp.getName() <<
" has more than one callOp";
213 if (callOps.empty()) {
214 LDBG() <<
"function " << funcOp.getName() <<
" does not have any callOp";
218 func::CallOp callOp = callOps.front();
223 for (
auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
224 auto [iterator,
inserted] = valueToNewArgIdx.insert(
225 {operand,
static_cast<int>(valueToNewArgIdx.size())});
227 oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
230 bool hasDuplicateOperands =
231 valueToNewArgIdx.size() != callOp.getNumOperands();
232 if (!hasDuplicateOperands) {
233 LDBG() <<
"function " << funcOp.getName()
234 <<
" does not have duplicate operands";
240 for (
int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
241 oldResIdxToNewResIdx[resultIdx] = resultIdx;
245 rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
246 if (
failed(newFuncOpOrFailure)) {
247 LDBG() <<
"failed to replace function signature with name "
248 << funcOp.getName() <<
" with new order";
253 rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
255 return std::make_pair(*newFuncOpOrFailure, newCallOp);
258FailureOr<func::FuncOp>
264 symTable, StringAttr::get(symTable->getContext(), name));
266 func = llvm::dyn_cast_or_null<FuncOp>(
273 mlir::FunctionType foundFuncT =
func.getFunctionType();
275 if (funcT != foundFuncT) {
276 return func.emitError(
"matched function '")
277 << name <<
"' but with different type: " << foundFuncT
278 <<
" (expected " << funcT <<
")";
static llvm::SmallVector< llvm::SmallVector< int > > getInverseMapping(ArrayRef< int > oldIdxToNewIdx)
This method creates an inverse mapping of the provided map oldToNew.
static SmallVector< Element > getMappedElements(ArrayRef< Element > origElements, const llvm::SmallVector< llvm::SmallVector< int > > &newIdxToOldIdxs)
This method returns a new vector of elements that are mapped from the origElements based on the newId...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static WalkResult advance()
static WalkResult interrupt()
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
FailureOr< FuncOp > lookupFnDecl(SymbolOpInterface symTable, StringRef name, FunctionType funcT, SymbolTableCollection *symbolTables=nullptr)
Look up a FuncOp with signature resultTypes(paramTypes) and name / name`.
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
Include the generated interface declarations.