MLIR  19.0.0git
DropEquivalentBufferResults.cpp
Go to the documentation of this file.
1 //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass drops return values from functions if they are equivalent to one of
10 // their arguments. E.g.:
11 //
12 // ```
13 // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14 // return %m : memref<?xf32>
15 // }
16 // ```
17 //
18 // This functions is rewritten to:
19 //
20 // ```
21 // func.func @foo(%m : memref<?xf32>) {
22 // return
23 // }
24 // ```
25 //
26 // All call sites are updated accordingly. If a function returns a cast of a
27 // function argument, it is also considered equivalent. A cast is inserted at
28 // the call site in that case.
29 
31 
34 #include "mlir/IR/Operation.h"
35 #include "mlir/Pass/Pass.h"
36 
37 namespace mlir {
38 namespace bufferization {
39 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS
40 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41 } // namespace bufferization
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 /// Return the unique ReturnOp that terminates `funcOp`.
47 /// Return nullptr if there is no such unique ReturnOp.
48 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
49  func::ReturnOp returnOp;
50  for (Block &b : funcOp.getBody()) {
51  if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52  if (returnOp)
53  return nullptr;
54  returnOp = candidateOp;
55  }
56  }
57  return returnOp;
58 }
59 
60 /// Return the func::FuncOp called by `callOp`.
61 static func::FuncOp getCalledFunction(CallOpInterface callOp) {
62  SymbolRefAttr sym =
63  llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
64  if (!sym)
65  return nullptr;
66  return dyn_cast_or_null<func::FuncOp>(
68 }
69 
72  IRRewriter rewriter(module.getContext());
73 
74  for (auto funcOp : module.getOps<func::FuncOp>()) {
75  if (funcOp.isExternal())
76  continue;
77  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
78  // TODO: Support functions with multiple blocks.
79  if (!returnOp)
80  continue;
81 
82  // Compute erased results.
83  SmallVector<Value> newReturnValues;
84  BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
85  DenseMap<int64_t, int64_t> resultToArgs;
86  for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
87  bool erased = false;
88  for (BlockArgument bbArg : funcOp.getArguments()) {
89  Value val = it.value();
90  while (auto castOp = val.getDefiningOp<memref::CastOp>())
91  val = castOp.getSource();
92 
93  if (val == bbArg) {
94  resultToArgs[it.index()] = bbArg.getArgNumber();
95  erased = true;
96  break;
97  }
98  }
99 
100  if (erased) {
101  erasedResultIndices.set(it.index());
102  } else {
103  newReturnValues.push_back(it.value());
104  }
105  }
106 
107  // Update function.
108  funcOp.eraseResults(erasedResultIndices);
109  returnOp.getOperandsMutable().assign(newReturnValues);
110 
111  // Update function calls.
112  module.walk([&](func::CallOp callOp) {
113  if (getCalledFunction(callOp) != funcOp)
114  return WalkResult::skip();
115 
116  rewriter.setInsertionPoint(callOp);
117  auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
118  callOp.getOperands());
119  SmallVector<Value> newResults;
120  int64_t nextResult = 0;
121  for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
122  if (!resultToArgs.count(i)) {
123  // This result was not erased.
124  newResults.push_back(newCallOp.getResult(nextResult++));
125  continue;
126  }
127 
128  // This result was erased.
129  Value replacement = callOp.getOperand(resultToArgs[i]);
130  Type expectedType = callOp.getResult(i).getType();
131  if (replacement.getType() != expectedType) {
132  // A cast must be inserted at the call site.
133  replacement = rewriter.create<memref::CastOp>(
134  callOp.getLoc(), expectedType, replacement);
135  }
136  newResults.push_back(replacement);
137  }
138  rewriter.replaceOp(callOp, newResults);
139  return WalkResult::advance();
140  });
141  }
142 
143  return success();
144 }
145 
146 namespace {
147 struct DropEquivalentBufferResultsPass
148  : bufferization::impl::DropEquivalentBufferResultsBase<
149  DropEquivalentBufferResultsPass> {
150  void runOnOperation() override {
152  return signalPassFailure();
153  }
154 };
155 } // namespace
156 
157 std::unique_ptr<Pass>
159  return std::make_unique<DropEquivalentBufferResultsPass>();
160 }
static func::FuncOp getCalledFunction(CallOpInterface callOp)
Return the func::FuncOp called by callOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:756
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
std::unique_ptr< Pass > createDropEquivalentBufferResultsPass()
Creates a pass that drops memref function results that are equivalent to a function argument.
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26