MLIR  20.0.0git
DuplicateFunctionElimination.cpp
Go to the documentation of this file.
1 //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 namespace mlir {
13 namespace {
14 
15 #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
16 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
17 
18 // Define a notion of function equivalence that allows for reuse. Ignore the
19 // symbol name for this purpose.
20 struct DuplicateFuncOpEquivalenceInfo
21  : public llvm::DenseMapInfo<func::FuncOp> {
22 
23  static unsigned getHashValue(const func::FuncOp cFunc) {
24  if (!cFunc) {
25  return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
26  }
27 
28  // Aggregate attributes, ignoring the symbol name.
29  llvm::hash_code hash = {};
30  func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
31  StringAttr symNameAttrName = func.getSymNameAttrName();
32  for (NamedAttribute namedAttr : cFunc->getAttrs()) {
33  StringAttr attrName = namedAttr.getName();
34  if (attrName == symNameAttrName)
35  continue;
36  hash = llvm::hash_combine(hash, namedAttr);
37  }
38 
39  // Also hash the func body.
40  func.getBody().walk([&](Operation *op) {
41  hash = llvm::hash_combine(
43  op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
46  });
47 
48  return hash;
49  }
50 
51  static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
52  if (lhs == rhs)
53  return true;
54  if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
55  rhs == getTombstoneKey() || rhs == getEmptyKey())
56  return false;
57 
58  if (lhs.isDeclaration() || rhs.isDeclaration())
59  return false;
60 
61  // Check discardable attributes equivalence
62  if (lhs->getDiscardableAttrDictionary() !=
63  rhs->getDiscardableAttrDictionary())
64  return false;
65 
66  // Check properties equivalence, ignoring the symbol name.
67  // Make a copy, so that we can erase the symbol name and perform the
68  // comparison.
69  auto pLhs = lhs.getProperties();
70  auto pRhs = rhs.getProperties();
71  pLhs.sym_name = nullptr;
72  pRhs.sym_name = nullptr;
73  if (pLhs != pRhs)
74  return false;
75 
76  // Compare inner workings.
78  &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
79  }
80 };
81 
82 struct DuplicateFunctionEliminationPass
83  : public impl::DuplicateFunctionEliminationPassBase<
84  DuplicateFunctionEliminationPass> {
85 
86  using DuplicateFunctionEliminationPassBase<
87  DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
88 
89  void runOnOperation() override {
90  auto module = getOperation();
91 
92  // Find unique representant per equivalent func ops.
93  DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
94  DenseMap<StringAttr, func::FuncOp> getRepresentant;
95  DenseSet<func::FuncOp> toBeErased;
96  module.walk([&](func::FuncOp f) {
97  auto [repr, inserted] = uniqueFuncOps.insert(f);
98  getRepresentant[f.getSymNameAttr()] = *repr;
99  if (!inserted) {
100  toBeErased.insert(f);
101  }
102  });
103 
104  // Update all symbol uses to reference unique func op
105  // representants and erase redundant func ops.
106  SymbolTableCollection symbolTable;
107  SymbolUserMap userMap(symbolTable, module);
108  for (auto it : toBeErased) {
109  StringAttr oldSymbol = it.getSymNameAttr();
110  StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
111  userMap.replaceAllUsesWith(it, newSymbol);
112  it.erase();
113  }
114  }
115 };
116 
117 } // namespace
118 
120  return std::make_unique<DuplicateFunctionEliminationPass>();
121 }
122 
123 } // namespace mlir
std::unique_ptr< Pass > createDuplicateFunctionEliminationPass()
Pass to deduplicate functions.
Include the generated interface declarations.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.