MLIR 22.0.0git
DropEquivalentBufferResults.cpp
Go to the documentation of this file.
1//===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass drops return values from functions if they are equivalent to one of
10// their arguments. E.g.:
11//
12// ```
13// func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14// return %m : memref<?xf32>
15// }
16// ```
17//
18// This functions is rewritten to:
19//
20// ```
21// func.func @foo(%m : memref<?xf32>) {
22// return
23// }
24// ```
25//
26// All call sites are updated accordingly. If a function returns a cast of a
27// function argument, it is also considered equivalent. A cast is inserted at
28// the call site in that case.
29
31
34
35namespace mlir {
36namespace bufferization {
37#define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
38#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
39} // namespace bufferization
40} // namespace mlir
41
42using namespace mlir;
43
44/// Get all the ReturnOp in the funcOp.
45static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
47 for (Block &b : funcOp.getBody()) {
48 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
49 returnOps.push_back(candidateOp);
50 }
51 }
52 return returnOps;
53}
54
55/// Get the operands at the specified position for all returnOps.
58 return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
59 return returnOp.getOperand(pos);
60 });
61}
62
63/// Check if all given values are the same buffer as the block argument (modulo
64/// cast ops).
66 BlockArgument argument) {
67 for (Value val : operands) {
68 while (auto castOp = val.getDefiningOp<memref::CastOp>())
69 val = castOp.getSource();
70
71 if (val != argument)
72 return false;
73 }
74 return true;
75}
76
77LogicalResult
79 IRRewriter rewriter(module.getContext());
80
82 // Collect the mapping of functions to their call sites.
83 module.walk([&](func::CallOp callOp) {
84 if (func::FuncOp calledFunc =
85 dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
86 if (!calledFunc.isPublic() && !calledFunc.isExternal())
87 callerMap[calledFunc].insert(callOp);
88 }
89 });
90
91 for (auto funcOp : module.getOps<func::FuncOp>()) {
92 if (funcOp.isExternal() || funcOp.isPublic())
93 continue;
94 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
95 if (returnOps.empty())
96 continue;
97
98 // Compute erased results.
99 size_t numReturnOps = returnOps.size();
100 size_t numReturnValues = funcOp.getFunctionType().getNumResults();
101 SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
102 BitVector erasedResultIndices(numReturnValues);
103 DenseMap<int64_t, int64_t> resultToArgs;
104 for (size_t i = 0; i < numReturnValues; ++i) {
105 bool erased = false;
106 SmallVector<Value> returnOperands =
107 getReturnOpsOperandInPos(returnOps, i);
108 for (BlockArgument bbArg : funcOp.getArguments()) {
109 if (operandsEqualFuncArgument(returnOperands, bbArg)) {
110 resultToArgs[i] = bbArg.getArgNumber();
111 erased = true;
112 break;
113 }
114 }
115
116 if (erased) {
117 erasedResultIndices.set(i);
118 } else {
119 for (auto [newReturnValue, operand] :
120 llvm::zip(newReturnValues, returnOperands)) {
121 newReturnValue.push_back(operand);
122 }
123 }
124 }
125
126 // Update function.
127 if (failed(funcOp.eraseResults(erasedResultIndices)))
128 return failure();
129
130 for (auto [returnOp, newReturnValue] :
131 llvm::zip(returnOps, newReturnValues))
132 returnOp.getOperandsMutable().assign(newReturnValue);
133
134 // Update function calls.
135 for (func::CallOp callOp : callerMap[funcOp]) {
136 rewriter.setInsertionPoint(callOp);
137 auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
138 callOp.getOperands());
139 SmallVector<Value> newResults;
140 int64_t nextResult = 0;
141 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
142 if (!resultToArgs.count(i)) {
143 // This result was not erased.
144 newResults.push_back(newCallOp.getResult(nextResult++));
145 continue;
146 }
147
148 // This result was erased.
149 Value replacement = callOp.getOperand(resultToArgs[i]);
150 Type expectedType = callOp.getResult(i).getType();
151 if (replacement.getType() != expectedType) {
152 // A cast must be inserted at the call site.
153 replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
154 expectedType, replacement);
155 }
156 newResults.push_back(replacement);
157 }
158 rewriter.replaceOp(callOp, newResults);
159 }
160 }
161
162 return success();
163}
164
165namespace {
166struct DropEquivalentBufferResultsPass
168 DropEquivalentBufferResultsPass> {
169 void runOnOperation() override {
171 return signalPassFailure();
172 }
173};
174} // namespace
return success()
static SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Get all the ReturnOp in the funcOp.
static bool operandsEqualFuncArgument(ArrayRef< Value > operands, BlockArgument argument)
Check if all given values are the same buffer as the block argument (modulo cast ops).
static SmallVector< Value > getReturnOpsOperandInPos(ArrayRef< func::ReturnOp > returnOps, size_t pos)
Get the operands at the specified position for all returnOps.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult dropEquivalentBufferResults(ModuleOp module)
Drop all memref function results that are equivalent to a function argument.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126