38 namespace bufferization {
39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS
40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
49 func::ReturnOp returnOp;
50 for (
Block &b : funcOp.getBody()) {
51 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
54 returnOp = candidateOp;
63 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
66 return dyn_cast_or_null<func::FuncOp>(
76 module.walk([&](func::CallOp callOp) {
78 callerMap[calledFunc].insert(callOp);
82 for (
auto funcOp : module.getOps<func::FuncOp>()) {
83 if (funcOp.isExternal())
92 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
97 Value val = it.value();
99 val = castOp.getSource();
102 resultToArgs[it.index()] = bbArg.getArgNumber();
109 erasedResultIndices.set(it.index());
111 newReturnValues.push_back(it.value());
116 funcOp.eraseResults(erasedResultIndices);
117 returnOp.getOperandsMutable().assign(newReturnValues);
120 for (func::CallOp callOp : callerMap[funcOp]) {
122 auto newCallOp = rewriter.
create<func::CallOp>(callOp.getLoc(), funcOp,
123 callOp.getOperands());
125 int64_t nextResult = 0;
126 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
127 if (!resultToArgs.count(i)) {
129 newResults.push_back(newCallOp.getResult(nextResult++));
134 Value replacement = callOp.getOperand(resultToArgs[i]);
135 Type expectedType = callOp.getResult(i).getType();
136 if (replacement.
getType() != expectedType) {
138 replacement = rewriter.
create<memref::CastOp>(
139 callOp.getLoc(), expectedType, replacement);
141 newResults.push_back(replacement);
151 struct DropEquivalentBufferResultsPass
152 : bufferization::impl::DropEquivalentBufferResultsBase<
153 DropEquivalentBufferResultsPass> {
154 void runOnOperation()
override {
156 return signalPassFailure();
161 std::unique_ptr<Pass>
163 return std::make_unique<DropEquivalentBufferResultsPass>();
static func::FuncOp getCalledFunction(CallOpInterface callOp)
Return the func::FuncOp called by callOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::unique_ptr< Pass > createDropEquivalentBufferResultsPass()
Creates a pass that drops memref function results that are equivalent to a function argument.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.