36 namespace bufferization {
37 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
38 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
47 func::ReturnOp returnOp;
48 for (
Block &b : funcOp.getBody()) {
49 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52 returnOp = candidateOp;
61 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
64 return dyn_cast_or_null<func::FuncOp>(
74 module.walk([&](func::CallOp callOp) {
76 callerMap[calledFunc].insert(callOp);
80 for (
auto funcOp : module.getOps<func::FuncOp>()) {
81 if (funcOp.isExternal())
90 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
95 Value val = it.value();
97 val = castOp.getSource();
100 resultToArgs[it.index()] = bbArg.getArgNumber();
107 erasedResultIndices.set(it.index());
109 newReturnValues.push_back(it.value());
114 if (
failed(funcOp.eraseResults(erasedResultIndices)))
116 returnOp.getOperandsMutable().assign(newReturnValues);
119 for (func::CallOp callOp : callerMap[funcOp]) {
121 auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
122 callOp.getOperands());
124 int64_t nextResult = 0;
125 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
126 if (!resultToArgs.count(i)) {
128 newResults.push_back(newCallOp.getResult(nextResult++));
133 Value replacement = callOp.getOperand(resultToArgs[i]);
134 Type expectedType = callOp.getResult(i).getType();
135 if (replacement.
getType() != expectedType) {
137 replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
138 expectedType, replacement);
140 newResults.push_back(replacement);
150 struct DropEquivalentBufferResultsPass
151 : bufferization::impl::DropEquivalentBufferResultsPassBase<
152 DropEquivalentBufferResultsPass> {
153 void runOnOperation()
override {
155 return signalPassFailure();
static func::FuncOp getCalledFunction(CallOpInterface callOp)
Return the func::FuncOp called by callOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.