MLIR  21.0.0git
DuplicateFunctionElimination.cpp
Go to the documentation of this file.
1 //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 namespace mlir {
13 namespace func {
14 #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
15 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
16 } // namespace func
17 
18 namespace {
19 
20 // Define a notion of function equivalence that allows for reuse. Ignore the
21 // symbol name for this purpose.
22 struct DuplicateFuncOpEquivalenceInfo
23  : public llvm::DenseMapInfo<func::FuncOp> {
24 
25  static unsigned getHashValue(const func::FuncOp cFunc) {
26  if (!cFunc) {
27  return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
28  }
29 
30  // Aggregate attributes, ignoring the symbol name.
31  llvm::hash_code hash = {};
32  func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
33  StringAttr symNameAttrName = func.getSymNameAttrName();
34  for (NamedAttribute namedAttr : cFunc->getAttrs()) {
35  StringAttr attrName = namedAttr.getName();
36  if (attrName == symNameAttrName)
37  continue;
38  hash = llvm::hash_combine(hash, namedAttr);
39  }
40 
41  // Also hash the func body.
42  func.getBody().walk([&](Operation *op) {
43  hash = llvm::hash_combine(
45  op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
48  });
49 
50  return hash;
51  }
52 
53  static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
54  if (lhs == rhs)
55  return true;
56  if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
57  rhs == getTombstoneKey() || rhs == getEmptyKey())
58  return false;
59 
60  if (lhs.isDeclaration() || rhs.isDeclaration())
61  return false;
62 
63  // Check discardable attributes equivalence
64  if (lhs->getDiscardableAttrDictionary() !=
65  rhs->getDiscardableAttrDictionary())
66  return false;
67 
68  // Check properties equivalence, ignoring the symbol name.
69  // Make a copy, so that we can erase the symbol name and perform the
70  // comparison.
71  auto pLhs = lhs.getProperties();
72  auto pRhs = rhs.getProperties();
73  pLhs.sym_name = nullptr;
74  pRhs.sym_name = nullptr;
75  if (pLhs != pRhs)
76  return false;
77 
78  // Compare inner workings.
80  &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
81  }
82 };
83 
84 struct DuplicateFunctionEliminationPass
85  : public func::impl::DuplicateFunctionEliminationPassBase<
86  DuplicateFunctionEliminationPass> {
87 
88  using DuplicateFunctionEliminationPassBase<
89  DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
90 
91  void runOnOperation() override {
92  auto module = getOperation();
93 
94  // Find unique representant per equivalent func ops.
95  DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
96  DenseMap<StringAttr, func::FuncOp> getRepresentant;
97  DenseSet<func::FuncOp> toBeErased;
98  module.walk([&](func::FuncOp f) {
99  auto [repr, inserted] = uniqueFuncOps.insert(f);
100  getRepresentant[f.getSymNameAttr()] = *repr;
101  if (!inserted) {
102  toBeErased.insert(f);
103  }
104  });
105 
106  // Update all symbol uses to reference unique func op
107  // representants and erase redundant func ops.
108  SymbolTableCollection symbolTable;
109  SymbolUserMap userMap(symbolTable, module);
110  for (auto it : toBeErased) {
111  StringAttr oldSymbol = it.getSymNameAttr();
112  StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
113  userMap.replaceAllUsesWith(it, newSymbol);
114  it.erase();
115  }
116  }
117 };
118 
119 } // namespace
120 } // namespace mlir
Include the generated interface declarations.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.