MLIR 22.0.0git
CSE.cpp
Go to the documentation of this file.
1//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass performs a simple common sub-expression elimination
10// algorithm on operations within a region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/CSE.h"
15
16#include "mlir/IR/Dominance.h"
19#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/Support/Allocator.h"
24#include "llvm/Support/RecyclingAllocator.h"
25#include <deque>
26
27namespace mlir {
28#define GEN_PASS_DEF_CSE
29#include "mlir/Transforms/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34namespace {
35struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
36 static unsigned getHashValue(const Operation *opC) {
38 const_cast<Operation *>(opC),
42 }
43 static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
44 auto *lhs = const_cast<Operation *>(lhsC);
45 auto *rhs = const_cast<Operation *>(rhsC);
46 if (lhs == rhs)
47 return true;
48 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
49 rhs == getTombstoneKey() || rhs == getEmptyKey())
50 return false;
52 const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
54 }
55};
56} // namespace
57
58namespace {
59/// Simple common sub-expression elimination.
60class CSEDriver {
61public:
62 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
63 : rewriter(rewriter), domInfo(domInfo) {}
64
65 /// Simplify all operations within the given op.
66 void simplify(Operation *op, bool *changed = nullptr);
67
68 int64_t getNumCSE() const { return numCSE; }
69 int64_t getNumDCE() const { return numDCE; }
70
71private:
72 /// Shared implementation of operation elimination and scoped map definitions.
73 using AllocatorTy = llvm::RecyclingAllocator<
74 llvm::BumpPtrAllocator,
75 llvm::ScopedHashTableVal<Operation *, Operation *>>;
76 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
77 SimpleOperationInfo, AllocatorTy>;
78
79 /// Cache holding MemoryEffects information between two operations. The first
80 /// operation is stored has the key. The second operation is stored inside a
81 /// pair in the value. The pair also hold the MemoryEffects between those
82 /// two operations. If the MemoryEffects is nullptr then we assume there is
83 /// no operation with MemoryEffects::Write between the two operations.
84 using MemEffectsCache =
86
87 /// Represents a single entry in the depth first traversal of a CFG.
88 struct CFGStackNode {
89 CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
90 : scope(knownValues), node(node), childIterator(node->begin()) {}
91
92 /// Scope for the known values.
93 ScopedMapTy::ScopeTy scope;
94
96 DominanceInfoNode::const_iterator childIterator;
97
98 /// If this node has been fully processed yet or not.
99 bool processed = false;
100 };
101
102 /// Attempt to eliminate a redundant operation. Returns success if the
103 /// operation was marked for removal, failure otherwise.
104 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
105 bool hasSSADominance);
106 void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
107 void simplifyRegion(ScopedMapTy &knownValues, Region &region);
108
109 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
110 Operation *existing, bool hasSSADominance);
111
112 /// Check if there is side-effecting operations other than the given effect
113 /// between the two operations.
114 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
116 /// A rewriter for modifying the IR.
117 RewriterBase &rewriter;
118
119 /// Operations marked as dead and to be erased.
120 std::vector<Operation *> opsToErase;
121 DominanceInfo *domInfo = nullptr;
122 MemEffectsCache memEffectsCache;
124 // Various statistics.
125 int64_t numCSE = 0;
126 int64_t numDCE = 0;
128} // namespace
129
130void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
131 Operation *existing,
132 bool hasSSADominance) {
133 // If we find one then replace all uses of the current operation with the
134 // existing one and mark it for deletion. We can only replace an operand in
135 // an operation if it has not been visited yet.
136 if (hasSSADominance) {
137 // If the region has SSA dominance, then we are guaranteed to have not
138 // visited any use of the current operation.
139 if (auto *rewriteListener =
140 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
141 rewriteListener->notifyOperationReplaced(op, existing);
142 // Replace all uses, but do not remove the operation yet. This does not
143 // notify the listener because the original op is not erased.
144 rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
145 opsToErase.push_back(op);
146 } else {
147 // When the region does not have SSA dominance, we need to check if we
148 // have visited a use before replacing any use.
149 auto wasVisited = [&](OpOperand &operand) {
150 return !knownValues.count(operand.getOwner());
151 };
152 if (auto *rewriteListener =
153 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
154 for (Value v : op->getResults())
155 if (all_of(v.getUses(), wasVisited))
156 rewriteListener->notifyOperationReplaced(op, existing);
157
158 // Replace all uses, but do not remove the operation yet. This does not
159 // notify the listener because the original op is not erased.
160 rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
161 wasVisited);
163 // There may be some remaining uses of the operation.
164 if (op->use_empty())
165 opsToErase.push_back(op);
166 }
167
168 // If the existing operation has an unknown location and the current
169 // operation doesn't, then set the existing op's location to that of the
170 // current op.
171 if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
172 existing->setLoc(op->getLoc());
173
174 ++numCSE;
175}
176
177bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
178 Operation *toOp) {
179 assert(fromOp->getBlock() == toOp->getBlock());
180 assert(hasEffect<MemoryEffects::Read>(fromOp) &&
181 "expected read effect on fromOp");
182 assert(hasEffect<MemoryEffects::Read>(toOp) &&
183 "expected read effect on toOp");
184 Operation *nextOp = fromOp->getNextNode();
185 auto result =
186 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
187 if (result.second) {
188 auto memEffectsCachePair = result.first->second;
189 if (memEffectsCachePair.second == nullptr) {
190 // No MemoryEffects::Write has been detected until the cached operation.
191 // Continue looking from the cached operation to toOp.
192 nextOp = memEffectsCachePair.first;
193 } else {
194 // MemoryEffects::Write has been detected before so there is no need to
195 // check further.
196 return true;
197 }
198 }
199 while (nextOp && nextOp != toOp) {
200 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
201 getEffectsRecursively(nextOp);
202 if (!effects) {
203 // TODO: Do we need to handle other effects generically?
204 // If the operation does not implement the MemoryEffectOpInterface we
205 // conservatively assume it writes.
206 result.first->second =
207 std::make_pair(nextOp, MemoryEffects::Write::get());
208 return true;
209 }
210
211 for (const MemoryEffects::EffectInstance &effect : *effects) {
212 if (isa<MemoryEffects::Write>(effect.getEffect())) {
213 result.first->second = {nextOp, MemoryEffects::Write::get()};
214 return true;
215 }
216 }
217 nextOp = nextOp->getNextNode();
218 }
219 result.first->second = std::make_pair(toOp, nullptr);
220 return false;
221}
222
223/// Attempt to eliminate a redundant operation.
224LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
225 Operation *op,
226 bool hasSSADominance) {
227 // Don't simplify terminator operations.
228 if (op->hasTrait<OpTrait::IsTerminator>())
229 return failure();
230
231 // If the operation is already trivially dead just add it to the erase list.
232 if (isOpTriviallyDead(op)) {
233 opsToErase.push_back(op);
234 ++numDCE;
235 return success();
236 }
237
238 // Don't simplify operations with regions that have multiple blocks.
239 // TODO: We need additional tests to verify that we handle such IR correctly.
240 if (!llvm::all_of(op->getRegions(),
241 [](Region &r) { return r.empty() || r.hasOneBlock(); }))
242 return failure();
243
244 // Some simple use case of operation with memory side-effect are dealt with
245 // here. Operations with no side-effect are done after.
246 if (!isMemoryEffectFree(op)) {
247 // TODO: Only basic use case for operations with MemoryEffects::Read can be
248 // eleminated now. More work needs to be done for more complicated patterns
249 // and other side-effects.
251 return failure();
252
253 // Look for an existing definition for the operation.
254 if (auto *existing = knownValues.lookup(op)) {
255 if (existing->getBlock() == op->getBlock() &&
256 !hasOtherSideEffectingOpInBetween(existing, op)) {
257 // The operation that can be deleted has been reach with no
258 // side-effecting operations in between the existing operation and
259 // this one so we can remove the duplicate.
260 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
261 return success();
262 }
263 }
264 knownValues.insert(op, op);
265 return failure();
266 }
267
268 // Look for an existing definition for the operation.
269 if (auto *existing = knownValues.lookup(op)) {
270 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
271 ++numCSE;
272 return success();
273 }
274
275 // Otherwise, we add this operation to the known values map.
276 knownValues.insert(op, op);
277 return failure();
278}
279
280void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
281 bool hasSSADominance) {
282 for (auto &op : *bb) {
283 // Most operations don't have regions, so fast path that case.
284 if (op.getNumRegions() != 0) {
285 // If this operation is isolated above, we can't process nested regions
286 // with the given 'knownValues' map. This would cause the insertion of
287 // implicit captures in explicit capture only regions.
288 if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
289 ScopedMapTy nestedKnownValues;
290 for (auto &region : op.getRegions())
291 simplifyRegion(nestedKnownValues, region);
292 } else {
293 // Otherwise, process nested regions normally.
294 for (auto &region : op.getRegions())
295 simplifyRegion(knownValues, region);
296 }
297 }
298
299 // If the operation is simplified, we don't process any held regions.
300 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
301 continue;
302 }
303 // Clear the MemoryEffects cache since its usage is by block only.
304 memEffectsCache.clear();
305}
306
307void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
308 // If the region is empty there is nothing to do.
309 if (region.empty())
310 return;
311
312 bool hasSSADominance = domInfo->hasSSADominance(&region);
313
314 // If the region only contains one block, then simplify it directly.
315 if (region.hasOneBlock()) {
316 ScopedMapTy::ScopeTy scope(knownValues);
317 simplifyBlock(knownValues, &region.front(), hasSSADominance);
318 return;
319 }
320
321 // If the region does not have dominanceInfo, then skip it.
322 // TODO: Regions without SSA dominance should define a different
323 // traversal order which is appropriate and can be used here.
324 if (!hasSSADominance)
325 return;
326
327 // Note, deque is being used here because there was significant performance
328 // gains over vector when the container becomes very large due to the
329 // specific access patterns. If/when these performance issues are no
330 // longer a problem we can change this to vector. For more information see
331 // the llvm mailing list discussion on this:
332 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
333 std::deque<std::unique_ptr<CFGStackNode>> stack;
334
335 // Process the nodes of the dom tree for this region.
336 stack.emplace_back(std::make_unique<CFGStackNode>(
337 knownValues, domInfo->getRootNode(&region)));
338
339 while (!stack.empty()) {
340 auto &currentNode = stack.back();
341
342 // Check to see if we need to process this node.
343 if (!currentNode->processed) {
344 currentNode->processed = true;
345 simplifyBlock(knownValues, currentNode->node->getBlock(),
346 hasSSADominance);
347 }
348
349 // Otherwise, check to see if we need to process a child node.
350 if (currentNode->childIterator != currentNode->node->end()) {
351 auto *childNode = *(currentNode->childIterator++);
352 stack.emplace_back(
353 std::make_unique<CFGStackNode>(knownValues, childNode));
354 } else {
355 // Finally, if the node and all of its children have been processed
356 // then we delete the node.
357 stack.pop_back();
358 }
359 }
360}
361
362void CSEDriver::simplify(Operation *op, bool *changed) {
363 /// Simplify all regions.
364 ScopedMapTy knownValues;
365 for (auto &region : op->getRegions())
366 simplifyRegion(knownValues, region);
367
368 /// Erase any operations that were marked as dead during simplification.
369 for (auto *op : opsToErase)
370 rewriter.eraseOp(op);
371 if (changed)
372 *changed = !opsToErase.empty();
373
374 // Note: CSE does currently not remove ops with regions, so DominanceInfo
375 // does not have to be invalidated.
376}
377
379 DominanceInfo &domInfo, Operation *op,
380 bool *changed) {
381 CSEDriver driver(rewriter, &domInfo);
382 driver.simplify(op, changed);
383}
384
385namespace {
386/// CSE pass.
387struct CSE : public impl::CSEBase<CSE> {
388 void runOnOperation() override;
389};
390} // namespace
391
392void CSE::runOnOperation() {
393 // Simplify the IR.
394 IRRewriter rewriter(&getContext());
395 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
396 bool changed = false;
397 driver.simplify(getOperation(), &changed);
398
399 // Set statistics.
400 numCSE = driver.getNumCSE();
401 numDCE = driver.getNumDCE();
402
403 // If there was no change to the IR, we mark all analyses as preserved.
404 if (!changed)
405 return markAllAnalysesPreserved();
406
407 // We currently don't remove region operations, so mark dominance as
408 // preserved.
409 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
410}
411
412std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); }
return success()
lhs
b getContext())
template bool mlir::hasSingleEffect< MemoryEffects::Read >(Operation *)
A class for computing basic dominance information.
Definition Dominance.h:140
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:757
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_range getResults()
Definition Operation.h:415
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
DominanceInfoNode * getRootNode(Region *region)
Get the root dominance node of the given region.
Definition Dominance.h:74
bool hasSSADominance(Block *block) const
Return true if operations in the specified block are known to obey SSA dominance requirements.
Definition Dominance.h:92
::mlir::Pass::Statistic numDCE
Definition CSE.cpp:162
::mlir::Pass::Statistic numCSE
Explicitly declare the TypeID for this class.
Definition CSE.cpp:161
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
std::unique_ptr< Pass > createCSEPass()
Creates a pass to perform common sub expression elimination.
Definition CSE.cpp:412
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::DomTreeNodeBase< Block > DominanceInfoNode
Definition Dominance.h:30
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:378
std::optional< llvm::SmallVector< MemoryEffects::EffectInstance > > getEffectsRecursively(Operation *rootOp)
Returns the side effects of an operation.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code directHashValue(Value v)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.