MLIR 23.0.0git
DropEquivalentBufferResults.cpp
Go to the documentation of this file.
1//===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass drops return values from functions if they are equivalent to one of
10// their arguments. E.g.:
11//
12// ```
13// func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14// return %m : memref<?xf32>
15// }
16// ```
17//
18// This functions is rewritten to:
19//
20// ```
21// func.func @foo(%m : memref<?xf32>) {
22// return
23// }
24// ```
25//
26// All call sites are updated accordingly. If a function returns a cast of a
27// function argument, it is also considered equivalent. A cast is inserted at
28// the call site in that case.
29
31
34
35namespace mlir {
36namespace bufferization {
37#define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
38#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
39} // namespace bufferization
40} // namespace mlir
41
42using namespace mlir;
43
44/// Get all the ReturnOp in the funcOp.
45static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
47 for (Block &b : funcOp.getBody()) {
48 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
49 returnOps.push_back(candidateOp);
50 }
51 }
52 return returnOps;
53}
54
55/// Get the operands at the specified position for all returnOps.
58 return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
59 return returnOp.getOperand(pos);
60 });
61}
62
63/// Check if all given values are the same buffer as the block argument (modulo
64/// cast ops).
66 BlockArgument argument) {
67 for (Value val : operands) {
68 while (auto castOp = val.getDefiningOp<memref::CastOp>())
69 val = castOp.getSource();
70
71 if (val != argument)
72 return false;
73 }
74 return true;
75}
76
78 ModuleOp module, DropBufferResultsOpts options) {
79 IRRewriter rewriter(module.getContext());
80
82 // Collect the mapping of functions to their call sites.
83 module.walk([&](func::CallOp callOp) {
84 if (func::FuncOp calledFunc =
85 dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
86 if (calledFunc.isPublic() && !options.modifyPublicFunctions)
87 return WalkResult::advance();
88 if (!calledFunc.isExternal())
89 callerMap[calledFunc].insert(callOp);
90 }
91 return WalkResult::advance();
92 });
93
94 for (auto funcOp : module.getOps<func::FuncOp>()) {
95 if (funcOp.isPublic() && !options.modifyPublicFunctions)
96 continue;
97 if (funcOp.isExternal())
98 continue;
99 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
100 if (returnOps.empty())
101 continue;
102
103 // Compute erased results.
104 size_t numReturnOps = returnOps.size();
105 size_t numReturnValues = funcOp.getFunctionType().getNumResults();
106 SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
107 BitVector erasedResultIndices(numReturnValues);
108 DenseMap<int64_t, int64_t> resultToArgs;
109 for (size_t i = 0; i < numReturnValues; ++i) {
110 bool erased = false;
111 SmallVector<Value> returnOperands =
112 getReturnOpsOperandInPos(returnOps, i);
113 for (BlockArgument bbArg : funcOp.getArguments()) {
114 if (operandsEqualFuncArgument(returnOperands, bbArg)) {
115 resultToArgs[i] = bbArg.getArgNumber();
116 erased = true;
117 break;
118 }
119 }
120
121 if (erased) {
122 erasedResultIndices.set(i);
123 } else {
124 for (auto [newReturnValue, operand] :
125 llvm::zip(newReturnValues, returnOperands)) {
126 newReturnValue.push_back(operand);
127 }
128 }
129 }
130
131 // Update function.
132 if (failed(funcOp.eraseResults(erasedResultIndices)))
133 return failure();
134
135 for (auto [returnOp, newReturnValue] :
136 llvm::zip(returnOps, newReturnValues))
137 returnOp.getOperandsMutable().assign(newReturnValue);
138
139 // Update function calls.
140 for (func::CallOp callOp : callerMap[funcOp]) {
141 rewriter.setInsertionPoint(callOp);
142 auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
143 callOp.getOperands());
144 SmallVector<Value> newResults;
145 int64_t nextResult = 0;
146 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
147 if (!resultToArgs.count(i)) {
148 // This result was not erased.
149 newResults.push_back(newCallOp.getResult(nextResult++));
150 continue;
151 }
152
153 // This result was erased.
154 Value replacement = callOp.getOperand(resultToArgs[i]);
155 Type expectedType = callOp.getResult(i).getType();
156 if (replacement.getType() != expectedType) {
157 // A cast must be inserted at the call site.
158 replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
159 expectedType, replacement);
160 }
161 newResults.push_back(replacement);
162 }
163 rewriter.replaceOp(callOp, newResults);
164 }
165 }
166
167 return success();
168}
169
170namespace {
171struct DropEquivalentBufferResultsPass
172 : bufferization::impl::DropEquivalentBufferResultsPassBase<
173 DropEquivalentBufferResultsPass> {
174 using Base::Base;
175
176 void runOnOperation() override {
177 // Convert pass options.
178 options.modifyPublicFunctions = modifyPublicFunctions;
179
181 options)))
182 return signalPassFailure();
183 }
184
185private:
186 bufferization::DropBufferResultsOpts options;
187};
188} // namespace
return success()
static SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Get all the ReturnOp in the funcOp.
static bool operandsEqualFuncArgument(ArrayRef< Value > operands, BlockArgument argument)
Check if all given values are the same buffer as the block argument (modulo cast ops).
static SmallVector< Value > getReturnOpsOperandInPos(ArrayRef< func::ReturnOp > returnOps, size_t pos)
Get the operands at the specified position for all returnOps.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
LogicalResult dropEquivalentBufferResults(ModuleOp module, DropBufferResultsOpts options=DropBufferResultsOpts())
Drop all memref function results that are equivalent to a function argument.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
Include the generated interface declarations.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
Options for dropping equivalent memref buffer results.
Definition Passes.h:186