MLIR 23.0.0git
CSE.cpp
Go to the documentation of this file.
1//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements common sub-expression elimination as a library utility.
10// The matching CSE pass is a thin wrapper over the APIs declared here.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/CSE.h"
15
16#include "mlir/IR/Dominance.h"
19#include "llvm/ADT/DenseMapInfo.h"
20#include "llvm/ADT/ScopedHashTable.h"
21#include "llvm/Support/Allocator.h"
22#include "llvm/Support/RecyclingAllocator.h"
23#include <deque>
24
25using namespace mlir;
26
27namespace {
28struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
29 static unsigned getHashValue(const Operation *opC) {
31 const_cast<Operation *>(opC),
35 }
36 static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
37 auto *lhs = const_cast<Operation *>(lhsC);
38 auto *rhs = const_cast<Operation *>(rhsC);
39 if (lhs == rhs)
40 return true;
41 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
42 rhs == getTombstoneKey() || rhs == getEmptyKey())
43 return false;
45 const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
47 }
48};
49} // namespace
50
51namespace {
52/// Simple common sub-expression elimination.
53class CSEDriver {
54public:
55 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
56 : rewriter(rewriter), domInfo(domInfo) {}
57
58 /// Simplify all operations within the given op.
59 void simplify(Operation *op, bool *changed = nullptr);
60
61 /// Simplify operations within the given region.
62 void simplify(Region &region, bool *changed = nullptr);
63
64 int64_t getNumCSE() const { return numCSE; }
65 int64_t getNumDCE() const { return numDCE; }
66
67private:
68 /// Shared implementation of operation elimination and scoped map definitions.
69 using AllocatorTy = llvm::RecyclingAllocator<
70 llvm::BumpPtrAllocator,
71 llvm::ScopedHashTableVal<Operation *, Operation *>>;
72 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
73 SimpleOperationInfo, AllocatorTy>;
74
75 /// Cache holding MemoryEffects information between two operations. The first
76 /// operation is stored has the key. The second operation is stored inside a
77 /// pair in the value. The pair also hold the MemoryEffects between those
78 /// two operations. If the MemoryEffects is nullptr then we assume there is
79 /// no operation with MemoryEffects::Write between the two operations.
80 using MemEffectsCache =
82
83 /// Represents a single entry in the depth first traversal of a CFG.
84 struct CFGStackNode {
85 CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
86 : scope(knownValues), node(node), childIterator(node->begin()) {}
87
88 /// Scope for the known values.
89 ScopedMapTy::ScopeTy scope;
90
92 DominanceInfoNode::const_iterator childIterator;
93
94 /// If this node has been fully processed yet or not.
95 bool processed = false;
96 };
97
98 /// Attempt to eliminate a redundant operation. Returns success if the
99 /// operation was marked for removal, failure otherwise.
100 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
101 bool hasSSADominance);
102 void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
103 void simplifyRegion(ScopedMapTy &knownValues, Region &region);
104
105 /// Erase all operations queued for deletion by the simplification routines.
106 void eraseDeadOps(bool *changed);
107
108 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
109 Operation *existing, bool hasSSADominance);
110
111 /// Check if there is side-effecting operations other than the given effect
112 /// between the two operations.
113 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
114
115 /// A rewriter for modifying the IR.
116 RewriterBase &rewriter;
117
118 /// Operations marked as dead and to be erased.
119 std::vector<Operation *> opsToErase;
120 DominanceInfo *domInfo = nullptr;
121 MemEffectsCache memEffectsCache;
122
123 // Various statistics.
124 int64_t numCSE = 0;
125 int64_t numDCE = 0;
126};
127} // namespace
128
129void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
130 Operation *existing,
131 bool hasSSADominance) {
132 // If we find one then replace all uses of the current operation with the
133 // existing one and mark it for deletion. We can only replace an operand in
134 // an operation if it has not been visited yet.
135 if (hasSSADominance) {
136 // If the region has SSA dominance, then we are guaranteed to have not
137 // visited any use of the current operation.
138 // Replace all uses, but do not remove the operation yet.
139 rewriter.replaceAllOpUsesWith(op, existing->getResults());
140 opsToErase.push_back(op);
141 } else {
142 // When the region does not have SSA dominance, we need to check if we
143 // have visited a use before replacing any use.
144 auto wasVisited = [&](OpOperand &operand) {
145 return !knownValues.count(operand.getOwner());
146 };
147 if (auto *rewriteListener =
148 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
149 for (Value v : op->getResults())
150 if (all_of(v.getUses(), wasVisited))
151 rewriteListener->notifyOperationReplaced(op, existing);
152
153 // Replace all uses, but do not remove the operation yet. This does not
154 // notify the listener because the original op is not erased.
155 rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
156 wasVisited);
157
158 // There may be some remaining uses of the operation.
159 if (op->use_empty())
160 opsToErase.push_back(op);
161 }
162
163 // If the existing operation has an unknown location and the current
164 // operation doesn't, then set the existing op's location to that of the
165 // current op.
166 if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
167 existing->setLoc(op->getLoc());
168
169 ++numCSE;
170}
171
172bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
173 Operation *toOp) {
174 assert(fromOp->getBlock() == toOp->getBlock());
175 assert(hasEffect<MemoryEffects::Read>(fromOp) &&
176 "expected read effect on fromOp");
177 assert(hasEffect<MemoryEffects::Read>(toOp) &&
178 "expected read effect on toOp");
179
180 // Collect the read effects of fromOp. A write can only block CSE if it
181 // can conflict with one of these reads.
182 SmallVector<MemoryEffects::EffectInstance> readEffects;
183 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
184 SmallVector<MemoryEffects::EffectInstance> fromEffects;
185 memOp.getEffects(fromEffects);
186 for (MemoryEffects::EffectInstance &e : fromEffects)
187 if (isa<MemoryEffects::Read>(e.getEffect()))
188 readEffects.push_back(e);
189 }
190
191 Operation *nextOp = fromOp->getNextNode();
192 auto result =
193 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
194 if (!result.second) {
195 auto memEffectsCachePair = result.first->second;
196 if (memEffectsCachePair.second == nullptr) {
197 // No MemoryEffects::Write has been detected until the cached operation.
198 // Continue looking from the cached operation to toOp.
199 nextOp = memEffectsCachePair.first;
200 } else {
201 // MemoryEffects::Write has been detected before so there is no need to
202 // check further.
203 return true;
204 }
205 }
206 while (nextOp && nextOp != toOp) {
207 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
208 getEffectsRecursively(nextOp);
209 if (!effects) {
210 // TODO: Do we need to handle other effects generically?
211 // If the operation does not implement the MemoryEffectOpInterface we
212 // conservatively assume it writes.
213 result.first->second =
214 std::make_pair(nextOp, MemoryEffects::Write::get());
215 return true;
216 }
217
218 for (const MemoryEffects::EffectInstance &effect : *effects) {
219 if (isa<MemoryEffects::Write>(effect.getEffect())) {
220 // A write on a resource disjoint from all read resources cannot
221 // conflict with the reads being CSE'd.
222 SideEffects::Resource *writeResource = effect.getResource();
223 bool canConflict =
224 llvm::any_of(readEffects, [&](const auto &readEffect) {
225 SideEffects::Resource *readResource = readEffect.getResource();
226 if (writeResource->isDisjointFrom(readResource))
227 return false;
228 // A pointer-based access to an addressable resource cannot
229 // conflict with a non-addressable resource.
230 if (readEffect.getValue() && !writeResource->isAddressable())
231 return false;
232 if (effect.getValue() && !readResource->isAddressable())
233 return false;
234 return true;
235 });
236 if (canConflict) {
237 result.first->second = {nextOp, MemoryEffects::Write::get()};
238 return true;
239 }
240 }
241 }
242 nextOp = nextOp->getNextNode();
243 }
244 result.first->second = std::make_pair(toOp, nullptr);
245 return false;
246}
247
248/// Attempt to eliminate a redundant operation.
249LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
250 Operation *op,
251 bool hasSSADominance) {
252 // Don't simplify terminator operations.
253 if (op->hasTrait<OpTrait::IsTerminator>())
254 return failure();
255
256 // Don't simplify operations with regions that have multiple blocks.
257 // TODO: We need additional tests to verify that we handle such IR correctly.
258 if (!llvm::all_of(op->getRegions(),
259 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
260 return failure();
261
262 // Some simple use case of operation with memory side-effect are dealt with
263 // here. Operations with no side-effect are done after.
264 if (!isMemoryEffectFree(op)) {
265 // TODO: Only basic use case for operations with MemoryEffects::Read can be
266 // eleminated now. More work needs to be done for more complicated patterns
267 // and other side-effects.
269 return failure();
270
271 // Look for an existing definition for the operation.
272 if (auto *existing = knownValues.lookup(op)) {
273 if (existing->getBlock() == op->getBlock() &&
274 !hasOtherSideEffectingOpInBetween(existing, op)) {
275 // The operation that can be deleted has been reach with no
276 // side-effecting operations in between the existing operation and
277 // this one so we can remove the duplicate.
278 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
279 return success();
280 }
281 }
282 knownValues.insert(op, op);
283 return failure();
284 }
285
286 // Look for an existing definition for the operation.
287 if (auto *existing = knownValues.lookup(op)) {
288 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
289 return success();
290 }
291
292 // Otherwise, we add this operation to the known values map.
293 knownValues.insert(op, op);
294 return failure();
295}
296
297void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
298 bool hasSSADominance) {
299 for (auto &op : llvm::make_early_inc_range(*bb)) {
300 // If the operation is already trivially dead just add it to the erase list.
301 // This also avoids calling `simplifyRegion` on dead region ops
302 // unnecessarily.
303 if (isOpTriviallyDead(&op)) {
304 opsToErase.push_back(&op);
305 ++numDCE;
306 continue;
307 }
308
309 // Most operations don't have regions, so fast path that case.
310 if (op.getNumRegions() != 0) {
311 // If this operation is isolated above, we can't process nested regions
312 // with the given 'knownValues' map. This would cause the insertion of
313 // implicit captures in explicit capture only regions.
314 if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
315 ScopedMapTy nestedKnownValues;
316 for (auto &region : op.getRegions())
317 simplifyRegion(nestedKnownValues, region);
318 } else {
319 // Otherwise, process nested regions normally.
320 for (auto &region : op.getRegions())
321 simplifyRegion(knownValues, region);
322 }
323 }
324
325 // If the operation is simplified, we don't process any held regions.
326 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
327 continue;
328 }
329 // Clear the MemoryEffects cache since its usage is by block only.
330 memEffectsCache.clear();
331}
332
333void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
334 // If the region is empty there is nothing to do.
335 if (region.empty())
336 return;
337
338 bool hasSSADominance = domInfo->hasSSADominance(&region);
339
340 // If the region only contains one block, then simplify it directly.
341 if (region.hasOneBlock()) {
342 ScopedMapTy::ScopeTy scope(knownValues);
343 simplifyBlock(knownValues, &region.front(), hasSSADominance);
344 return;
345 }
346
347 // If the region does not have dominanceInfo, then skip it.
348 // TODO: Regions without SSA dominance should define a different
349 // traversal order which is appropriate and can be used here.
350 if (!hasSSADominance)
351 return;
352
353 // Note, deque is being used here because there was significant performance
354 // gains over vector when the container becomes very large due to the
355 // specific access patterns. If/when these performance issues are no
356 // longer a problem we can change this to vector. For more information see
357 // the llvm mailing list discussion on this:
358 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
359 std::deque<std::unique_ptr<CFGStackNode>> stack;
360
361 // Process the nodes of the dom tree for this region.
362 stack.emplace_back(std::make_unique<CFGStackNode>(
363 knownValues, domInfo->getRootNode(&region)));
364
365 while (!stack.empty()) {
366 auto &currentNode = stack.back();
367
368 // Check to see if we need to process this node.
369 if (!currentNode->processed) {
370 currentNode->processed = true;
371 simplifyBlock(knownValues, currentNode->node->getBlock(),
372 hasSSADominance);
373 }
374
375 // Otherwise, check to see if we need to process a child node.
376 if (currentNode->childIterator != currentNode->node->end()) {
377 auto *childNode = *(currentNode->childIterator++);
378 stack.emplace_back(
379 std::make_unique<CFGStackNode>(knownValues, childNode));
380 } else {
381 // Finally, if the node and all of its children have been processed
382 // then we delete the node.
383 stack.pop_back();
384 }
385 }
386}
387
388void CSEDriver::eraseDeadOps(bool *changed) {
389 // Erase any operations that were marked as dead during simplification, and
390 // remove their associated dominator trees.
391 for (auto *op : opsToErase) {
392 for (Region &region : op->getRegions())
393 domInfo->invalidate(&region);
394 rewriter.eraseOp(op);
395 }
396 if (changed)
397 *changed = !opsToErase.empty();
398 opsToErase.clear();
399
400 // Note: CSE does currently not remove ops with regions, so DominanceInfo
401 // does not have to be invalidated.
402}
403
404void CSEDriver::simplify(Operation *op, bool *changed) {
405 // Simplify all regions.
406 ScopedMapTy knownValues;
407 for (auto &region : op->getRegions())
408 simplifyRegion(knownValues, region);
409 eraseDeadOps(changed);
410}
411
412void CSEDriver::simplify(Region &region, bool *changed) {
413 ScopedMapTy knownValues;
414 simplifyRegion(knownValues, region);
415 eraseDeadOps(changed);
416}
417
419 DominanceInfo &domInfo, Operation *op,
420 bool *changed, int64_t *numCSE,
421 int64_t *numDCE) {
422 CSEDriver driver(rewriter, &domInfo);
423 driver.simplify(op, changed);
424 if (numCSE)
425 *numCSE = driver.getNumCSE();
426 if (numDCE)
427 *numDCE = driver.getNumDCE();
428}
429
431 DominanceInfo &domInfo, Region &region,
432 bool *changed) {
433 CSEDriver driver(rewriter, &domInfo);
434 driver.simplify(region, changed);
435}
return success()
lhs
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Definition Dominance.h:140
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:244
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:878
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:783
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_range getResults()
Definition Operation.h:441
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
Definition Dominance.h:74
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
Definition Dominance.h:92
void invalidate()
Invalidate dominance info.
Definition Dominance.cpp:37
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:418
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
Definition Dominance.h:30
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.