36 namespace bufferization {
37 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
38 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
47 func::ReturnOp returnOp;
48 for (
Block &b : funcOp.getBody()) {
49 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52 returnOp = candidateOp;
64 module.walk([&](func::CallOp callOp) {
65 if (func::FuncOp calledFunc =
66 dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
67 if (!calledFunc.isPublic() && !calledFunc.isExternal())
68 callerMap[calledFunc].insert(callOp);
72 for (
auto funcOp : module.getOps<func::FuncOp>()) {
73 if (funcOp.isExternal() || funcOp.isPublic())
82 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
87 Value val = it.value();
89 val = castOp.getSource();
92 resultToArgs[it.index()] = bbArg.getArgNumber();
99 erasedResultIndices.set(it.index());
101 newReturnValues.push_back(it.value());
106 if (
failed(funcOp.eraseResults(erasedResultIndices)))
108 returnOp.getOperandsMutable().assign(newReturnValues);
111 for (func::CallOp callOp : callerMap[funcOp]) {
113 auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
114 callOp.getOperands());
116 int64_t nextResult = 0;
117 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
118 if (!resultToArgs.count(i)) {
120 newResults.push_back(newCallOp.getResult(nextResult++));
125 Value replacement = callOp.getOperand(resultToArgs[i]);
126 Type expectedType = callOp.getResult(i).getType();
127 if (replacement.
getType() != expectedType) {
129 replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
130 expectedType, replacement);
132 newResults.push_back(replacement);
142 struct DropEquivalentBufferResultsPass
143 : bufferization::impl::DropEquivalentBufferResultsPassBase<
144 DropEquivalentBufferResultsPass> {
145 void runOnOperation()
override {
147 return signalPassFailure();
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.