38 namespace bufferization {
39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS
40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
49 func::ReturnOp returnOp;
50 for (
Block &b : funcOp.getBody()) {
51 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
54 returnOp = candidateOp;
62 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
65 return dyn_cast_or_null<func::FuncOp>(
73 for (
auto funcOp : module.getOps<func::FuncOp>()) {
74 if (funcOp.isExternal())
83 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
88 Value val = it.value();
90 val = castOp.getSource();
93 resultToArgs[it.index()] = bbArg.getArgNumber();
100 erasedResultIndices.set(it.index());
102 newReturnValues.push_back(it.value());
107 funcOp.eraseResults(erasedResultIndices);
108 returnOp.getOperandsMutable().assign(newReturnValues);
111 module.walk([&](func::CallOp callOp) {
116 auto newCallOp = rewriter.
create<func::CallOp>(callOp.getLoc(), funcOp,
117 callOp.getOperands());
119 int64_t nextResult = 0;
120 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
121 if (!resultToArgs.count(i)) {
123 newResults.push_back(newCallOp.getResult(nextResult++));
128 Value replacement = callOp.getOperand(resultToArgs[i]);
129 Type expectedType = callOp.getResult(i).getType();
130 if (replacement.
getType() != expectedType) {
132 replacement = rewriter.
create<memref::CastOp>(
133 callOp.getLoc(), expectedType, replacement);
135 newResults.push_back(replacement);
146 struct DropEquivalentBufferResultsPass
147 : bufferization::impl::DropEquivalentBufferResultsBase<
148 DropEquivalentBufferResultsPass> {
149 void runOnOperation()
override {
151 return signalPassFailure();
156 std::unique_ptr<Pass>
158 return std::make_unique<DropEquivalentBufferResultsPass>();
static func::FuncOp getCalledFunction(CallOpInterface callOp)
Return the func::FuncOp called by callOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
std::unique_ptr< Pass > createDropEquivalentBufferResultsPass()
Creates a pass that drops memref function results that are equivalent to a function argument.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.