21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/Support/Allocator.h"
24#include "llvm/Support/RecyclingAllocator.h"
28#define GEN_PASS_DEF_CSE
29#include "mlir/Transforms/Passes.h.inc"
36 static unsigned getHashValue(
const Operation *opC) {
38 const_cast<Operation *
>(opC),
43 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
44 auto *
lhs =
const_cast<Operation *
>(lhsC);
45 auto *
rhs =
const_cast<Operation *
>(rhsC);
48 if (
lhs == getTombstoneKey() ||
lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() ||
rhs == getEmptyKey())
52 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
62 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
63 : rewriter(rewriter), domInfo(domInfo) {}
66 void simplify(Operation *op,
bool *
changed =
nullptr);
68 int64_t getNumCSE()
const {
return numCSE; }
69 int64_t getNumDCE()
const {
return numDCE; }
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
76 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
77 SimpleOperationInfo, AllocatorTy>;
84 using MemEffectsCache =
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
93 ScopedMapTy::ScopeTy scope;
96 DominanceInfoNode::const_iterator childIterator;
99 bool processed =
false;
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
109 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
110 Operation *existing,
bool hasSSADominance);
114 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
120 std::vector<Operation *> opsToErase;
122 MemEffectsCache memEffectsCache;
130void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
132 bool hasSSADominance) {
136 if (hasSSADominance) {
139 if (
auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
145 opsToErase.push_back(op);
149 auto wasVisited = [&](
OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
152 if (
auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
165 opsToErase.push_back(op);
171 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
177bool CSEDriver::hasOtherSideEffectingOpInBetween(
Operation *fromOp,
181 "expected read effect on fromOp");
183 "expected read effect on toOp");
184 Operation *nextOp = fromOp->getNextNode();
186 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
188 auto memEffectsCachePair =
result.first->second;
189 if (memEffectsCachePair.second ==
nullptr) {
192 nextOp = memEffectsCachePair.first;
199 while (nextOp && nextOp != toOp) {
200 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
207 std::make_pair(nextOp, MemoryEffects::Write::get());
212 if (isa<MemoryEffects::Write>(effect.getEffect())) {
213 result.first->second = {nextOp, MemoryEffects::Write::get()};
217 nextOp = nextOp->getNextNode();
219 result.first->second = std::make_pair(toOp,
nullptr);
224LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
226 bool hasSSADominance) {
228 if (op->
hasTrait<OpTrait::IsTerminator>())
233 opsToErase.push_back(op);
241 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
254 if (
auto *existing = knownValues.lookup(op)) {
256 !hasOtherSideEffectingOpInBetween(existing, op)) {
260 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
264 knownValues.insert(op, op);
269 if (
auto *existing = knownValues.lookup(op)) {
270 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
276 knownValues.insert(op, op);
280void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
281 bool hasSSADominance) {
282 for (
auto &op : *bb) {
289 ScopedMapTy nestedKnownValues;
291 simplifyRegion(nestedKnownValues, region);
295 simplifyRegion(knownValues, region);
300 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
304 memEffectsCache.clear();
307void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
316 ScopedMapTy::ScopeTy scope(knownValues);
317 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
324 if (!hasSSADominance)
333 std::deque<std::unique_ptr<CFGStackNode>> stack;
336 stack.emplace_back(std::make_unique<CFGStackNode>(
339 while (!stack.empty()) {
340 auto ¤tNode = stack.back();
343 if (!currentNode->processed) {
344 currentNode->processed =
true;
345 simplifyBlock(knownValues, currentNode->node->getBlock(),
350 if (currentNode->childIterator != currentNode->node->end()) {
351 auto *childNode = *(currentNode->childIterator++);
353 std::make_unique<CFGStackNode>(knownValues, childNode));
362void CSEDriver::simplify(Operation *op,
bool *
changed) {
364 ScopedMapTy knownValues;
366 simplifyRegion(knownValues, region);
369 for (
auto *op : opsToErase)
372 *
changed = !opsToErase.empty();
381 CSEDriver driver(rewriter, &domInfo);
388 void runOnOperation()
override;
392void CSE::runOnOperation() {
395 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
397 driver.simplify(getOperation(), &
changed);
400 numCSE = driver.getNumCSE();
401 numDCE = driver.getNumDCE();
405 return markAllAnalysesPreserved();
409 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
::mlir::Pass::Statistic numDCE
::mlir::Pass::Statistic numCSE
Explicitly declare the TypeID for this class.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
std::unique_ptr< Pass > createCSEPass()
Creates a pass to perform common sub expression elimination.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.