19#include "llvm/ADT/DenseMapInfo.h"
20#include "llvm/ADT/ScopedHashTable.h"
21#include "llvm/Support/Allocator.h"
22#include "llvm/Support/RecyclingAllocator.h"
29 static unsigned getHashValue(
const Operation *opC) {
31 const_cast<Operation *
>(opC),
36 static bool isEqual(
const Operation *lhsC,
const Operation *rhsC) {
37 auto *
lhs =
const_cast<Operation *
>(lhsC);
38 auto *
rhs =
const_cast<Operation *
>(rhsC);
42 const_cast<Operation *
>(lhsC),
const_cast<Operation *
>(rhsC),
52 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
53 : rewriter(rewriter), domInfo(domInfo) {}
56 void simplify(Operation *op,
bool *changed =
nullptr);
59 void simplify(Region ®ion,
bool *changed =
nullptr);
61 int64_t getNumCSE()
const {
return numCSE; }
62 int64_t getNumDCE()
const {
return numDCE; }
66 using AllocatorTy = llvm::RecyclingAllocator<
67 llvm::BumpPtrAllocator,
68 llvm::ScopedHashTableVal<Operation *, Operation *>>;
69 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
70 SimpleOperationInfo, AllocatorTy>;
77 using MemEffectsCache =
83 : scope(knownValues), node(node), childIterator(node->begin()) {}
86 ScopedMapTy::ScopeTy scope;
89 DominanceInfoNode::const_iterator childIterator;
92 bool processed =
false;
97 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
98 bool hasSSADominance);
99 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
100 void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
104 void eraseDeadOp(Operation *op);
106 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
107 Operation *existing,
bool hasSSADominance);
111 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
114 RewriterBase &rewriter;
116 DominanceInfo *domInfo =
nullptr;
117 MemEffectsCache memEffectsCache;
125void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
127 bool hasSSADominance) {
131 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
138 if (hasSSADominance) {
149 auto wasVisited = [&](OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
152 if (
auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.
getListener()))
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
168bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
172 "expected read effect on fromOp");
174 "expected read effect on toOp");
178 SmallVector<MemoryEffects::EffectInstance> readEffects;
179 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
180 SmallVector<MemoryEffects::EffectInstance> fromEffects;
181 memOp.getEffects(fromEffects);
183 if (isa<MemoryEffects::Read>(e.getEffect()))
184 readEffects.push_back(e);
187 Operation *nextOp = fromOp->getNextNode();
189 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
191 auto memEffectsCachePair =
result.first->second;
192 if (memEffectsCachePair.second ==
nullptr) {
195 nextOp = memEffectsCachePair.first;
202 while (nextOp && nextOp != toOp) {
203 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
210 std::make_pair(nextOp, MemoryEffects::Write::get());
215 if (isa<MemoryEffects::Write>(effect.getEffect())) {
218 SideEffects::Resource *writeResource = effect.getResource();
220 llvm::any_of(readEffects, [&](
const auto &readEffect) {
221 SideEffects::Resource *readResource = readEffect.getResource();
226 if (readEffect.getValue() && !writeResource->
isAddressable())
233 result.first->second = {nextOp, MemoryEffects::Write::get()};
238 nextOp = nextOp->getNextNode();
244 result.first->second = std::make_pair(toOp->getPrevNode(),
nullptr);
249LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
251 bool hasSSADominance) {
253 if (op->
hasTrait<OpTrait::IsTerminator>())
259 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
272 if (
auto *existing = knownValues.lookup(op)) {
274 !hasOtherSideEffectingOpInBetween(existing, op)) {
278 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
282 knownValues.insert(op, op);
287 if (
auto *existing = knownValues.lookup(op)) {
288 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
293 knownValues.insert(op, op);
297void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
298 bool hasSSADominance) {
299 for (
auto &op : llvm::make_early_inc_range(*bb)) {
315 ScopedMapTy nestedKnownValues;
317 simplifyRegion(nestedKnownValues, region);
321 simplifyRegion(knownValues, region);
326 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
330 memEffectsCache.clear();
333void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
342 ScopedMapTy::ScopeTy scope(knownValues);
343 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
350 if (!hasSSADominance)
359 std::deque<std::unique_ptr<CFGStackNode>> stack;
362 stack.emplace_back(std::make_unique<CFGStackNode>(
365 while (!stack.empty()) {
366 auto ¤tNode = stack.back();
369 if (!currentNode->processed) {
370 currentNode->processed =
true;
371 simplifyBlock(knownValues, currentNode->node->getBlock(),
376 if (currentNode->childIterator != currentNode->node->end()) {
377 auto *childNode = *(currentNode->childIterator++);
379 std::make_unique<CFGStackNode>(knownValues, childNode));
388void CSEDriver::eraseDeadOp(Operation *op) {
399void CSEDriver::simplify(Operation *op,
bool *changed) {
401 ScopedMapTy knownValues;
403 simplifyRegion(knownValues, region);
405 *changed = numCSE || numDCE;
408void CSEDriver::simplify(Region ®ion,
bool *changed) {
409 ScopedMapTy knownValues;
410 simplifyRegion(knownValues, region);
412 *changed = numCSE || numDCE;
417 bool *changed,
int64_t *numCSE,
419 CSEDriver driver(rewriter, &domInfo);
420 driver.simplify(op, changed);
422 *numCSE = driver.getNumCSE();
424 *numDCE = driver.getNumDCE();
430 CSEDriver driver(rewriter, &domInfo);
431 driver.simplify(region, changed);
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
void invalidate()
Invalidate dominance info.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.