MLIR 22.0.0git
CyclicReplacerCache.h
Go to the documentation of this file.
1//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains helper classes for caching replacer-like functions that
10// map values between two domains. They are able to handle replacer logic that
11// contains self-recursion.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
16#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
17
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/DenseSet.h"
20#include "llvm/ADT/SmallVector.h"
21#include <functional>
22#include <optional>
23#include <set>
24
25namespace mlir {
26
27//===----------------------------------------------------------------------===//
28// CyclicReplacerCache
29//===----------------------------------------------------------------------===//
30
31/// A cache for replacer-like functions that map values between two domains. The
32/// difference compared to just using a map to cache in-out pairs is that this
33/// class is able to handle replacer logic that is self-recursive (and thus may
34/// cause infinite recursion in the naive case).
35///
36/// This class provides a hook for the user to perform cycle pruning when a
37/// cycle is identified, and is able to perform context-sensitive caching so
38/// that the replacement result for an input that is part of a pruned cycle can
39/// be distinct from the replacement result for the same input when it is not
40/// part of a cycle.
41///
42/// In addition, this class allows deferring cycle pruning until specific inputs
43/// are repeated. This is useful for cases where not all elements in a cycle can
44/// perform pruning. The user still must guarantee that at least one element in
45/// any given cycle can perform pruning. Even if not, an assertion will
46/// eventually be tripped instead of infinite recursion (the run-time is
47/// linearly bounded by the maximum cycle length of its input).
48///
49/// WARNING: This class works best with InT & OutT that are trivial scalar
50/// types. The input/output elements will be frequently copied and hashed.
51template <typename InT, typename OutT>
53public:
54 /// User-provided replacement function & cycle-breaking functions.
55 /// The cycle-breaking function must not make any more recursive invocations
56 /// to this cached replacer.
57 using CycleBreakerFn = std::function<std::optional<OutT>(InT)>;
58
61 : cycleBreaker(std::move(cycleBreaker)) {}
62
63 /// A possibly unresolved cache entry.
64 /// If unresolved, the entry must be resolved before it goes out of scope.
65 struct CacheEntry {
66 public:
67 ~CacheEntry() { assert(result && "unresovled cache entry"); }
68
69 /// Check whether this node was repeated during recursive replacements.
70 /// This only makes sense to be called after all recursive replacements are
71 /// completed and the current element has resurfaced to the top of the
72 /// replacement stack.
73 bool wasRepeated() const {
74 // If the top frame includes itself as a dependency, then it must have
75 // been repeated.
76 ReplacementFrame &currFrame = cache.replacementStack.back();
77 size_t currFrameIndex = cache.replacementStack.size() - 1;
78 return currFrame.dependentFrames.count(currFrameIndex);
79 }
80
81 /// Resolve an unresolved cache entry by providing the result to be stored
82 /// in the cache.
83 void resolve(OutT result) {
84 assert(!this->result && "cache entry already resolved");
85 cache.finalizeReplacement(element, result);
86 this->result = std::move(result);
87 }
88
89 /// Get the resolved result if one exists.
90 const std::optional<OutT> &get() const { return result; }
91
92 private:
93 friend class CyclicReplacerCache;
94 CacheEntry() = delete;
95 CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
96 std::optional<OutT> result = std::nullopt)
97 : cache(cache), element(std::move(element)), result(result) {}
98
100 InT element;
101 std::optional<OutT> result;
102 };
103
104 /// Lookup the cache for a pre-calculated replacement for `element`.
105 /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
106 /// unresolved CacheEntry will be returned, and the caller must resolve it
107 /// with the calculated replacement so it can be registered in the cache for
108 /// future use.
109 /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
110 /// CacheEntries that are returned must be resolved in reverse order of
111 /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
112 /// the first retrieved CacheEntry must be resolved last. This should be
113 /// natural when used as a stack / inside recursion.
114 CacheEntry lookupOrInit(InT element);
115
116private:
117 /// Register the replacement in the cache and update the replacementStack.
118 void finalizeReplacement(InT element, OutT result);
119
120 CycleBreakerFn cycleBreaker;
121 llvm::DenseMap<InT, OutT> standaloneCache;
122
123 struct DependentReplacement {
124 OutT replacement;
125 /// The highest replacement frame index that this cache entry is dependent
126 /// on.
127 size_t highestDependentFrame;
128 };
130
131 struct ReplacementFrame {
132 /// The set of elements that is only legal while under this current frame.
133 /// They need to be removed from the cache when this frame is popped off the
134 /// replacement stack.
135 llvm::DenseSet<InT> dependingReplacements;
136 /// The set of frame indices that this current frame's replacement is
137 /// dependent on, ordered from highest to lowest.
138 std::set<size_t, std::greater<size_t>> dependentFrames;
139 };
140 /// Every element currently in the progress of being replaced pushes a frame
141 /// onto this stack.
143 /// Maps from each input element to its indices on the replacement stack.
145 /// If set to true, we are currently asking an element to break a cycle. No
146 /// more recursive invocations is allowed while this is true (the replacement
147 /// stack can no longer grow).
148 bool resolvingCycle = false;
149};
150
151template <typename InT, typename OutT>
154 assert(!resolvingCycle &&
155 "illegal recursive invocation while breaking cycle");
156
157 if (auto it = standaloneCache.find(element); it != standaloneCache.end())
158 return CacheEntry(*this, element, it->second);
159
160 if (auto it = dependentCache.find(element); it != dependentCache.end()) {
161 // Update the current top frame (the element that invoked this current
162 // replacement) to include any dependencies the cache entry had.
163 ReplacementFrame &currFrame = replacementStack.back();
164 currFrame.dependentFrames.insert(it->second.highestDependentFrame);
165 return CacheEntry(*this, element, it->second.replacement);
166 }
167
168 auto [it, inserted] = cyclicElementFrame.try_emplace(element);
169 if (!inserted) {
170 // This is a repeat of a known element. Try to break cycle here.
171 resolvingCycle = true;
172 std::optional<OutT> result = cycleBreaker(element);
173 resolvingCycle = false;
174 if (result) {
175 // Cycle was broken.
176 size_t dependentFrame = it->second.back();
177 dependentCache[element] = {*result, dependentFrame};
178 ReplacementFrame &currFrame = replacementStack.back();
179 // If this is a repeat, there is no replacement frame to pop. Mark the top
180 // frame as being dependent on this element.
181 currFrame.dependentFrames.insert(dependentFrame);
182
183 return CacheEntry(*this, element, *result);
184 }
185
186 // Cycle could not be broken.
187 // A legal setup must ensure at least one element of each cycle can break
188 // cycles. Under this setup, each element can be seen at most twice before
189 // the cycle is broken. If we see an element more than twice, we know this
190 // is an illegal setup.
191 assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
192 }
193
194 // Otherwise, either this is the first time we see this element, or this
195 // element could not break this cycle.
196 it->second.push_back(replacementStack.size());
197 replacementStack.emplace_back();
198
199 return CacheEntry(*this, element);
200}
201
202template <typename InT, typename OutT>
203void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
204 OutT result) {
205 ReplacementFrame &currFrame = replacementStack.back();
206 // With the conclusion of this replacement frame, the current element is no
207 // longer a dependent element.
208 currFrame.dependentFrames.erase(replacementStack.size() - 1);
209
210 auto prevLayerIter = ++replacementStack.rbegin();
211 if (prevLayerIter == replacementStack.rend()) {
212 // If this is the last frame, there should be zero dependents.
213 assert(currFrame.dependentFrames.empty() &&
214 "internal error: top-level dependent replacement");
215 // Cache standalone result.
216 standaloneCache[element] = result;
217 } else if (currFrame.dependentFrames.empty()) {
218 // Cache standalone result.
219 standaloneCache[element] = result;
220 } else {
221 // Cache dependent result.
222 size_t highestDependentFrame = *currFrame.dependentFrames.begin();
223 dependentCache[element] = {result, highestDependentFrame};
224
225 // Otherwise, the previous frame inherits the same dependent frames.
226 prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
227 currFrame.dependentFrames.end());
228
229 // Mark this current replacement as a depending replacement on the closest
230 // dependent frame.
231 replacementStack[highestDependentFrame].dependingReplacements.insert(
232 element);
233 }
234
235 // All depending replacements in the cache must be purged.
236 for (InT key : currFrame.dependingReplacements)
237 dependentCache.erase(key);
238
239 replacementStack.pop_back();
240 auto it = cyclicElementFrame.find(element);
241 it->second.pop_back();
242 if (it->second.empty())
243 cyclicElementFrame.erase(it);
244}
245
246//===----------------------------------------------------------------------===//
247// CachedCyclicReplacer
248//===----------------------------------------------------------------------===//
249
250/// A helper class for cases where the input/output types of the replacer
251/// function is identical to the types stored in the cache. This class wraps
252/// the user-provided replacer function, and can be used in place of the user
253/// function.
254template <typename InT, typename OutT>
256public:
257 using ReplacerFn = std::function<OutT(InT)>;
260
263 : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
264
265 OutT operator()(InT element) {
266 auto cacheEntry = cache.lookupOrInit(element);
267 if (std::optional<OutT> result = cacheEntry.get())
268 return *result;
269
270 OutT result = replacer(element);
271 cacheEntry.resolve(result);
272 return result;
273 }
274
275private:
276 ReplacerFn replacer;
278};
279
280} // namespace mlir
281
282#endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn
std::function< OutT(InT)> ReplacerFn
A cache for replacer-like functions that map values between two domains.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
std::function< std::optional< OutT >(InT)> CycleBreakerFn
User-provided replacement function & cycle-breaking functions.
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
Include the generated interface declarations.
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
bool wasRepeated() const
Check whether this node was repeated during recursive replacements.
const std::optional< OutT > & get() const
Get the resolved result if one exists.