MLIR  22.0.0git
DropEquivalentBufferResults.cpp
Go to the documentation of this file.
1 //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass drops return values from functions if they are equivalent to one of
10 // their arguments. E.g.:
11 //
12 // ```
13 // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14 // return %m : memref<?xf32>
15 // }
16 // ```
17 //
18 // This functions is rewritten to:
19 //
20 // ```
21 // func.func @foo(%m : memref<?xf32>) {
22 // return
23 // }
24 // ```
25 //
26 // All call sites are updated accordingly. If a function returns a cast of a
27 // function argument, it is also considered equivalent. A cast is inserted at
28 // the call site in that case.
29 
31 
34 
35 namespace mlir {
36 namespace bufferization {
37 #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
38 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
39 } // namespace bufferization
40 } // namespace mlir
41 
42 using namespace mlir;
43 
44 /// Return the unique ReturnOp that terminates `funcOp`.
45 /// Return nullptr if there is no such unique ReturnOp.
46 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
47  func::ReturnOp returnOp;
48  for (Block &b : funcOp.getBody()) {
49  if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50  if (returnOp)
51  return nullptr;
52  returnOp = candidateOp;
53  }
54  }
55  return returnOp;
56 }
57 
58 LogicalResult
60  IRRewriter rewriter(module.getContext());
61 
63  // Collect the mapping of functions to their call sites.
64  module.walk([&](func::CallOp callOp) {
65  if (func::FuncOp calledFunc =
66  dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
67  if (!calledFunc.isPublic() && !calledFunc.isExternal())
68  callerMap[calledFunc].insert(callOp);
69  }
70  });
71 
72  for (auto funcOp : module.getOps<func::FuncOp>()) {
73  if (funcOp.isExternal() || funcOp.isPublic())
74  continue;
75  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
76  // TODO: Support functions with multiple blocks.
77  if (!returnOp)
78  continue;
79 
80  // Compute erased results.
81  SmallVector<Value> newReturnValues;
82  BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
83  DenseMap<int64_t, int64_t> resultToArgs;
84  for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
85  bool erased = false;
86  for (BlockArgument bbArg : funcOp.getArguments()) {
87  Value val = it.value();
88  while (auto castOp = val.getDefiningOp<memref::CastOp>())
89  val = castOp.getSource();
90 
91  if (val == bbArg) {
92  resultToArgs[it.index()] = bbArg.getArgNumber();
93  erased = true;
94  break;
95  }
96  }
97 
98  if (erased) {
99  erasedResultIndices.set(it.index());
100  } else {
101  newReturnValues.push_back(it.value());
102  }
103  }
104 
105  // Update function.
106  if (failed(funcOp.eraseResults(erasedResultIndices)))
107  return failure();
108  returnOp.getOperandsMutable().assign(newReturnValues);
109 
110  // Update function calls.
111  for (func::CallOp callOp : callerMap[funcOp]) {
112  rewriter.setInsertionPoint(callOp);
113  auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
114  callOp.getOperands());
115  SmallVector<Value> newResults;
116  int64_t nextResult = 0;
117  for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
118  if (!resultToArgs.count(i)) {
119  // This result was not erased.
120  newResults.push_back(newCallOp.getResult(nextResult++));
121  continue;
122  }
123 
124  // This result was erased.
125  Value replacement = callOp.getOperand(resultToArgs[i]);
126  Type expectedType = callOp.getResult(i).getType();
127  if (replacement.getType() != expectedType) {
128  // A cast must be inserted at the call site.
129  replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
130  expectedType, replacement);
131  }
132  newResults.push_back(replacement);
133  }
134  rewriter.replaceOp(callOp, newResults);
135  }
136  }
137 
138  return success();
139 }
140 
141 namespace {
142 struct DropEquivalentBufferResultsPass
143  : bufferization::impl::DropEquivalentBufferResultsPassBase<
144  DropEquivalentBufferResultsPass> {
145  void runOnOperation() override {
147  return signalPassFailure();
148  }
149 };
150 } // namespace
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.