14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15 #define MLIR_SUPPORT_THREADLOCALCACHE_H
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
26 template <
typename ValueT>
28 struct PerInstanceState;
44 std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(
nullptr);
51 std::weak_ptr<PerInstanceState> keepalive;
63 Owner(Observer &observer)
64 : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
65 *observer.ptr = value.get();
68 if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
72 Owner(Owner &&) =
default;
73 Owner &operator=(Owner &&) =
default;
75 std::unique_ptr<ValueT> value;
76 std::weak_ptr<ValueT *> ptrRef;
83 struct PerInstanceState {
88 void remove(ValueT *value) {
91 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
92 auto it = llvm::find_if(instances, [&](Owner &instance) {
93 return instance.value.get() == value;
95 assert(it != instances.end() &&
"expected value to exist in cache");
112 struct CacheType :
public llvm::SmallDenseMap<PerInstanceState *, Observer> {
117 for (
auto &[instance, observer] : *
this)
118 if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
119 state->remove(*observer.ptr);
124 void clearExpiredEntries() {
125 for (
auto it = this->begin(), e = this->end(); it != e;) {
127 if (!*curIt->second.ptr)
143 CacheType &staticCache = getStaticCache();
144 Observer &threadInstance = staticCache[perInstanceState.get()];
145 if (ValueT *value = *threadInstance.ptr)
150 llvm::sys::SmartScopedLock<true> threadInstanceLock(
151 perInstanceState->instanceMutex);
152 perInstanceState->instances.emplace_back(threadInstance);
154 threadInstance.keepalive = perInstanceState;
159 staticCache.clearExpiredEntries();
160 return **threadInstance.ptr;
171 static CacheType &getStaticCache() {
172 static LLVM_THREAD_LOCAL CacheType cache;
176 std::shared_ptr<PerInstanceState> perInstanceState =
177 std::make_shared<PerInstanceState>();
This class provides support for defining a thread local object with non static storage duration.
ThreadLocalCache()=default
ValueT & get()
Return an instance of the value type for the current thread.
Include the generated interface declarations.