15#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
16#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/DenseSet.h"
20#include "llvm/ADT/SmallVector.h"
51template <
typename InT,
typename OutT>
61 : cycleBreaker(std::move(cycleBreaker)) {}
76 ReplacementFrame &currFrame = cache.replacementStack.back();
77 size_t currFrameIndex = cache.replacementStack.size() - 1;
78 return currFrame.dependentFrames.count(currFrameIndex);
84 assert(!this->result &&
"cache entry already resolved");
85 cache.finalizeReplacement(element, result);
86 this->result = std::move(result);
90 const std::optional<OutT> &
get()
const {
return result; }
94 CacheEntry() =
delete;
96 std::optional<OutT> result = std::nullopt)
97 : cache(cache), element(std::move(element)), result(result) {}
101 std::optional<OutT>
result;
118 void finalizeReplacement(InT element, OutT
result);
123 struct DependentReplacement {
127 size_t highestDependentFrame;
131 struct ReplacementFrame {
138 std::set<size_t, std::greater<size_t>> dependentFrames;
148 bool resolvingCycle =
false;
151template <
typename InT,
typename OutT>
154 assert(!resolvingCycle &&
155 "illegal recursive invocation while breaking cycle");
157 if (
auto it = standaloneCache.find(element); it != standaloneCache.end())
158 return CacheEntry(*
this, element, it->second);
160 if (
auto it = dependentCache.find(element); it != dependentCache.end()) {
163 ReplacementFrame &currFrame = replacementStack.back();
164 currFrame.dependentFrames.insert(it->second.highestDependentFrame);
165 return CacheEntry(*
this, element, it->second.replacement);
168 auto [it,
inserted] = cyclicElementFrame.try_emplace(element);
171 resolvingCycle =
true;
172 std::optional<OutT>
result = cycleBreaker(element);
173 resolvingCycle =
false;
176 size_t dependentFrame = it->second.back();
177 dependentCache[element] = {*
result, dependentFrame};
178 ReplacementFrame &currFrame = replacementStack.back();
181 currFrame.dependentFrames.insert(dependentFrame);
191 assert(it->second.size() <= 2 &&
"illegal 3rd repeat of input");
196 it->second.push_back(replacementStack.size());
197 replacementStack.emplace_back();
202template <
typename InT,
typename OutT>
203void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
205 ReplacementFrame &currFrame = replacementStack.back();
208 currFrame.dependentFrames.erase(replacementStack.size() - 1);
210 auto prevLayerIter = ++replacementStack.rbegin();
211 if (prevLayerIter == replacementStack.rend()) {
213 assert(currFrame.dependentFrames.empty() &&
214 "internal error: top-level dependent replacement");
216 standaloneCache[element] =
result;
217 }
else if (currFrame.dependentFrames.empty()) {
219 standaloneCache[element] =
result;
222 size_t highestDependentFrame = *currFrame.dependentFrames.begin();
223 dependentCache[element] = {
result, highestDependentFrame};
226 prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
227 currFrame.dependentFrames.end());
231 replacementStack[highestDependentFrame].dependingReplacements.insert(
236 for (InT key : currFrame.dependingReplacements)
237 dependentCache.erase(key);
239 replacementStack.pop_back();
240 auto it = cyclicElementFrame.find(element);
241 it->second.pop_back();
242 if (it->second.empty())
243 cyclicElementFrame.erase(it);
254template <
typename InT,
typename OutT>
263 : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
266 auto cacheEntry = cache.lookupOrInit(element);
267 if (std::optional<OutT>
result = cacheEntry.get())
270 OutT
result = replacer(element);
271 cacheEntry.resolve(
result);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
CachedCyclicReplacer()=delete
OutT operator()(InT element)
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn
std::function< OutT(InT)> ReplacerFn
A cache for replacer-like functions that map values between two domains.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
std::function< std::optional< OutT >(InT)> CycleBreakerFn
User-provided replacement function & cycle-breaking functions.
CyclicReplacerCache()=delete
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
Include the generated interface declarations.
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
friend class CyclicReplacerCache
bool wasRepeated() const
Check whether this node was repeated during recursive replacements.
const std::optional< OutT > & get() const
Get the resolved result if one exists.