MLIR  20.0.0git
CyclicReplacerCache.h
Go to the documentation of this file.
1 //===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains helper classes for caching replacer-like functions that
10 // map values between two domains. They are able to handle replacer logic that
11 // contains self-recursion.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
16 #define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <functional>
22 #include <optional>
23 #include <set>
24 
25 namespace mlir {
26 
27 //===----------------------------------------------------------------------===//
28 // CyclicReplacerCache
29 //===----------------------------------------------------------------------===//
30 
31 /// A cache for replacer-like functions that map values between two domains. The
32 /// difference compared to just using a map to cache in-out pairs is that this
33 /// class is able to handle replacer logic that is self-recursive (and thus may
34 /// cause infinite recursion in the naive case).
35 ///
36 /// This class provides a hook for the user to perform cycle pruning when a
37 /// cycle is identified, and is able to perform context-sensitive caching so
38 /// that the replacement result for an input that is part of a pruned cycle can
39 /// be distinct from the replacement result for the same input when it is not
40 /// part of a cycle.
41 ///
42 /// In addition, this class allows deferring cycle pruning until specific inputs
43 /// are repeated. This is useful for cases where not all elements in a cycle can
44 /// perform pruning. The user still must guarantee that at least one element in
45 /// any given cycle can perform pruning. Even if not, an assertion will
46 /// eventually be tripped instead of infinite recursion (the run-time is
47 /// linearly bounded by the maximum cycle length of its input).
48 ///
49 /// WARNING: This class works best with InT & OutT that are trivial scalar
50 /// types. The input/output elements will be frequently copied and hashed.
51 template <typename InT, typename OutT>
53 public:
54  /// User-provided replacement function & cycle-breaking functions.
55  /// The cycle-breaking function must not make any more recursive invocations
56  /// to this cached replacer.
57  using CycleBreakerFn = std::function<std::optional<OutT>(InT)>;
58 
59  CyclicReplacerCache() = delete;
61  : cycleBreaker(std::move(cycleBreaker)) {}
62 
63  /// A possibly unresolved cache entry.
64  /// If unresolved, the entry must be resolved before it goes out of scope.
65  struct CacheEntry {
66  public:
67  ~CacheEntry() { assert(result && "unresovled cache entry"); }
68 
69  /// Check whether this node was repeated during recursive replacements.
70  /// This only makes sense to be called after all recursive replacements are
71  /// completed and the current element has resurfaced to the top of the
72  /// replacement stack.
73  bool wasRepeated() const {
74  // If the top frame includes itself as a dependency, then it must have
75  // been repeated.
76  ReplacementFrame &currFrame = cache.replacementStack.back();
77  size_t currFrameIndex = cache.replacementStack.size() - 1;
78  return currFrame.dependentFrames.count(currFrameIndex);
79  }
80 
81  /// Resolve an unresolved cache entry by providing the result to be stored
82  /// in the cache.
83  void resolve(OutT result) {
84  assert(!this->result && "cache entry already resolved");
85  cache.finalizeReplacement(element, result);
86  this->result = std::move(result);
87  }
88 
89  /// Get the resolved result if one exists.
90  const std::optional<OutT> &get() const { return result; }
91 
92  private:
93  friend class CyclicReplacerCache;
94  CacheEntry() = delete;
95  CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
96  std::optional<OutT> result = std::nullopt)
97  : cache(cache), element(std::move(element)), result(result) {}
98 
100  InT element;
101  std::optional<OutT> result;
102  };
103 
104  /// Lookup the cache for a pre-calculated replacement for `element`.
105  /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
106  /// unresolved CacheEntry will be returned, and the caller must resolve it
107  /// with the calculated replacement so it can be registered in the cache for
108  /// future use.
109  /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
110  /// CacheEntries that are returned must be resolved in reverse order of
111  /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
112  /// the first retrieved CacheEntry must be resolved last. This should be
113  /// natural when used as a stack / inside recursion.
114  CacheEntry lookupOrInit(InT element);
115 
116 private:
117  /// Register the replacement in the cache and update the replacementStack.
118  void finalizeReplacement(InT element, OutT result);
119 
120  CycleBreakerFn cycleBreaker;
121  llvm::DenseMap<InT, OutT> standaloneCache;
122 
123  struct DependentReplacement {
124  OutT replacement;
125  /// The highest replacement frame index that this cache entry is dependent
126  /// on.
127  size_t highestDependentFrame;
128  };
130 
131  struct ReplacementFrame {
132  /// The set of elements that is only legal while under this current frame.
133  /// They need to be removed from the cache when this frame is popped off the
134  /// replacement stack.
135  llvm::DenseSet<InT> dependingReplacements;
136  /// The set of frame indices that this current frame's replacement is
137  /// dependent on, ordered from highest to lowest.
138  std::set<size_t, std::greater<size_t>> dependentFrames;
139  };
140  /// Every element currently in the progress of being replaced pushes a frame
141  /// onto this stack.
142  llvm::SmallVector<ReplacementFrame> replacementStack;
143  /// Maps from each input element to its indices on the replacement stack.
145  /// If set to true, we are currently asking an element to break a cycle. No
146  /// more recursive invocations is allowed while this is true (the replacement
147  /// stack can no longer grow).
148  bool resolvingCycle = false;
149 };
150 
151 template <typename InT, typename OutT>
152 typename CyclicReplacerCache<InT, OutT>::CacheEntry
154  assert(!resolvingCycle &&
155  "illegal recursive invocation while breaking cycle");
156 
157  if (auto it = standaloneCache.find(element); it != standaloneCache.end())
158  return CacheEntry(*this, element, it->second);
159 
160  if (auto it = dependentCache.find(element); it != dependentCache.end()) {
161  // Update the current top frame (the element that invoked this current
162  // replacement) to include any dependencies the cache entry had.
163  ReplacementFrame &currFrame = replacementStack.back();
164  currFrame.dependentFrames.insert(it->second.highestDependentFrame);
165  return CacheEntry(*this, element, it->second.replacement);
166  }
167 
168  auto [it, inserted] = cyclicElementFrame.try_emplace(element);
169  if (!inserted) {
170  // This is a repeat of a known element. Try to break cycle here.
171  resolvingCycle = true;
172  std::optional<OutT> result = cycleBreaker(element);
173  resolvingCycle = false;
174  if (result) {
175  // Cycle was broken.
176  size_t dependentFrame = it->second.back();
177  dependentCache[element] = {*result, dependentFrame};
178  ReplacementFrame &currFrame = replacementStack.back();
179  // If this is a repeat, there is no replacement frame to pop. Mark the top
180  // frame as being dependent on this element.
181  currFrame.dependentFrames.insert(dependentFrame);
182 
183  return CacheEntry(*this, element, *result);
184  }
185 
186  // Cycle could not be broken.
187  // A legal setup must ensure at least one element of each cycle can break
188  // cycles. Under this setup, each element can be seen at most twice before
189  // the cycle is broken. If we see an element more than twice, we know this
190  // is an illegal setup.
191  assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
192  }
193 
194  // Otherwise, either this is the first time we see this element, or this
195  // element could not break this cycle.
196  it->second.push_back(replacementStack.size());
197  replacementStack.emplace_back();
198 
199  return CacheEntry(*this, element);
200 }
201 
202 template <typename InT, typename OutT>
204  OutT result) {
205  ReplacementFrame &currFrame = replacementStack.back();
206  // With the conclusion of this replacement frame, the current element is no
207  // longer a dependent element.
208  currFrame.dependentFrames.erase(replacementStack.size() - 1);
209 
210  auto prevLayerIter = ++replacementStack.rbegin();
211  if (prevLayerIter == replacementStack.rend()) {
212  // If this is the last frame, there should be zero dependents.
213  assert(currFrame.dependentFrames.empty() &&
214  "internal error: top-level dependent replacement");
215  // Cache standalone result.
216  standaloneCache[element] = result;
217  } else if (currFrame.dependentFrames.empty()) {
218  // Cache standalone result.
219  standaloneCache[element] = result;
220  } else {
221  // Cache dependent result.
222  size_t highestDependentFrame = *currFrame.dependentFrames.begin();
223  dependentCache[element] = {result, highestDependentFrame};
224 
225  // Otherwise, the previous frame inherits the same dependent frames.
226  prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
227  currFrame.dependentFrames.end());
228 
229  // Mark this current replacement as a depending replacement on the closest
230  // dependent frame.
231  replacementStack[highestDependentFrame].dependingReplacements.insert(
232  element);
233  }
234 
235  // All depending replacements in the cache must be purged.
236  for (InT key : currFrame.dependingReplacements)
237  dependentCache.erase(key);
238 
239  replacementStack.pop_back();
240  auto it = cyclicElementFrame.find(element);
241  it->second.pop_back();
242  if (it->second.empty())
243  cyclicElementFrame.erase(it);
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // CachedCyclicReplacer
248 //===----------------------------------------------------------------------===//
249 
250 /// A helper class for cases where the input/output types of the replacer
251 /// function is identical to the types stored in the cache. This class wraps
252 /// the user-provided replacer function, and can be used in place of the user
253 /// function.
254 template <typename InT, typename OutT>
256 public:
257  using ReplacerFn = std::function<OutT(InT)>;
260 
263  : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
264 
265  OutT operator()(InT element) {
266  auto cacheEntry = cache.lookupOrInit(element);
267  if (std::optional<OutT> result = cacheEntry.get())
268  return *result;
269 
270  OutT result = replacer(element);
271  cacheEntry.resolve(result);
272  return result;
273  }
274 
275 private:
276  ReplacerFn replacer;
278 };
279 
280 } // namespace mlir
281 
282 #endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H
A helper class for cases where the input/output types of the replacer function is identical to the ty...
typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
std::function< OutT(InT)> ReplacerFn
A cache for replacer-like functions that map values between two domains.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
std::function< std::optional< OutT >(InT)> CycleBreakerFn
User-provided replacement function & cycle-breaking functions.
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
Include the generated interface declarations.
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
bool wasRepeated() const
Check whether this node was repeated during recursive replacements.
const std::optional< OutT > & get() const
Get the resolved result if one exists.