15 #ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
16 #define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/SmallVector.h"
51 template <
typename InT,
typename OutT>
61 : cycleBreaker(std::move(cycleBreaker)) {}
76 ReplacementFrame &currFrame = cache.replacementStack.back();
77 size_t currFrameIndex = cache.replacementStack.size() - 1;
78 return currFrame.dependentFrames.count(currFrameIndex);
84 assert(!this->result &&
"cache entry already resolved");
85 cache.finalizeReplacement(element, result);
86 this->result = std::move(result);
90 const std::optional<OutT> &
get()
const {
return result; }
96 std::optional<OutT> result = std::nullopt)
97 : cache(cache), element(std::move(element)), result(result) {}
101 std::optional<OutT> result;
118 void finalizeReplacement(InT element, OutT result);
123 struct DependentReplacement {
127 size_t highestDependentFrame;
131 struct ReplacementFrame {
138 std::set<size_t, std::greater<size_t>> dependentFrames;
148 bool resolvingCycle =
false;
151 template <
typename InT,
typename OutT>
152 typename CyclicReplacerCache<InT, OutT>::CacheEntry
154 assert(!resolvingCycle &&
155 "illegal recursive invocation while breaking cycle");
157 if (
auto it = standaloneCache.find(element); it != standaloneCache.end())
158 return CacheEntry(*
this, element, it->second);
160 if (
auto it = dependentCache.find(element); it != dependentCache.end()) {
163 ReplacementFrame &currFrame = replacementStack.back();
164 currFrame.dependentFrames.insert(it->second.highestDependentFrame);
165 return CacheEntry(*
this, element, it->second.replacement);
168 auto [it, inserted] = cyclicElementFrame.try_emplace(element);
171 resolvingCycle =
true;
172 std::optional<OutT> result = cycleBreaker(element);
173 resolvingCycle =
false;
176 size_t dependentFrame = it->second.back();
177 dependentCache[element] = {*result, dependentFrame};
178 ReplacementFrame &currFrame = replacementStack.back();
181 currFrame.dependentFrames.insert(dependentFrame);
191 assert(it->second.size() <= 2 &&
"illegal 3rd repeat of input");
196 it->second.push_back(replacementStack.size());
197 replacementStack.emplace_back();
202 template <
typename InT,
typename OutT>
205 ReplacementFrame &currFrame = replacementStack.back();
208 currFrame.dependentFrames.erase(replacementStack.size() - 1);
210 auto prevLayerIter = ++replacementStack.rbegin();
211 if (prevLayerIter == replacementStack.rend()) {
213 assert(currFrame.dependentFrames.empty() &&
214 "internal error: top-level dependent replacement");
216 standaloneCache[element] = result;
217 }
else if (currFrame.dependentFrames.empty()) {
219 standaloneCache[element] = result;
222 size_t highestDependentFrame = *currFrame.dependentFrames.begin();
223 dependentCache[element] = {result, highestDependentFrame};
226 prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
227 currFrame.dependentFrames.end());
231 replacementStack[highestDependentFrame].dependingReplacements.insert(
236 for (InT key : currFrame.dependingReplacements)
237 dependentCache.erase(key);
239 replacementStack.pop_back();
240 auto it = cyclicElementFrame.find(element);
241 it->second.pop_back();
242 if (it->second.empty())
243 cyclicElementFrame.erase(it);
254 template <
typename InT,
typename OutT>
263 : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
266 auto cacheEntry = cache.lookupOrInit(element);
267 if (std::optional<OutT> result = cacheEntry.get())
270 OutT result = replacer(element);
271 cacheEntry.resolve(result);
A helper class for cases where the input/output types of the replacer function is identical to the ty...
CachedCyclicReplacer()=delete
OutT operator()(InT element)
typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
std::function< OutT(InT)> ReplacerFn
A cache for replacer-like functions that map values between two domains.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
std::function< std::optional< OutT >(InT)> CycleBreakerFn
User-provided replacement function & cycle-breaking functions.
CyclicReplacerCache()=delete
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
Include the generated interface declarations.
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
bool wasRepeated() const
Check whether this node was repeated during recursive replacements.
const std::optional< OutT > & get() const
Get the resolved result if one exists.