21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/Support/Allocator.h"
24 #include "llvm/Support/RecyclingAllocator.h"
28 #define GEN_PASS_DEF_CSE
29 #include "mlir/Transforms/Passes.h.inc"
36 static unsigned getHashValue(
const Operation *opC) {
44 auto *lhs =
const_cast<Operation *
>(lhsC);
45 auto *rhs =
const_cast<Operation *
>(rhsC);
48 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() || rhs == getEmptyKey())
63 : rewriter(rewriter), domInfo(domInfo) {}
68 int64_t getNumCSE()
const {
return numCSE; }
69 int64_t getNumDCE()
const {
return numDCE; }
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
77 SimpleOperationInfo, AllocatorTy>;
84 using MemEffectsCache =
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
93 ScopedMapTy::ScopeTy scope;
96 DominanceInfoNode::const_iterator childIterator;
99 bool processed =
false;
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues,
Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues,
Region ®ion);
109 void replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
110 Operation *existing,
bool hasSSADominance);
120 std::vector<Operation *> opsToErase;
122 MemEffectsCache memEffectsCache;
130 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
132 bool hasSSADominance) {
136 if (hasSSADominance) {
139 if (
auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
145 opsToErase.push_back(op);
149 auto wasVisited = [&](
OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
152 if (
auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
165 opsToErase.push_back(op);
171 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
177 bool CSEDriver::hasOtherSideEffectingOpInBetween(
Operation *fromOp,
180 assert(hasEffect<MemoryEffects::Read>(fromOp) &&
181 "expected read effect on fromOp");
182 assert(hasEffect<MemoryEffects::Read>(toOp) &&
183 "expected read effect on toOp");
184 Operation *nextOp = fromOp->getNextNode();
186 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
188 auto memEffectsCachePair = result.first->second;
189 if (memEffectsCachePair.second ==
nullptr) {
192 nextOp = memEffectsCachePair.first;
199 while (nextOp && nextOp != toOp) {
200 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
206 result.first->second =
212 if (isa<MemoryEffects::Write>(effect.getEffect())) {
217 nextOp = nextOp->getNextNode();
219 result.first->second = std::make_pair(toOp,
nullptr);
224 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
226 bool hasSSADominance) {
233 opsToErase.push_back(op);
241 [](
Region &r) { return r.empty() || r.hasOneBlock(); }))
250 if (!hasSingleEffect<MemoryEffects::Read>(op))
254 if (
auto *existing = knownValues.lookup(op)) {
256 !hasOtherSideEffectingOpInBetween(existing, op)) {
260 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
264 knownValues.insert(op, op);
269 if (
auto *existing = knownValues.lookup(op)) {
270 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
276 knownValues.insert(op, op);
280 void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
281 bool hasSSADominance) {
282 for (
auto &op : *bb) {
289 ScopedMapTy nestedKnownValues;
291 simplifyRegion(nestedKnownValues, region);
295 simplifyRegion(knownValues, region);
300 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
304 memEffectsCache.clear();
307 void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
Region ®ion) {
312 bool hasSSADominance = domInfo->hasSSADominance(®ion);
316 ScopedMapTy::ScopeTy scope(knownValues);
317 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
324 if (!hasSSADominance)
333 std::deque<std::unique_ptr<CFGStackNode>> stack;
336 stack.emplace_back(std::make_unique<CFGStackNode>(
337 knownValues, domInfo->getRootNode(®ion)));
339 while (!stack.empty()) {
340 auto ¤tNode = stack.back();
343 if (!currentNode->processed) {
344 currentNode->processed =
true;
345 simplifyBlock(knownValues, currentNode->node->getBlock(),
350 if (currentNode->childIterator != currentNode->node->end()) {
351 auto *childNode = *(currentNode->childIterator++);
353 std::make_unique<CFGStackNode>(knownValues, childNode));
364 ScopedMapTy knownValues;
366 simplifyRegion(knownValues, region);
369 for (
auto *op : opsToErase)
370 rewriter.eraseOp(op);
372 *
changed = !opsToErase.empty();
381 CSEDriver driver(rewriter, &domInfo);
387 struct CSE :
public impl::CSEBase<CSE> {
388 void runOnOperation()
override;
392 void CSE::runOnOperation() {
395 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
397 driver.simplify(getOperation(), &
changed);
400 numCSE = driver.getNumCSE();
401 numDCE = driver.getNumDCE();
405 return markAllAnalysesPreserved();
409 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
A class for computing basic dominance information.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents an operand of an operation.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
std::unique_ptr< Pass > createCSEPass()
Creates a pass to perform common sub expression elimination.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.