21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/Support/Allocator.h"
24#include "llvm/Support/RecyclingAllocator.h"
28#define GEN_PASS_DEF_CSEPASS
29#include "mlir/Transforms/Passes.h.inc"
36 static unsigned getHashValue(
const Operation *opC) {
38 const_cast<Operation *
>(opC),
43 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
44 auto *
lhs =
const_cast<Operation *
>(lhsC);
45 auto *
rhs =
const_cast<Operation *
>(rhsC);
48 if (
lhs == getTombstoneKey() ||
lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() ||
rhs == getEmptyKey())
52 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
62 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
63 : rewriter(rewriter), domInfo(domInfo) {}
66 void simplify(Operation *op,
bool *changed =
nullptr);
68 int64_t getNumCSE()
const {
return numCSE; }
69 int64_t getNumDCE()
const {
return numDCE; }
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
76 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
77 SimpleOperationInfo, AllocatorTy>;
84 using MemEffectsCache =
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
93 ScopedMapTy::ScopeTy scope;
96 DominanceInfoNode::const_iterator childIterator;
99 bool processed =
false;
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
109 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
110 Operation *existing,
bool hasSSADominance);
114 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
117 RewriterBase &rewriter;
120 std::vector<Operation *> opsToErase;
122 MemEffectsCache memEffectsCache;
130void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
132 bool hasSSADominance) {
136 if (hasSSADominance) {
139 if (
auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
145 opsToErase.push_back(op);
149 auto wasVisited = [&](
OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
152 if (
auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
165 opsToErase.push_back(op);
171 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
177bool CSEDriver::hasOtherSideEffectingOpInBetween(
Operation *fromOp,
181 "expected read effect on fromOp");
183 "expected read effect on toOp");
188 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
190 memOp.getEffects(fromEffects);
192 if (isa<MemoryEffects::Read>(e.getEffect()))
193 readEffects.push_back(e);
196 Operation *nextOp = fromOp->getNextNode();
198 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
200 auto memEffectsCachePair =
result.first->second;
201 if (memEffectsCachePair.second ==
nullptr) {
204 nextOp = memEffectsCachePair.first;
211 while (nextOp && nextOp != toOp) {
212 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
224 if (isa<MemoryEffects::Write>(effect.getEffect())) {
229 llvm::any_of(readEffects, [&](
const auto &readEffect) {
235 if (readEffect.getValue() && !writeResource->
isAddressable())
247 nextOp = nextOp->getNextNode();
249 result.first->second = std::make_pair(toOp,
nullptr);
254LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
256 bool hasSSADominance) {
263 opsToErase.push_back(op);
271 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
284 if (
auto *existing = knownValues.lookup(op)) {
286 !hasOtherSideEffectingOpInBetween(existing, op)) {
290 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
294 knownValues.insert(op, op);
299 if (
auto *existing = knownValues.lookup(op)) {
300 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
305 knownValues.insert(op, op);
309void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
310 bool hasSSADominance) {
311 for (
auto &op : *bb) {
318 ScopedMapTy nestedKnownValues;
320 simplifyRegion(nestedKnownValues, region);
324 simplifyRegion(knownValues, region);
329 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
333 memEffectsCache.clear();
336void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
345 ScopedMapTy::ScopeTy scope(knownValues);
346 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
353 if (!hasSSADominance)
362 std::deque<std::unique_ptr<CFGStackNode>> stack;
365 stack.emplace_back(std::make_unique<CFGStackNode>(
368 while (!stack.empty()) {
369 auto ¤tNode = stack.back();
372 if (!currentNode->processed) {
373 currentNode->processed =
true;
374 simplifyBlock(knownValues, currentNode->node->getBlock(),
379 if (currentNode->childIterator != currentNode->node->end()) {
380 auto *childNode = *(currentNode->childIterator++);
382 std::make_unique<CFGStackNode>(knownValues, childNode));
391void CSEDriver::simplify(Operation *op,
bool *changed) {
393 ScopedMapTy knownValues;
395 simplifyRegion(knownValues, region);
398 for (
auto *op : opsToErase)
401 *changed = !opsToErase.empty();
410 CSEDriver driver(rewriter, &domInfo);
411 driver.simplify(op, changed);
417 void runOnOperation()
override;
421void CSE::runOnOperation() {
424 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
425 bool changed =
false;
426 driver.simplify(getOperation(), &changed);
429 numCSE = driver.getNumCSE();
430 numDCE = driver.getNumDCE();
434 return markAllAnalysesPreserved();
438 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents an operand of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static DerivedEffect * get()
This class represents a specific resource that an effect applies to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
::mlir::Pass::Statistic numDCE
::mlir::Pass::Statistic numCSE
Explicitly declare the TypeID for this class.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.