15#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H 
   16#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H 
   18#include "llvm/ADT/DenseMap.h" 
   19#include "llvm/ADT/DenseSet.h" 
   20#include "llvm/ADT/SmallVector.h" 
   51template <
typename InT, 
typename OutT>
 
   61      : cycleBreaker(std::move(cycleBreaker)) {}
 
 
   76      ReplacementFrame &currFrame = cache.replacementStack.back();
 
   77      size_t currFrameIndex = cache.replacementStack.size() - 1;
 
   78      return currFrame.dependentFrames.count(currFrameIndex);
 
 
   84      assert(!this->result && 
"cache entry already resolved");
 
   85      cache.finalizeReplacement(element, result);
 
   86      this->result = std::move(result);
 
 
   90    const std::optional<OutT> &
get()
 const { 
return result; }
 
   94    CacheEntry() = 
delete;
 
   96               std::optional<OutT> result = std::nullopt)
 
   97        : cache(cache), element(std::move(element)), result(result) {}
 
  101    std::optional<OutT> 
result;
 
 
  118  void finalizeReplacement(InT element, OutT 
result);
 
  123  struct DependentReplacement {
 
  127    size_t highestDependentFrame;
 
  131  struct ReplacementFrame {
 
  138    std::set<size_t, std::greater<size_t>> dependentFrames;
 
  148  bool resolvingCycle = 
false;
 
  151template <
typename InT, 
typename OutT>
 
  154  assert(!resolvingCycle &&
 
  155         "illegal recursive invocation while breaking cycle");
 
  157  if (
auto it = standaloneCache.find(element); it != standaloneCache.end())
 
  158    return CacheEntry(*
this, element, it->second);
 
  160  if (
auto it = dependentCache.find(element); it != dependentCache.end()) {
 
  163    ReplacementFrame &currFrame = replacementStack.back();
 
  164    currFrame.dependentFrames.insert(it->second.highestDependentFrame);
 
  165    return CacheEntry(*
this, element, it->second.replacement);
 
  168  auto [it, 
inserted] = cyclicElementFrame.try_emplace(element);
 
  171    resolvingCycle = 
true;
 
  172    std::optional<OutT> 
result = cycleBreaker(element);
 
  173    resolvingCycle = 
false;
 
  176      size_t dependentFrame = it->second.back();
 
  177      dependentCache[element] = {*
result, dependentFrame};
 
  178      ReplacementFrame &currFrame = replacementStack.back();
 
  181      currFrame.dependentFrames.insert(dependentFrame);
 
  191    assert(it->second.size() <= 2 && 
"illegal 3rd repeat of input");
 
  196  it->second.push_back(replacementStack.size());
 
  197  replacementStack.emplace_back();
 
 
 
  202template <
typename InT, 
typename OutT>
 
  203void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
 
  205  ReplacementFrame &currFrame = replacementStack.back();
 
  208  currFrame.dependentFrames.erase(replacementStack.size() - 1);
 
  210  auto prevLayerIter = ++replacementStack.rbegin();
 
  211  if (prevLayerIter == replacementStack.rend()) {
 
  213    assert(currFrame.dependentFrames.empty() &&
 
  214           "internal error: top-level dependent replacement");
 
  216    standaloneCache[element] = 
result;
 
  217  } 
else if (currFrame.dependentFrames.empty()) {
 
  219    standaloneCache[element] = 
result;
 
  222    size_t highestDependentFrame = *currFrame.dependentFrames.begin();
 
  223    dependentCache[element] = {
result, highestDependentFrame};
 
  226    prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
 
  227                                          currFrame.dependentFrames.end());
 
  231    replacementStack[highestDependentFrame].dependingReplacements.insert(
 
  236  for (InT key : currFrame.dependingReplacements)
 
  237    dependentCache.erase(key);
 
  239  replacementStack.pop_back();
 
  240  auto it = cyclicElementFrame.find(element);
 
  241  it->second.pop_back();
 
  242  if (it->second.empty())
 
  243    cyclicElementFrame.erase(it);
 
  254template <
typename InT, 
typename OutT>
 
  263      : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
 
 
  266    auto cacheEntry = cache.lookupOrInit(element);
 
  267    if (std::optional<OutT> 
result = cacheEntry.get())
 
  270    OutT 
result = replacer(element);
 
  271    cacheEntry.resolve(
result);
 
 
 
 
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
 
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
 
CachedCyclicReplacer()=delete
 
OutT operator()(InT element)
 
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
 
typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn
 
std::function< OutT(InT)> ReplacerFn
 
A cache for replacer-like functions that map values between two domains.
 
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
 
std::function< std::optional< OutT >(InT)> CycleBreakerFn
User-provided replacement function & cycle-breaking functions.
 
CyclicReplacerCache()=delete
 
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
 
Include the generated interface declarations.
 
A possibly unresolved cache entry.
 
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
 
friend class CyclicReplacerCache
 
bool wasRepeated() const
Check whether this node was repeated during recursive replacements.
 
const std::optional< OutT > & get() const
Get the resolved result if one exists.