38 namespace bufferization {
39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
49 func::ReturnOp returnOp;
50 for (
Block &b : funcOp.getBody()) {
51 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
54 returnOp = candidateOp;
63 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
66 return dyn_cast_or_null<func::FuncOp>(
76 module.walk([&](func::CallOp callOp) {
78 callerMap[calledFunc].insert(callOp);
82 for (
auto funcOp : module.getOps<func::FuncOp>()) {
83 if (funcOp.isExternal())
92 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
97 Value val = it.value();
99 val = castOp.getSource();
102 resultToArgs[it.index()] = bbArg.getArgNumber();
109 erasedResultIndices.set(it.index());
111 newReturnValues.push_back(it.value());
116 if (failed(funcOp.eraseResults(erasedResultIndices)))
118 returnOp.getOperandsMutable().assign(newReturnValues);
121 for (func::CallOp callOp : callerMap[funcOp]) {
123 auto newCallOp = rewriter.
create<func::CallOp>(callOp.getLoc(), funcOp,
124 callOp.getOperands());
126 int64_t nextResult = 0;
127 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
128 if (!resultToArgs.count(i)) {
130 newResults.push_back(newCallOp.getResult(nextResult++));
135 Value replacement = callOp.getOperand(resultToArgs[i]);
136 Type expectedType = callOp.getResult(i).getType();
137 if (replacement.
getType() != expectedType) {
139 replacement = rewriter.
create<memref::CastOp>(
140 callOp.getLoc(), expectedType, replacement);
142 newResults.push_back(replacement);
152 struct DropEquivalentBufferResultsPass
153 : bufferization::impl::DropEquivalentBufferResultsPassBase<
154 DropEquivalentBufferResultsPass> {
155 void runOnOperation()
override {
157 return signalPassFailure();
static func::FuncOp getCalledFunction(CallOpInterface callOp)
Return the func::FuncOp called by callOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.