21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/Support/Allocator.h"
24#include "llvm/Support/RecyclingAllocator.h"
28#define GEN_PASS_DEF_CSEPASS
29#include "mlir/Transforms/Passes.h.inc"
36 static unsigned getHashValue(
const Operation *opC) {
38 const_cast<Operation *
>(opC),
43 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
44 auto *
lhs =
const_cast<Operation *
>(lhsC);
45 auto *
rhs =
const_cast<Operation *
>(rhsC);
48 if (
lhs == getTombstoneKey() ||
lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() ||
rhs == getEmptyKey())
52 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
62 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
63 : rewriter(rewriter), domInfo(domInfo) {}
66 void simplify(Operation *op,
bool *changed =
nullptr);
68 int64_t getNumCSE()
const {
return numCSE; }
69 int64_t getNumDCE()
const {
return numDCE; }
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
76 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
77 SimpleOperationInfo, AllocatorTy>;
84 using MemEffectsCache =
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
93 ScopedMapTy::ScopeTy scope;
96 DominanceInfoNode::const_iterator childIterator;
99 bool processed =
false;
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
109 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
110 Operation *existing,
bool hasSSADominance);
114 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
117 RewriterBase &rewriter;
120 std::vector<Operation *> opsToErase;
121 DominanceInfo *domInfo =
nullptr;
122 MemEffectsCache memEffectsCache;
130void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
132 bool hasSSADominance) {
136 if (hasSSADominance) {
139 if (
auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
145 opsToErase.push_back(op);
149 auto wasVisited = [&](OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
152 if (
auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
165 opsToErase.push_back(op);
171 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
177bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
181 "expected read effect on fromOp");
183 "expected read effect on toOp");
187 SmallPtrSet<SideEffects::Resource *, 1> readResources;
188 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
189 SmallVector<MemoryEffects::EffectInstance> fromEffects;
190 memOp.getEffects(fromEffects);
191 for (
const auto &e : fromEffects)
192 if (isa<MemoryEffects::Read>(e.getEffect()))
193 readResources.insert(e.getResource());
196 Operation *nextOp = fromOp->getNextNode();
198 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
200 auto memEffectsCachePair =
result.first->second;
201 if (memEffectsCachePair.second ==
nullptr) {
204 nextOp = memEffectsCachePair.first;
211 while (nextOp && nextOp != toOp) {
212 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
219 std::make_pair(nextOp, MemoryEffects::Write::get());
224 if (isa<MemoryEffects::Write>(effect.getEffect())) {
227 auto *writeResource = effect.getResource();
228 bool canConflict = llvm::any_of(readResources, [&](
auto *readResource) {
229 return !writeResource->isDisjointFrom(readResource);
232 result.first->second = {nextOp, MemoryEffects::Write::get()};
237 nextOp = nextOp->getNextNode();
239 result.first->second = std::make_pair(toOp,
nullptr);
244LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
246 bool hasSSADominance) {
248 if (op->
hasTrait<OpTrait::IsTerminator>())
253 opsToErase.push_back(op);
261 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
274 if (
auto *existing = knownValues.lookup(op)) {
276 !hasOtherSideEffectingOpInBetween(existing, op)) {
280 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
284 knownValues.insert(op, op);
289 if (
auto *existing = knownValues.lookup(op)) {
290 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
296 knownValues.insert(op, op);
300void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
301 bool hasSSADominance) {
302 for (
auto &op : *bb) {
309 ScopedMapTy nestedKnownValues;
311 simplifyRegion(nestedKnownValues, region);
315 simplifyRegion(knownValues, region);
320 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
324 memEffectsCache.clear();
327void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
336 ScopedMapTy::ScopeTy scope(knownValues);
337 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
344 if (!hasSSADominance)
353 std::deque<std::unique_ptr<CFGStackNode>> stack;
356 stack.emplace_back(std::make_unique<CFGStackNode>(
359 while (!stack.empty()) {
360 auto ¤tNode = stack.back();
363 if (!currentNode->processed) {
364 currentNode->processed =
true;
365 simplifyBlock(knownValues, currentNode->node->getBlock(),
370 if (currentNode->childIterator != currentNode->node->end()) {
371 auto *childNode = *(currentNode->childIterator++);
373 std::make_unique<CFGStackNode>(knownValues, childNode));
382void CSEDriver::simplify(Operation *op,
bool *changed) {
384 ScopedMapTy knownValues;
386 simplifyRegion(knownValues, region);
389 for (
auto *op : opsToErase)
392 *changed = !opsToErase.empty();
401 CSEDriver driver(rewriter, &domInfo);
402 driver.simplify(op, changed);
407struct CSE :
public impl::CSEPassBase<CSE> {
408 void runOnOperation()
override;
412void CSE::runOnOperation() {
415 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
416 bool changed =
false;
417 driver.simplify(getOperation(), &changed);
420 numCSE = driver.getNumCSE();
421 numDCE = driver.getNumDCE();
425 return markAllAnalysesPreserved();
429 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.