83 module.walk([&](func::CallOp callOp) {
84 if (func::FuncOp calledFunc =
85 dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
86 if (!calledFunc.isPublic() && !calledFunc.isExternal())
87 callerMap[calledFunc].insert(callOp);
91 for (
auto funcOp : module.getOps<
func::FuncOp>()) {
92 if (funcOp.isExternal() || funcOp.isPublic())
95 if (returnOps.empty())
99 size_t numReturnOps = returnOps.size();
100 size_t numReturnValues = funcOp.getFunctionType().getNumResults();
102 BitVector erasedResultIndices(numReturnValues);
104 for (
size_t i = 0; i < numReturnValues; ++i) {
110 resultToArgs[i] = bbArg.getArgNumber();
117 erasedResultIndices.set(i);
119 for (
auto [newReturnValue, operand] :
120 llvm::zip(newReturnValues, returnOperands)) {
121 newReturnValue.push_back(operand);
127 if (
failed(funcOp.eraseResults(erasedResultIndices)))
130 for (
auto [returnOp, newReturnValue] :
131 llvm::zip(returnOps, newReturnValues))
132 returnOp.getOperandsMutable().assign(newReturnValue);
135 for (func::CallOp callOp : callerMap[funcOp]) {
136 rewriter.setInsertionPoint(callOp);
137 auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
138 callOp.getOperands());
141 for (
int64_t i = 0; i < callOp.getNumResults(); ++i) {
142 if (!resultToArgs.count(i)) {
144 newResults.push_back(newCallOp.getResult(nextResult++));
150 Type expectedType = callOp.getResult(i).getType();
153 replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
158 rewriter.replaceOp(callOp, newResults);