MLIR 23.0.0git
CSE.cpp
Go to the documentation of this file.
1//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass performs a simple common sub-expression elimination
10// algorithm on operations within a region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/CSE.h"
15
16#include "mlir/IR/Dominance.h"
19#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/Support/Allocator.h"
24#include "llvm/Support/RecyclingAllocator.h"
25#include <deque>
26
27namespace mlir {
28#define GEN_PASS_DEF_CSEPASS
29#include "mlir/Transforms/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34namespace {
35struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
36 static unsigned getHashValue(const Operation *opC) {
38 const_cast<Operation *>(opC),
42 }
43 static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
44 auto *lhs = const_cast<Operation *>(lhsC);
45 auto *rhs = const_cast<Operation *>(rhsC);
46 if (lhs == rhs)
47 return true;
48 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() || rhs == getEmptyKey())
50 return false;
52 const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
54 }
55};
56} // namespace
57
58namespace {
59/// Simple common sub-expression elimination.
60class CSEDriver {
61public:
62 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
63 : rewriter(rewriter), domInfo(domInfo) {}
64
65 /// Simplify all operations within the given op.
66 void simplify(Operation *op, bool *changed = nullptr);
67
68 int64_t getNumCSE() const { return numCSE; }
69 int64_t getNumDCE() const { return numDCE; }
70
71private:
72 /// Shared implementation of operation elimination and scoped map definitions.
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
76 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
77 SimpleOperationInfo, AllocatorTy>;
78
79 /// Cache holding MemoryEffects information between two operations. The first
80 /// operation is stored has the key. The second operation is stored inside a
81 /// pair in the value. The pair also hold the MemoryEffects between those
82 /// two operations. If the MemoryEffects is nullptr then we assume there is
83 /// no operation with MemoryEffects::Write between the two operations.
84 using MemEffectsCache =
86
87 /// Represents a single entry in the depth first traversal of a CFG.
88 struct CFGStackNode {
89 CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
91
92 /// Scope for the known values.
93 ScopedMapTy::ScopeTy scope;
94
96 DominanceInfoNode::const_iterator childIterator;
97
98 /// If this node has been fully processed yet or not.
99 bool processed = false;
100 };
101
102 /// Attempt to eliminate a redundant operation. Returns success if the
103 /// operation was marked for removal, failure otherwise.
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues, Region &region);
108
109 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
110 Operation *existing, bool hasSSADominance);
111
112 /// Check if there is side-effecting operations other than the given effect
113 /// between the two operations.
114 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
115
116 /// A rewriter for modifying the IR.
117 RewriterBase &rewriter;
118
119 /// Operations marked as dead and to be erased.
120 std::vector<Operation *> opsToErase;
121 DominanceInfo *domInfo = nullptr;
122 MemEffectsCache memEffectsCache;
123
124 // Various statistics.
128} // namespace
130void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
131 Operation *existing,
132 bool hasSSADominance) {
133 // If we find one then replace all uses of the current operation with the
134 // existing one and mark it for deletion. We can only replace an operand in
135 // an operation if it has not been visited yet.
136 if (hasSSADominance) {
137 // If the region has SSA dominance, then we are guaranteed to have not
138 // visited any use of the current operation.
139 if (auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
142 // Replace all uses, but do not remove the operation yet. This does not
143 // notify the listener because the original op is not erased.
144 rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
145 opsToErase.push_back(op);
146 } else {
147 // When the region does not have SSA dominance, we need to check if we
148 // have visited a use before replacing any use.
149 auto wasVisited = [&](OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
151 };
152 if (auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
154 for (Value v : op->getResults())
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
157
158 // Replace all uses, but do not remove the operation yet. This does not
159 // notify the listener because the original op is not erased.
160 rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
161 wasVisited);
162
163 // There may be some remaining uses of the operation.
164 if (op->use_empty())
165 opsToErase.push_back(op);
168 // If the existing operation has an unknown location and the current
169 // operation doesn't, then set the existing op's location to that of the
170 // current op.
171 if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
172 existing->setLoc(op->getLoc());
173
174 ++numCSE;
175}
177bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
178 Operation *toOp) {
179 assert(fromOp->getBlock() == toOp->getBlock());
180 assert(hasEffect<MemoryEffects::Read>(fromOp) &&
181 "expected read effect on fromOp");
182 assert(hasEffect<MemoryEffects::Read>(toOp) &&
183 "expected read effect on toOp");
184
185 // Collect the read effects of fromOp. A write can only block CSE if it
186 // can conflict with one of these reads.
188 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
190 memOp.getEffects(fromEffects);
191 for (MemoryEffects::EffectInstance &e : fromEffects)
192 if (isa<MemoryEffects::Read>(e.getEffect()))
193 readEffects.push_back(e);
194 }
195
196 Operation *nextOp = fromOp->getNextNode();
197 auto result =
198 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
199 if (result.second) {
200 auto memEffectsCachePair = result.first->second;
201 if (memEffectsCachePair.second == nullptr) {
202 // No MemoryEffects::Write has been detected until the cached operation.
203 // Continue looking from the cached operation to toOp.
204 nextOp = memEffectsCachePair.first;
205 } else {
206 // MemoryEffects::Write has been detected before so there is no need to
207 // check further.
208 return true;
209 }
210 }
211 while (nextOp && nextOp != toOp) {
212 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
213 getEffectsRecursively(nextOp);
214 if (!effects) {
215 // TODO: Do we need to handle other effects generically?
216 // If the operation does not implement the MemoryEffectOpInterface we
217 // conservatively assume it writes.
218 result.first->second =
219 std::make_pair(nextOp, MemoryEffects::Write::get());
220 return true;
221 }
222
223 for (const MemoryEffects::EffectInstance &effect : *effects) {
224 if (isa<MemoryEffects::Write>(effect.getEffect())) {
225 // A write on a resource disjoint from all read resources cannot
226 // conflict with the reads being CSE'd.
227 SideEffects::Resource *writeResource = effect.getResource();
228 bool canConflict =
229 llvm::any_of(readEffects, [&](const auto &readEffect) {
230 SideEffects::Resource *readResource = readEffect.getResource();
231 if (writeResource->isDisjointFrom(readResource))
232 return false;
233 // A pointer-based access to an addressable resource cannot
234 // conflict with a non-addressable resource.
235 if (readEffect.getValue() && !writeResource->isAddressable())
236 return false;
237 if (effect.getValue() && !readResource->isAddressable())
238 return false;
239 return true;
240 });
241 if (canConflict) {
242 result.first->second = {nextOp, MemoryEffects::Write::get()};
243 return true;
244 }
245 }
246 }
247 nextOp = nextOp->getNextNode();
248 }
249 result.first->second = std::make_pair(toOp, nullptr);
250 return false;
251}
252
253/// Attempt to eliminate a redundant operation.
254LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
255 Operation *op,
256 bool hasSSADominance) {
257 // Don't simplify terminator operations.
259 return failure();
260
261 // If the operation is already trivially dead just add it to the erase list.
262 if (isOpTriviallyDead(op)) {
263 opsToErase.push_back(op);
264 ++numDCE;
265 return success();
266 }
267
268 // Don't simplify operations with regions that have multiple blocks.
269 // TODO: We need additional tests to verify that we handle such IR correctly.
270 if (!llvm::all_of(op->getRegions(),
271 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
272 return failure();
273
274 // Some simple use case of operation with memory side-effect are dealt with
275 // here. Operations with no side-effect are done after.
276 if (!isMemoryEffectFree(op)) {
277 // TODO: Only basic use case for operations with MemoryEffects::Read can be
278 // eleminated now. More work needs to be done for more complicated patterns
279 // and other side-effects.
281 return failure();
282
283 // Look for an existing definition for the operation.
284 if (auto *existing = knownValues.lookup(op)) {
285 if (existing->getBlock() == op->getBlock() &&
286 !hasOtherSideEffectingOpInBetween(existing, op)) {
287 // The operation that can be deleted has been reach with no
288 // side-effecting operations in between the existing operation and
289 // this one so we can remove the duplicate.
290 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
291 return success();
292 }
293 }
294 knownValues.insert(op, op);
295 return failure();
296 }
297
298 // Look for an existing definition for the operation.
299 if (auto *existing = knownValues.lookup(op)) {
300 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
301 return success();
302 }
303
304 // Otherwise, we add this operation to the known values map.
305 knownValues.insert(op, op);
306 return failure();
307}
308
309void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
310 bool hasSSADominance) {
311 for (auto &op : *bb) {
312 // Most operations don't have regions, so fast path that case.
313 if (op.getNumRegions() != 0) {
314 // If this operation is isolated above, we can't process nested regions
315 // with the given 'knownValues' map. This would cause the insertion of
316 // implicit captures in explicit capture only regions.
317 if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
318 ScopedMapTy nestedKnownValues;
319 for (auto &region : op.getRegions())
320 simplifyRegion(nestedKnownValues, region);
321 } else {
322 // Otherwise, process nested regions normally.
323 for (auto &region : op.getRegions())
324 simplifyRegion(knownValues, region);
325 }
326 }
327
328 // If the operation is simplified, we don't process any held regions.
329 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
330 continue;
331 }
332 // Clear the MemoryEffects cache since its usage is by block only.
333 memEffectsCache.clear();
334}
335
336void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
337 // If the region is empty there is nothing to do.
338 if (region.empty())
339 return;
340
341 bool hasSSADominance = domInfo->hasSSADominance(&region);
342
343 // If the region only contains one block, then simplify it directly.
344 if (region.hasOneBlock()) {
345 ScopedMapTy::ScopeTy scope(knownValues);
346 simplifyBlock(knownValues, &region.front(), hasSSADominance);
347 return;
348 }
349
350 // If the region does not have dominanceInfo, then skip it.
351 // TODO: Regions without SSA dominance should define a different
352 // traversal order which is appropriate and can be used here.
353 if (!hasSSADominance)
354 return;
355
356 // Note, deque is being used here because there was significant performance
357 // gains over vector when the container becomes very large due to the
358 // specific access patterns. If/when these performance issues are no
359 // longer a problem we can change this to vector. For more information see
360 // the llvm mailing list discussion on this:
361 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
362 std::deque<std::unique_ptr<CFGStackNode>> stack;
363
364 // Process the nodes of the dom tree for this region.
365 stack.emplace_back(std::make_unique<CFGStackNode>(
366 knownValues, domInfo->getRootNode(&region)));
367
368 while (!stack.empty()) {
369 auto &currentNode = stack.back();
370
371 // Check to see if we need to process this node.
372 if (!currentNode->processed) {
373 currentNode->processed = true;
374 simplifyBlock(knownValues, currentNode->node->getBlock(),
375 hasSSADominance);
376 }
377
378 // Otherwise, check to see if we need to process a child node.
379 if (currentNode->childIterator != currentNode->node->end()) {
380 auto *childNode = *(currentNode->childIterator++);
381 stack.emplace_back(
382 std::make_unique<CFGStackNode>(knownValues, childNode));
383 } else {
384 // Finally, if the node and all of its children have been processed
385 // then we delete the node.
386 stack.pop_back();
387 }
388 }
389}
390
391void CSEDriver::simplify(Operation *op, bool *changed) {
392 /// Simplify all regions.
393 ScopedMapTy knownValues;
394 for (auto &region : op->getRegions())
395 simplifyRegion(knownValues, region);
396
397 /// Erase any operations that were marked as dead during simplification.
398 for (auto *op : opsToErase)
399 rewriter.eraseOp(op);
400 if (changed)
401 *changed = !opsToErase.empty();
402
403 // Note: CSE does currently not remove ops with regions, so DominanceInfo
404 // does not have to be invalidated.
405}
406
408 DominanceInfo &domInfo, Operation *op,
409 bool *changed) {
410 CSEDriver driver(rewriter, &domInfo);
411 driver.simplify(op, changed);
412}
413
414namespace {
415/// CSE pass.
416struct CSE : public impl::CSEPassBase<CSE> {
417 void runOnOperation() override;
418};
419} // namespace
420
421void CSE::runOnOperation() {
422 // Simplify the IR.
423 IRRewriter rewriter(&getContext());
424 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
425 bool changed = false;
426 driver.simplify(getOperation(), &changed);
427
428 // Set statistics.
429 numCSE = driver.getNumCSE();
430 numDCE = driver.getNumDCE();
431
432 // If there was no change to the IR, we mark all analyses as preserved.
433 if (!changed)
434 return markAllAnalysesPreserved();
435
436 // We currently don't remove region operations, so mark dominance as
437 // preserved.
438 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
439}
return success()
lhs
b getContext())
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Definition Dominance.h:140
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
This class represents an operand of an operation.
Definition Value.h:254
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:244
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:878
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:783
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_range getResults()
Definition Operation.h:441
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a specific resource that an effect applies to.
virtual bool isAddressable() const
Returns true if this resource is addressable (effects on it can alias pointer-based memory).
bool isDisjointFrom(const Resource *other) const
Returns true if this resource is disjoint from another.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
Definition Dominance.h:74
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
Definition Dominance.h:92
::mlir::Pass::Statistic numDCE
Definition CSE.cpp:167
::mlir::Pass::Statistic numCSE
Explicitly declare the TypeID for this class.
Definition CSE.cpp:166
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
Definition Dominance.h:30
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:407
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.