15 #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
16 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
20 struct DuplicateFuncOpEquivalenceInfo
23 static unsigned getHashValue(
const func::FuncOp cFunc) {
25 return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
29 llvm::hash_code hash = {};
30 func::FuncOp func =
const_cast<func::FuncOp &
>(cFunc);
31 StringAttr symNameAttrName = func.getSymNameAttrName();
32 for (NamedAttribute namedAttr : cFunc->getAttrs()) {
33 StringAttr attrName = namedAttr.getName();
34 if (attrName == symNameAttrName)
36 hash = llvm::hash_combine(hash, namedAttr);
40 func.getBody().walk([&](Operation *op) {
41 hash = llvm::hash_combine(
51 static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
54 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
55 rhs == getTombstoneKey() || rhs == getEmptyKey())
58 if (lhs.isDeclaration() || rhs.isDeclaration())
62 if (lhs->getDiscardableAttrDictionary() !=
63 rhs->getDiscardableAttrDictionary())
69 auto pLhs = lhs.getProperties();
70 auto pRhs = rhs.getProperties();
71 pLhs.sym_name =
nullptr;
72 pRhs.sym_name =
nullptr;
82 struct DuplicateFunctionEliminationPass
83 :
public impl::DuplicateFunctionEliminationPassBase<
84 DuplicateFunctionEliminationPass> {
86 using DuplicateFunctionEliminationPassBase<
87 DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
89 void runOnOperation()
override {
90 auto module = getOperation();
93 DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
94 DenseMap<StringAttr, func::FuncOp> getRepresentant;
95 DenseSet<func::FuncOp> toBeErased;
96 module.walk([&](func::FuncOp f) {
97 auto [repr, inserted] = uniqueFuncOps.insert(f);
98 getRepresentant[f.getSymNameAttr()] = *repr;
100 toBeErased.insert(f);
106 SymbolTableCollection symbolTable;
107 SymbolUserMap userMap(symbolTable, module);
108 for (
auto it : toBeErased) {
109 StringAttr oldSymbol = it.getSymNameAttr();
110 StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
111 userMap.replaceAllUsesWith(it, newSymbol);
120 return std::make_unique<DuplicateFunctionEliminationPass>();
std::unique_ptr< Pass > createDuplicateFunctionEliminationPass()
Pass to deduplicate functions.
Include the generated interface declarations.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.