19#include "llvm/ADT/DenseMapInfo.h"
20#include "llvm/ADT/ScopedHashTable.h"
21#include "llvm/Support/Allocator.h"
22#include "llvm/Support/RecyclingAllocator.h"
29 static unsigned getHashValue(
const Operation *opC) {
31 const_cast<Operation *
>(opC),
36 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
37 auto *
lhs =
const_cast<Operation *
>(lhsC);
38 auto *
rhs =
const_cast<Operation *
>(rhsC);
41 if (
lhs == getEmptyKey() ||
rhs == getEmptyKey())
44 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
54 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
55 : rewriter(rewriter), domInfo(domInfo) {}
58 void simplify(Operation *op,
bool *changed =
nullptr);
61 void simplify(Region ®ion,
bool *changed =
nullptr);
63 int64_t getNumCSE()
const {
return numCSE; }
64 int64_t getNumDCE()
const {
return numDCE; }
68 using AllocatorTy = llvm::RecyclingAllocator<
69 llvm::BumpPtrAllocator,
70 llvm::ScopedHashTableVal<Operation *, Operation *>>;
71 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
72 SimpleOperationInfo, AllocatorTy>;
79 using MemEffectsCache =
85 : scope(knownValues), node(node), childIterator(node->begin()) {}
88 ScopedMapTy::ScopeTy scope;
91 DominanceInfoNode::const_iterator childIterator;
94 bool processed =
false;
99 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
100 bool hasSSADominance);
101 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
102 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
105 void eraseDeadOps(
bool *changed);
107 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
108 Operation *existing,
bool hasSSADominance);
112 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
115 RewriterBase &rewriter;
118 std::vector<Operation *> opsToErase;
119 DominanceInfo *domInfo =
nullptr;
120 MemEffectsCache memEffectsCache;
128void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
130 bool hasSSADominance) {
134 if (hasSSADominance) {
139 opsToErase.push_back(op);
143 auto wasVisited = [&](OpOperand &operand) {
144 return !knownValues.count(operand.getOwner());
146 if (
auto *rewriteListener =
147 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
149 if (all_of(v.getUses(), wasVisited))
150 rewriteListener->notifyOperationReplaced(op, existing);
159 opsToErase.push_back(op);
165 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
171bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
175 "expected read effect on fromOp");
177 "expected read effect on toOp");
181 SmallVector<MemoryEffects::EffectInstance> readEffects;
182 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
183 SmallVector<MemoryEffects::EffectInstance> fromEffects;
184 memOp.getEffects(fromEffects);
186 if (isa<MemoryEffects::Read>(e.getEffect()))
187 readEffects.push_back(e);
190 Operation *nextOp = fromOp->getNextNode();
192 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
194 auto memEffectsCachePair =
result.first->second;
195 if (memEffectsCachePair.second ==
nullptr) {
198 nextOp = memEffectsCachePair.first;
205 while (nextOp && nextOp != toOp) {
206 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
213 std::make_pair(nextOp, MemoryEffects::Write::get());
218 if (isa<MemoryEffects::Write>(effect.getEffect())) {
221 SideEffects::Resource *writeResource = effect.getResource();
223 llvm::any_of(readEffects, [&](
const auto &readEffect) {
224 SideEffects::Resource *readResource = readEffect.getResource();
229 if (readEffect.getValue() && !writeResource->
isAddressable())
236 result.first->second = {nextOp, MemoryEffects::Write::get()};
241 nextOp = nextOp->getNextNode();
243 result.first->second = std::make_pair(toOp,
nullptr);
248LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
250 bool hasSSADominance) {
252 if (op->
hasTrait<OpTrait::IsTerminator>())
258 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
271 if (
auto *existing = knownValues.lookup(op)) {
273 !hasOtherSideEffectingOpInBetween(existing, op)) {
277 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
281 knownValues.insert(op, op);
286 if (
auto *existing = knownValues.lookup(op)) {
287 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
292 knownValues.insert(op, op);
296void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
297 bool hasSSADominance) {
298 for (
auto &op : llvm::make_early_inc_range(*bb)) {
303 opsToErase.push_back(&op);
314 ScopedMapTy nestedKnownValues;
316 simplifyRegion(nestedKnownValues, region);
320 simplifyRegion(knownValues, region);
325 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
329 memEffectsCache.clear();
332void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
341 ScopedMapTy::ScopeTy scope(knownValues);
342 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
349 if (!hasSSADominance)
358 std::deque<std::unique_ptr<CFGStackNode>> stack;
361 stack.emplace_back(std::make_unique<CFGStackNode>(
364 while (!stack.empty()) {
365 auto ¤tNode = stack.back();
368 if (!currentNode->processed) {
369 currentNode->processed =
true;
370 simplifyBlock(knownValues, currentNode->node->getBlock(),
375 if (currentNode->childIterator != currentNode->node->end()) {
376 auto *childNode = *(currentNode->childIterator++);
378 std::make_unique<CFGStackNode>(knownValues, childNode));
387void CSEDriver::eraseDeadOps(
bool *changed) {
390 for (
auto *op : opsToErase) {
396 *changed = !opsToErase.empty();
403void CSEDriver::simplify(Operation *op,
bool *changed) {
405 ScopedMapTy knownValues;
407 simplifyRegion(knownValues, region);
408 eraseDeadOps(changed);
411void CSEDriver::simplify(Region ®ion,
bool *changed) {
412 ScopedMapTy knownValues;
413 simplifyRegion(knownValues, region);
414 eraseDeadOps(changed);
419 bool *changed,
int64_t *numCSE,
421 CSEDriver driver(rewriter, &domInfo);
422 driver.simplify(op, changed);
424 *numCSE = driver.getNumCSE();
426 *numDCE = driver.getNumDCE();
432 CSEDriver driver(rewriter, &domInfo);
433 driver.simplify(region, changed);
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
void invalidate()
Invalidate dominance info.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.