MLIR 23.0.0git
CSE.cpp
Go to the documentation of this file.
1//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass performs a simple common sub-expression elimination
10// algorithm on operations within a region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/CSE.h"
15
16#include "mlir/IR/Dominance.h"
19#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/Support/Allocator.h"
24#include "llvm/Support/RecyclingAllocator.h"
25#include <deque>
26
27namespace mlir {
28#define GEN_PASS_DEF_CSEPASS
29#include "mlir/Transforms/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34namespace {
35struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
36 static unsigned getHashValue(const Operation *opC) {
38 const_cast<Operation *>(opC),
42 }
43 static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
44 auto *lhs = const_cast<Operation *>(lhsC);
45 auto *rhs = const_cast<Operation *>(rhsC);
46 if (lhs == rhs)
47 return true;
48 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() || rhs == getEmptyKey())
50 return false;
52 const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
54 }
55};
56} // namespace
57
58namespace {
59/// Simple common sub-expression elimination.
60class CSEDriver {
61public:
62 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
63 : rewriter(rewriter), domInfo(domInfo) {}
64
65 /// Simplify all operations within the given op.
66 void simplify(Operation *op, bool *changed = nullptr);
67
68 int64_t getNumCSE() const { return numCSE; }
69 int64_t getNumDCE() const { return numDCE; }
70
71private:
72 /// Shared implementation of operation elimination and scoped map definitions.
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
76 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
77 SimpleOperationInfo, AllocatorTy>;
78
79 /// Cache holding MemoryEffects information between two operations. The first
80 /// operation is stored has the key. The second operation is stored inside a
81 /// pair in the value. The pair also hold the MemoryEffects between those
82 /// two operations. If the MemoryEffects is nullptr then we assume there is
83 /// no operation with MemoryEffects::Write between the two operations.
84 using MemEffectsCache =
86
87 /// Represents a single entry in the depth first traversal of a CFG.
88 struct CFGStackNode {
89 CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
91
92 /// Scope for the known values.
93 ScopedMapTy::ScopeTy scope;
94
96 DominanceInfoNode::const_iterator childIterator;
97
98 /// If this node has been fully processed yet or not.
99 bool processed = false;
100 };
101
102 /// Attempt to eliminate a redundant operation. Returns success if the
103 /// operation was marked for removal, failure otherwise.
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues, Region &region);
108
109 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
110 Operation *existing, bool hasSSADominance);
111
112 /// Check if there is side-effecting operations other than the given effect
113 /// between the two operations.
114 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
115
116 /// A rewriter for modifying the IR.
117 RewriterBase &rewriter;
118
119 /// Operations marked as dead and to be erased.
120 std::vector<Operation *> opsToErase;
121 DominanceInfo *domInfo = nullptr;
122 MemEffectsCache memEffectsCache;
123
124 // Various statistics.
125 int64_t numCSE = 0;
126 int64_t numDCE = 0;
127};
128} // namespace
129
130void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
131 Operation *existing,
132 bool hasSSADominance) {
133 // If we find one then replace all uses of the current operation with the
134 // existing one and mark it for deletion. We can only replace an operand in
135 // an operation if it has not been visited yet.
136 if (hasSSADominance) {
137 // If the region has SSA dominance, then we are guaranteed to have not
138 // visited any use of the current operation.
139 if (auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
142 // Replace all uses, but do not remove the operation yet. This does not
143 // notify the listener because the original op is not erased.
144 rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
145 opsToErase.push_back(op);
146 } else {
147 // When the region does not have SSA dominance, we need to check if we
148 // have visited a use before replacing any use.
149 auto wasVisited = [&](OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
151 };
152 if (auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
154 for (Value v : op->getResults())
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
157
158 // Replace all uses, but do not remove the operation yet. This does not
159 // notify the listener because the original op is not erased.
160 rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
161 wasVisited);
162
163 // There may be some remaining uses of the operation.
164 if (op->use_empty())
165 opsToErase.push_back(op);
166 }
167
168 // If the existing operation has an unknown location and the current
169 // operation doesn't, then set the existing op's location to that of the
170 // current op.
171 if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
172 existing->setLoc(op->getLoc());
173
174 ++numCSE;
175}
176
177bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
178 Operation *toOp) {
179 assert(fromOp->getBlock() == toOp->getBlock());
180 assert(hasEffect<MemoryEffects::Read>(fromOp) &&
181 "expected read effect on fromOp");
182 assert(hasEffect<MemoryEffects::Read>(toOp) &&
183 "expected read effect on toOp");
184
185 // Collect the resources of fromOp's read effects. A write can only block
186 // CSE if its resource is not disjoint from one of these.
187 SmallPtrSet<SideEffects::Resource *, 1> readResources;
188 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
189 SmallVector<MemoryEffects::EffectInstance> fromEffects;
190 memOp.getEffects(fromEffects);
191 for (const auto &e : fromEffects)
192 if (isa<MemoryEffects::Read>(e.getEffect()))
193 readResources.insert(e.getResource());
194 }
195
196 Operation *nextOp = fromOp->getNextNode();
197 auto result =
198 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
199 if (result.second) {
200 auto memEffectsCachePair = result.first->second;
201 if (memEffectsCachePair.second == nullptr) {
202 // No MemoryEffects::Write has been detected until the cached operation.
203 // Continue looking from the cached operation to toOp.
204 nextOp = memEffectsCachePair.first;
205 } else {
206 // MemoryEffects::Write has been detected before so there is no need to
207 // check further.
208 return true;
209 }
210 }
211 while (nextOp && nextOp != toOp) {
212 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
213 getEffectsRecursively(nextOp);
214 if (!effects) {
215 // TODO: Do we need to handle other effects generically?
216 // If the operation does not implement the MemoryEffectOpInterface we
217 // conservatively assume it writes.
218 result.first->second =
219 std::make_pair(nextOp, MemoryEffects::Write::get());
220 return true;
221 }
222
223 for (const MemoryEffects::EffectInstance &effect : *effects) {
224 if (isa<MemoryEffects::Write>(effect.getEffect())) {
225 // A write on a resource disjoint from all read resources cannot
226 // conflict with the reads being CSE'd.
227 auto *writeResource = effect.getResource();
228 bool canConflict = llvm::any_of(readResources, [&](auto *readResource) {
229 return !writeResource->isDisjointFrom(readResource);
230 });
231 if (canConflict) {
232 result.first->second = {nextOp, MemoryEffects::Write::get()};
233 return true;
234 }
235 }
236 }
237 nextOp = nextOp->getNextNode();
238 }
239 result.first->second = std::make_pair(toOp, nullptr);
240 return false;
241}
242
243/// Attempt to eliminate a redundant operation.
244LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
245 Operation *op,
246 bool hasSSADominance) {
247 // Don't simplify terminator operations.
248 if (op->hasTrait<OpTrait::IsTerminator>())
249 return failure();
250
251 // If the operation is already trivially dead just add it to the erase list.
252 if (isOpTriviallyDead(op)) {
253 opsToErase.push_back(op);
254 ++numDCE;
255 return success();
256 }
257
258 // Don't simplify operations with regions that have multiple blocks.
259 // TODO: We need additional tests to verify that we handle such IR correctly.
260 if (!llvm::all_of(op->getRegions(),
261 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
262 return failure();
263
264 // Some simple use case of operation with memory side-effect are dealt with
265 // here. Operations with no side-effect are done after.
266 if (!isMemoryEffectFree(op)) {
267 // TODO: Only basic use case for operations with MemoryEffects::Read can be
268 // eleminated now. More work needs to be done for more complicated patterns
269 // and other side-effects.
271 return failure();
272
273 // Look for an existing definition for the operation.
274 if (auto *existing = knownValues.lookup(op)) {
275 if (existing->getBlock() == op->getBlock() &&
276 !hasOtherSideEffectingOpInBetween(existing, op)) {
277 // The operation that can be deleted has been reach with no
278 // side-effecting operations in between the existing operation and
279 // this one so we can remove the duplicate.
280 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
281 return success();
282 }
283 }
284 knownValues.insert(op, op);
285 return failure();
286 }
287
288 // Look for an existing definition for the operation.
289 if (auto *existing = knownValues.lookup(op)) {
290 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
291 ++numCSE;
292 return success();
293 }
294
295 // Otherwise, we add this operation to the known values map.
296 knownValues.insert(op, op);
297 return failure();
298}
299
300void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
301 bool hasSSADominance) {
302 for (auto &op : *bb) {
303 // Most operations don't have regions, so fast path that case.
304 if (op.getNumRegions() != 0) {
305 // If this operation is isolated above, we can't process nested regions
306 // with the given 'knownValues' map. This would cause the insertion of
307 // implicit captures in explicit capture only regions.
308 if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
309 ScopedMapTy nestedKnownValues;
310 for (auto &region : op.getRegions())
311 simplifyRegion(nestedKnownValues, region);
312 } else {
313 // Otherwise, process nested regions normally.
314 for (auto &region : op.getRegions())
315 simplifyRegion(knownValues, region);
316 }
317 }
318
319 // If the operation is simplified, we don't process any held regions.
320 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
321 continue;
322 }
323 // Clear the MemoryEffects cache since its usage is by block only.
324 memEffectsCache.clear();
325}
326
327void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
328 // If the region is empty there is nothing to do.
329 if (region.empty())
330 return;
331
332 bool hasSSADominance = domInfo->hasSSADominance(&region);
333
334 // If the region only contains one block, then simplify it directly.
335 if (region.hasOneBlock()) {
336 ScopedMapTy::ScopeTy scope(knownValues);
337 simplifyBlock(knownValues, &region.front(), hasSSADominance);
338 return;
339 }
340
341 // If the region does not have dominanceInfo, then skip it.
342 // TODO: Regions without SSA dominance should define a different
343 // traversal order which is appropriate and can be used here.
344 if (!hasSSADominance)
345 return;
346
347 // Note, deque is being used here because there was significant performance
348 // gains over vector when the container becomes very large due to the
349 // specific access patterns. If/when these performance issues are no
350 // longer a problem we can change this to vector. For more information see
351 // the llvm mailing list discussion on this:
352 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
353 std::deque<std::unique_ptr<CFGStackNode>> stack;
354
355 // Process the nodes of the dom tree for this region.
356 stack.emplace_back(std::make_unique<CFGStackNode>(
357 knownValues, domInfo->getRootNode(&region)));
358
359 while (!stack.empty()) {
360 auto &currentNode = stack.back();
361
362 // Check to see if we need to process this node.
363 if (!currentNode->processed) {
364 currentNode->processed = true;
365 simplifyBlock(knownValues, currentNode->node->getBlock(),
366 hasSSADominance);
367 }
368
369 // Otherwise, check to see if we need to process a child node.
370 if (currentNode->childIterator != currentNode->node->end()) {
371 auto *childNode = *(currentNode->childIterator++);
372 stack.emplace_back(
373 std::make_unique<CFGStackNode>(knownValues, childNode));
374 } else {
375 // Finally, if the node and all of its children have been processed
376 // then we delete the node.
377 stack.pop_back();
378 }
379 }
380}
381
382void CSEDriver::simplify(Operation *op, bool *changed) {
383 /// Simplify all regions.
384 ScopedMapTy knownValues;
385 for (auto &region : op->getRegions())
386 simplifyRegion(knownValues, region);
387
388 /// Erase any operations that were marked as dead during simplification.
389 for (auto *op : opsToErase)
390 rewriter.eraseOp(op);
391 if (changed)
392 *changed = !opsToErase.empty();
393
394 // Note: CSE does currently not remove ops with regions, so DominanceInfo
395 // does not have to be invalidated.
396}
397
399 DominanceInfo &domInfo, Operation *op,
400 bool *changed) {
401 CSEDriver driver(rewriter, &domInfo);
402 driver.simplify(op, changed);
403}
404
405namespace {
406/// CSE pass.
407struct CSE : public impl::CSEPassBase<CSE> {
408 void runOnOperation() override;
409};
410} // namespace
411
412void CSE::runOnOperation() {
413 // Simplify the IR.
414 IRRewriter rewriter(&getContext());
415 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
416 bool changed = false;
417 driver.simplify(getOperation(), &changed);
418
419 // Set statistics.
420 numCSE = driver.getNumCSE();
421 numDCE = driver.getNumDCE();
422
423 // If there was no change to the IR, we mark all analyses as preserved.
424 if (!changed)
425 return markAllAnalysesPreserved();
426
427 // We currently don't remove region operations, so mark dominance as
428 // preserved.
429 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
430}
return success()
lhs
b getContext())
template bool mlir::hasEffect< MemoryEffects::Read >(Operation *)
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Definition Dominance.h:140
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:247
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:881
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:786
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
result_range getResults()
Definition Operation.h:444
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
Definition Dominance.h:74
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
Definition Dominance.h:92
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
Definition Dominance.h:30
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:398
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash to compute the hash value of operands/results directly.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.