21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/ScopedHashTable.h"
24 #include "llvm/Support/Allocator.h"
25 #include "llvm/Support/RecyclingAllocator.h"
29 #define GEN_PASS_DEF_CSE
30 #include "mlir/Transforms/Passes.h.inc"
37 static unsigned getHashValue(
const Operation *opC) {
45 auto *lhs =
const_cast<Operation *
>(lhsC);
46 auto *rhs =
const_cast<Operation *
>(rhsC);
49 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
50 rhs == getTombstoneKey() || rhs == getEmptyKey())
64 : rewriter(rewriter), domInfo(domInfo) {}
69 int64_t getNumCSE()
const {
return numCSE; }
70 int64_t getNumDCE()
const {
return numDCE; }
74 using AllocatorTy = llvm::RecyclingAllocator<
75 llvm::BumpPtrAllocator,
76 llvm::ScopedHashTableVal<Operation *, Operation *>>;
78 SimpleOperationInfo, AllocatorTy>;
85 using MemEffectsCache =
91 : scope(knownValues), node(node), childIterator(node->begin()) {}
94 ScopedMapTy::ScopeTy scope;
97 DominanceInfoNode::const_iterator childIterator;
100 bool processed =
false;
105 LogicalResult simplifyOperation(ScopedMapTy &knownValues,
Operation *op,
106 bool hasSSADominance);
107 void simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
bool hasSSADominance);
108 void simplifyRegion(ScopedMapTy &knownValues,
Region ®ion);
110 void replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
111 Operation *existing,
bool hasSSADominance);
121 std::vector<Operation *> opsToErase;
123 MemEffectsCache memEffectsCache;
131 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues,
Operation *op,
133 bool hasSSADominance) {
137 if (hasSSADominance) {
140 if (
auto *rewriteListener =
141 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
142 rewriteListener->notifyOperationReplaced(op, existing);
146 opsToErase.push_back(op);
150 auto wasVisited = [&](
OpOperand &operand) {
151 return !knownValues.count(operand.getOwner());
153 if (
auto *rewriteListener =
154 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
156 if (all_of(v.getUses(), wasVisited))
157 rewriteListener->notifyOperationReplaced(op, existing);
166 opsToErase.push_back(op);
172 if (isa<UnknownLoc>(existing->
getLoc()) && !isa<UnknownLoc>(op->
getLoc()))
178 bool CSEDriver::hasOtherSideEffectingOpInBetween(
Operation *fromOp,
182 isa<MemoryEffectOpInterface>(fromOp) &&
183 cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
184 isa<MemoryEffectOpInterface>(toOp) &&
185 cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
186 Operation *nextOp = fromOp->getNextNode();
188 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp,
nullptr));
190 auto memEffectsCachePair = result.first->second;
191 if (memEffectsCachePair.second ==
nullptr) {
194 nextOp = memEffectsCachePair.first;
201 while (nextOp && nextOp != toOp) {
202 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
208 result.first->second =
214 if (isa<MemoryEffects::Write>(effect.getEffect())) {
219 nextOp = nextOp->getNextNode();
221 result.first->second = std::make_pair(toOp,
nullptr);
226 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
228 bool hasSSADominance) {
235 opsToErase.push_back(op);
243 return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks());
250 auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
258 if (
auto *existing = knownValues.lookup(op)) {
260 !hasOtherSideEffectingOpInBetween(existing, op)) {
264 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
268 knownValues.insert(op, op);
273 if (
auto *existing = knownValues.lookup(op)) {
274 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
280 knownValues.insert(op, op);
284 void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
Block *bb,
285 bool hasSSADominance) {
286 for (
auto &op : *bb) {
293 ScopedMapTy nestedKnownValues;
295 simplifyRegion(nestedKnownValues, region);
299 simplifyRegion(knownValues, region);
304 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
308 memEffectsCache.clear();
311 void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
Region ®ion) {
316 bool hasSSADominance = domInfo->hasSSADominance(®ion);
320 ScopedMapTy::ScopeTy scope(knownValues);
321 simplifyBlock(knownValues, ®ion.
front(), hasSSADominance);
328 if (!hasSSADominance)
337 std::deque<std::unique_ptr<CFGStackNode>> stack;
340 stack.emplace_back(std::make_unique<CFGStackNode>(
341 knownValues, domInfo->getRootNode(®ion)));
343 while (!stack.empty()) {
344 auto ¤tNode = stack.back();
347 if (!currentNode->processed) {
348 currentNode->processed =
true;
349 simplifyBlock(knownValues, currentNode->node->getBlock(),
354 if (currentNode->childIterator != currentNode->node->end()) {
355 auto *childNode = *(currentNode->childIterator++);
357 std::make_unique<CFGStackNode>(knownValues, childNode));
368 ScopedMapTy knownValues;
370 simplifyRegion(knownValues, region);
373 for (
auto *op : opsToErase)
374 rewriter.eraseOp(op);
376 *
changed = !opsToErase.empty();
385 CSEDriver driver(rewriter, &domInfo);
391 struct CSE :
public impl::CSEBase<CSE> {
392 void runOnOperation()
override;
396 void CSE::runOnOperation() {
399 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
401 driver.simplify(getOperation(), &
changed);
404 numCSE = driver.getNumCSE();
405 numDCE = driver.getNumDCE();
409 return markAllAnalysesPreserved();
413 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
A class for computing basic dominance information.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents an operand of an operation.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
std::unique_ptr< Pass > createCSEPass()
Creates a pass to perform common sub expression elimination.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
The following effect indicates that the operation reads from some resource.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.