38 namespace bufferization {
39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS
40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
49 func::ReturnOp returnOp;
50 for (
Block &b : funcOp.getBody()) {
51 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
54 returnOp = candidateOp;
63 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
66 return dyn_cast_or_null<func::FuncOp>(
74 for (
auto funcOp : module.getOps<func::FuncOp>()) {
75 if (funcOp.isExternal())
84 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
89 Value val = it.value();
91 val = castOp.getSource();
94 resultToArgs[it.index()] = bbArg.getArgNumber();
101 erasedResultIndices.set(it.index());
103 newReturnValues.push_back(it.value());
108 funcOp.eraseResults(erasedResultIndices);
109 returnOp.getOperandsMutable().assign(newReturnValues);
112 module.walk([&](func::CallOp callOp) {
117 auto newCallOp = rewriter.
create<func::CallOp>(callOp.getLoc(), funcOp,
118 callOp.getOperands());
120 int64_t nextResult = 0;
121 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
122 if (!resultToArgs.count(i)) {
124 newResults.push_back(newCallOp.getResult(nextResult++));
129 Value replacement = callOp.getOperand(resultToArgs[i]);
130 Type expectedType = callOp.getResult(i).getType();
131 if (replacement.
getType() != expectedType) {
133 replacement = rewriter.
create<memref::CastOp>(
134 callOp.getLoc(), expectedType, replacement);
136 newResults.push_back(replacement);
147 struct DropEquivalentBufferResultsPass
148 : bufferization::impl::DropEquivalentBufferResultsBase<
149 DropEquivalentBufferResultsPass> {
150 void runOnOperation()
override {
152 return signalPassFailure();
157 std::unique_ptr<Pass>
159 return std::make_unique<DropEquivalentBufferResultsPass>();
static func::FuncOp getCalledFunction(CallOpInterface callOp)
Return the func::FuncOp called by callOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
std::unique_ptr< Pass > createDropEquivalentBufferResultsPass()
Creates a pass that drops memref function results that are equivalent to a function argument.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.