19#include "llvm/ADT/DenseMapInfo.h"
20#include "llvm/ADT/ScopedHashTable.h"
21#include "llvm/Support/Allocator.h"
22#include "llvm/Support/RecyclingAllocator.h"
29 static unsigned getHashValue(
const Operation *opC) {
31 const_cast<Operation *
>(opC),
36 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
37 auto *
lhs =
const_cast<Operation *
>(lhsC);
38 auto *
rhs =
const_cast<Operation *
>(rhsC);
41 if (
lhs == getTombstoneKey() ||
lhs == getEmptyKey() ||
42 rhs == getTombstoneKey() ||
rhs == getEmptyKey())
45 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
55 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
56 : rewriter(rewriter), domInfo(domInfo) {}
59 void simplify(Operation *op,
bool *changed =
nullptr);
62 void simplify(Region ®ion,
bool *changed =
nullptr);
64 int64_t getNumCSE()
const {
return numCSE; }
65 int64_t getNumDCE()
const {
return numDCE; }
69 using AllocatorTy = llvm::RecyclingAllocator<
70 llvm::BumpPtrAllocator,
71 llvm::ScopedHashTableVal<Operation *, Operation *>>;
72 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
73 SimpleOperationInfo, AllocatorTy>;
80 using MemEffectsCache =
86 : scope(knownValues), node(node), childIterator(node->begin()) {}
89 ScopedMapTy::ScopeTy scope;
92 DominanceInfoNode::const_iterator childIterator;
95 bool processed =
false;
100 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
101 bool hasSSADominance);
102 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
103 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
106 void eraseDeadOps(
bool *changed);
108 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
109 Operation *existing,
bool hasSSADominance);
113 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
116 RewriterBase &rewriter;
119 std::vector<Operation *> opsToErase;
120 DominanceInfo *domInfo =
nullptr;
121 MemEffectsCache memEffectsCache;
129void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
131 bool hasSSADominance) {
135 if (hasSSADominance) {
140 opsToErase.push_back(op);
144 auto wasVisited = [&](OpOperand &operand) {
145 return !knownValues.count(operand.getOwner());
147 if (
auto *rewriteListener =
148 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
150 if (all_of(v.getUses(), wasVisited))
151 rewriteListener->notifyOperationReplaced(op, existing);
160 opsToErase.push_back(op);
166 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
172bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
176 "expected read effect on fromOp");
178 "expected read effect on toOp");
182 SmallVector<MemoryEffects::EffectInstance> readEffects;
183 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
184 SmallVector<MemoryEffects::EffectInstance> fromEffects;
185 memOp.getEffects(fromEffects);
187 if (isa<MemoryEffects::Read>(e.getEffect()))
188 readEffects.push_back(e);
191 Operation *nextOp = fromOp->getNextNode();
193 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
195 auto memEffectsCachePair =
result.first->second;
196 if (memEffectsCachePair.second ==
nullptr) {
199 nextOp = memEffectsCachePair.first;
206 while (nextOp && nextOp != toOp) {
207 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
214 std::make_pair(nextOp, MemoryEffects::Write::get());
219 if (isa<MemoryEffects::Write>(effect.getEffect())) {
222 SideEffects::Resource *writeResource = effect.getResource();
224 llvm::any_of(readEffects, [&](
const auto &readEffect) {
225 SideEffects::Resource *readResource = readEffect.getResource();
230 if (readEffect.getValue() && !writeResource->
isAddressable())
237 result.first->second = {nextOp, MemoryEffects::Write::get()};
242 nextOp = nextOp->getNextNode();
244 result.first->second = std::make_pair(toOp,
nullptr);
249LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
251 bool hasSSADominance) {
253 if (op->
hasTrait<OpTrait::IsTerminator>())
259 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
272 if (
auto *existing = knownValues.lookup(op)) {
274 !hasOtherSideEffectingOpInBetween(existing, op)) {
278 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
282 knownValues.insert(op, op);
287 if (
auto *existing = knownValues.lookup(op)) {
288 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
293 knownValues.insert(op, op);
297void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
298 bool hasSSADominance) {
299 for (
auto &op : llvm::make_early_inc_range(*bb)) {
304 opsToErase.push_back(&op);
315 ScopedMapTy nestedKnownValues;
317 simplifyRegion(nestedKnownValues, region);
321 simplifyRegion(knownValues, region);
326 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
330 memEffectsCache.clear();
333void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
342 ScopedMapTy::ScopeTy scope(knownValues);
343 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
350 if (!hasSSADominance)
359 std::deque<std::unique_ptr<CFGStackNode>> stack;
362 stack.emplace_back(std::make_unique<CFGStackNode>(
365 while (!stack.empty()) {
366 auto ¤tNode = stack.back();
369 if (!currentNode->processed) {
370 currentNode->processed =
true;
371 simplifyBlock(knownValues, currentNode->node->getBlock(),
376 if (currentNode->childIterator != currentNode->node->end()) {
377 auto *childNode = *(currentNode->childIterator++);
379 std::make_unique<CFGStackNode>(knownValues, childNode));
388void CSEDriver::eraseDeadOps(
bool *changed) {
391 for (
auto *op : opsToErase) {
397 *changed = !opsToErase.empty();
404void CSEDriver::simplify(Operation *op,
bool *changed) {
406 ScopedMapTy knownValues;
408 simplifyRegion(knownValues, region);
409 eraseDeadOps(changed);
412void CSEDriver::simplify(Region ®ion,
bool *changed) {
413 ScopedMapTy knownValues;
414 simplifyRegion(knownValues, region);
415 eraseDeadOps(changed);
420 bool *changed,
int64_t *numCSE,
422 CSEDriver driver(rewriter, &domInfo);
423 driver.simplify(op, changed);
425 *numCSE = driver.getNumCSE();
427 *numDCE = driver.getNumDCE();
433 CSEDriver driver(rewriter, &domInfo);
434 driver.simplify(region, changed);
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
void invalidate()
Invalidate dominance info.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.