MLIR  19.0.0git
DuplicateFunctionElimination.cpp
Go to the documentation of this file.
1 //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 namespace mlir {
13 namespace {
14 
15 #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
16 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
17 
18 // Define a notion of function equivalence that allows for reuse. Ignore the
19 // symbol name for this purpose.
20 struct DuplicateFuncOpEquivalenceInfo
21  : public llvm::DenseMapInfo<func::FuncOp> {
22 
23  static unsigned getHashValue(const func::FuncOp cFunc) {
24  if (!cFunc) {
25  return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
26  }
27 
28  // Aggregate attributes, ignoring the symbol name.
29  llvm::hash_code hash = {};
30  func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
31  StringAttr symNameAttrName = func.getSymNameAttrName();
32  for (NamedAttribute namedAttr : cFunc->getAttrs()) {
33  StringAttr attrName = namedAttr.getName();
34  if (attrName == symNameAttrName)
35  continue;
36  hash = llvm::hash_combine(hash, namedAttr);
37  }
38 
39  // Also hash the func body.
40  func.getBody().walk([&](Operation *op) {
41  hash = llvm::hash_combine(
43  op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
46  });
47 
48  return hash;
49  }
50 
51  static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
52  if (lhs == rhs)
53  return true;
54  if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
55  rhs == getTombstoneKey() || rhs == getEmptyKey())
56  return false;
57  // Check discardable attributes equivalence
58  if (lhs->getDiscardableAttrDictionary() !=
59  rhs->getDiscardableAttrDictionary())
60  return false;
61 
62  // Check properties equivalence, ignoring the symbol name.
63  // Make a copy, so that we can erase the symbol name and perform the
64  // comparison.
65  auto pLhs = lhs.getProperties();
66  auto pRhs = rhs.getProperties();
67  pLhs.sym_name = nullptr;
68  pRhs.sym_name = nullptr;
69  if (pLhs != pRhs)
70  return false;
71 
72  // Compare inner workings.
74  &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
75  }
76 };
77 
78 struct DuplicateFunctionEliminationPass
79  : public impl::DuplicateFunctionEliminationPassBase<
80  DuplicateFunctionEliminationPass> {
81 
82  using DuplicateFunctionEliminationPassBase<
83  DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
84 
85  void runOnOperation() override {
86  auto module = getOperation();
87 
88  // Find unique representant per equivalent func ops.
89  DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
90  DenseMap<StringAttr, func::FuncOp> getRepresentant;
91  DenseSet<func::FuncOp> toBeErased;
92  module.walk([&](func::FuncOp f) {
93  auto [repr, inserted] = uniqueFuncOps.insert(f);
94  getRepresentant[f.getSymNameAttr()] = *repr;
95  if (!inserted) {
96  toBeErased.insert(f);
97  }
98  });
99 
100  // Update call ops to call unique func op representants.
101  module.walk([&](func::CallOp callOp) {
102  func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
103  callOp.setCallee(callee.getSymName());
104  });
105 
106  // Erase redundant func ops.
107  for (auto it : toBeErased) {
108  it.erase();
109  }
110  }
111 };
112 
113 } // namespace
114 
116  return std::make_unique<DuplicateFunctionEliminationPass>();
117 }
118 
119 } // namespace mlir
std::unique_ptr< Pass > createDuplicateFunctionEliminationPass()
Pass to deduplicate functions.
Include the generated interface declarations.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.