14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15 #define MLIR_SUPPORT_THREADLOCALCACHE_H
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
26 template <
typename ValueT>
28 struct PerInstanceState;
30 using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>;
46 std::shared_ptr<PointerAndFlag> ptr =
47 std::make_shared<PointerAndFlag>(std::make_pair(
nullptr,
false));
54 std::weak_ptr<PerInstanceState> keepalive;
66 Owner(Observer &observer)
67 : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
68 observer.ptr->second =
true;
69 observer.ptr->first = value.get();
72 if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) {
78 Owner(Owner &&) =
default;
79 Owner &operator=(Owner &&) =
default;
81 std::unique_ptr<ValueT> value;
82 std::weak_ptr<PointerAndFlag> ptrRef;
89 struct PerInstanceState {
94 void remove(ValueT *value) {
97 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
98 auto it = llvm::find_if(instances, [&](Owner &instance) {
99 return instance.value.get() == value;
101 assert(it != instances.end() &&
"expected value to exist in cache");
118 struct CacheType :
public llvm::SmallDenseMap<PerInstanceState *, Observer> {
123 for (
auto &[instance, observer] : *
this)
124 if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
125 state->remove(observer.ptr->first);
130 void clearExpiredEntries() {
131 for (
auto it = this->begin(), e = this->end(); it != e;) {
133 if (!curIt->second.ptr->second)
149 CacheType &staticCache = getStaticCache();
150 Observer &threadInstance = staticCache[perInstanceState.get()];
151 if (ValueT *value = threadInstance.ptr->first)
156 llvm::sys::SmartScopedLock<true> threadInstanceLock(
157 perInstanceState->instanceMutex);
158 perInstanceState->instances.emplace_back(threadInstance);
160 threadInstance.keepalive = perInstanceState;
165 staticCache.clearExpiredEntries();
166 return *threadInstance.ptr->first;
177 static CacheType &getStaticCache() {
178 static LLVM_THREAD_LOCAL CacheType cache;
182 std::shared_ptr<PerInstanceState> perInstanceState =
183 std::make_shared<PerInstanceState>();
This class provides support for defining a thread local object with non static storage duration.
ThreadLocalCache()=default
ValueT & get()
Return an instance of the value type for the current thread.
Include the generated interface declarations.