19#include "llvm/ADT/DenseMapInfo.h"
20#include "llvm/ADT/ScopedHashTable.h"
21#include "llvm/Support/Allocator.h"
22#include "llvm/Support/RecyclingAllocator.h"
29 static unsigned getHashValue(
const Operation *opC) {
31 const_cast<Operation *
>(opC),
36 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
37 auto *
lhs =
const_cast<Operation *
>(lhsC);
38 auto *
rhs =
const_cast<Operation *
>(rhsC);
42 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
52 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
53 : rewriter(rewriter), domInfo(domInfo) {}
56 void simplify(Operation *op,
bool *changed =
nullptr);
59 void simplify(Region ®ion,
bool *changed =
nullptr);
61 int64_t getNumCSE()
const {
return numCSE; }
62 int64_t getNumDCE()
const {
return numDCE; }
66 using AllocatorTy = llvm::RecyclingAllocator<
67 llvm::BumpPtrAllocator,
68 llvm::ScopedHashTableVal<Operation *, Operation *>>;
69 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
70 SimpleOperationInfo, AllocatorTy>;
77 using MemEffectsCache =
83 : scope(knownValues), node(node), childIterator(node->begin()) {}
86 ScopedMapTy::ScopeTy scope;
89 DominanceInfoNode::const_iterator childIterator;
92 bool processed =
false;
97 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
98 bool hasSSADominance);
99 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
100 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
103 void eraseDeadOps(
bool *changed);
105 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
106 Operation *existing,
bool hasSSADominance);
110 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
113 RewriterBase &rewriter;
116 std::vector<Operation *> opsToErase;
117 DominanceInfo *domInfo =
nullptr;
118 MemEffectsCache memEffectsCache;
126void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
128 bool hasSSADominance) {
132 if (hasSSADominance) {
137 opsToErase.push_back(op);
141 auto wasVisited = [&](OpOperand &operand) {
142 return !knownValues.count(operand.getOwner());
144 if (
auto *rewriteListener =
145 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
147 if (all_of(v.getUses(), wasVisited))
148 rewriteListener->notifyOperationReplaced(op, existing);
157 opsToErase.push_back(op);
163 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
169bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
173 "expected read effect on fromOp");
175 "expected read effect on toOp");
179 SmallVector<MemoryEffects::EffectInstance> readEffects;
180 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
181 SmallVector<MemoryEffects::EffectInstance> fromEffects;
182 memOp.getEffects(fromEffects);
184 if (isa<MemoryEffects::Read>(e.getEffect()))
185 readEffects.push_back(e);
188 Operation *nextOp = fromOp->getNextNode();
190 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
192 auto memEffectsCachePair =
result.first->second;
193 if (memEffectsCachePair.second ==
nullptr) {
196 nextOp = memEffectsCachePair.first;
203 while (nextOp && nextOp != toOp) {
204 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
211 std::make_pair(nextOp, MemoryEffects::Write::get());
216 if (isa<MemoryEffects::Write>(effect.getEffect())) {
219 SideEffects::Resource *writeResource = effect.getResource();
221 llvm::any_of(readEffects, [&](
const auto &readEffect) {
222 SideEffects::Resource *readResource = readEffect.getResource();
227 if (readEffect.getValue() && !writeResource->
isAddressable())
234 result.first->second = {nextOp, MemoryEffects::Write::get()};
239 nextOp = nextOp->getNextNode();
241 result.first->second = std::make_pair(toOp,
nullptr);
246LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
248 bool hasSSADominance) {
250 if (op->
hasTrait<OpTrait::IsTerminator>())
256 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
269 if (
auto *existing = knownValues.lookup(op)) {
271 !hasOtherSideEffectingOpInBetween(existing, op)) {
275 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
279 knownValues.insert(op, op);
284 if (
auto *existing = knownValues.lookup(op)) {
285 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
290 knownValues.insert(op, op);
294void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
295 bool hasSSADominance) {
296 for (
auto &op : llvm::make_early_inc_range(*bb)) {
301 opsToErase.push_back(&op);
312 ScopedMapTy nestedKnownValues;
314 simplifyRegion(nestedKnownValues, region);
318 simplifyRegion(knownValues, region);
323 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
327 memEffectsCache.clear();
330void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
339 ScopedMapTy::ScopeTy scope(knownValues);
340 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
347 if (!hasSSADominance)
356 std::deque<std::unique_ptr<CFGStackNode>> stack;
359 stack.emplace_back(std::make_unique<CFGStackNode>(
362 while (!stack.empty()) {
363 auto ¤tNode = stack.back();
366 if (!currentNode->processed) {
367 currentNode->processed =
true;
368 simplifyBlock(knownValues, currentNode->node->getBlock(),
373 if (currentNode->childIterator != currentNode->node->end()) {
374 auto *childNode = *(currentNode->childIterator++);
376 std::make_unique<CFGStackNode>(knownValues, childNode));
385void CSEDriver::eraseDeadOps(
bool *changed) {
388 for (
auto *op : opsToErase) {
394 *changed = !opsToErase.empty();
401void CSEDriver::simplify(Operation *op,
bool *changed) {
403 ScopedMapTy knownValues;
405 simplifyRegion(knownValues, region);
406 eraseDeadOps(changed);
409void CSEDriver::simplify(Region ®ion,
bool *changed) {
410 ScopedMapTy knownValues;
411 simplifyRegion(knownValues, region);
412 eraseDeadOps(changed);
417 bool *changed,
int64_t *numCSE,
419 CSEDriver driver(rewriter, &domInfo);
420 driver.simplify(op, changed);
422 *numCSE = driver.getNumCSE();
424 *numDCE = driver.getNumDCE();
430 CSEDriver driver(rewriter, &domInfo);
431 driver.simplify(region, changed);
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
void invalidate()
Invalidate dominance info.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.