MLIR 22.0.0git
ThreadLocalCache.h
Go to the documentation of this file.
1//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a definition of the ThreadLocalCache class. This class
10// provides support for defining thread local objects with non-static duration.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15#define MLIR_SUPPORT_THREADLOCALCACHE_H
16
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/Support/ManagedStatic.h"
20#include "llvm/Support/Mutex.h"
21
22namespace mlir {
23/// This class provides support for defining a thread local object with non
24/// static storage duration. This is very useful for situations in which a data
25/// cache has very large lock contention.
26template <typename ValueT>
28 struct PerInstanceState;
29
30 using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>;
31
32 /// The "observer" is owned by a thread-local cache instance. It is
33 /// constructed the first time a `ThreadLocalCache` instance is accessed by a
34 /// thread, unless `perInstanceState` happens to get re-allocated to the same
35 /// address as a previous one. A `thread_local` instance of this class is
36 /// destructed when the thread in which it lives is destroyed.
37 ///
38 /// This class is called the "observer" because while values cached in
39 /// thread-local caches are owned by `PerInstanceState`, a reference is stored
40 /// via this class in the TLC. With a double pointer, it knows when the
41 /// referenced value has been destroyed.
42 struct Observer {
43 /// This is the double pointer, explicitly allocated because we need to keep
44 /// the address stable if the TLC map re-allocates. It is owned by the
45 /// observer and shared with the value owner.
46 std::shared_ptr<PointerAndFlag> ptr =
47 std::make_shared<PointerAndFlag>(std::make_pair(nullptr, false));
48 /// Because the `Owner` instance that lives inside `PerInstanceState`
49 /// contains a reference to the double pointer, and likewise this class
50 /// contains a reference to the value, we need to synchronize destruction of
51 /// the TLC and the `PerInstanceState` to avoid racing. This weak pointer is
52 /// acquired during TLC destruction if the `PerInstanceState` hasn't entered
53 /// its destructor yet, and prevents it from happening.
54 std::weak_ptr<PerInstanceState> keepalive;
55 };
56
57 /// This struct owns the cache entries. It contains a reference back to the
58 /// reference inside the cache so that it can be written to null to indicate
59 /// that the cache entry is invalidated. It needs to do this because
60 /// `perInstanceState` could get re-allocated to the same pointer and we don't
61 /// remove entries from the TLC when it is deallocated. Thus, we have to reset
62 /// the TLC entries to a starting state in case the `ThreadLocalCache` lives
63 /// shorter than the threads.
64 struct Owner {
65 /// Save a pointer to the reference and write it to the newly created entry.
66 Owner(Observer &observer)
67 : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
68 observer.ptr->second = true;
69 observer.ptr->first = value.get();
70 }
71 ~Owner() {
72 if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) {
73 ptr->first = nullptr;
74 ptr->second = false;
75 }
76 }
77
78 Owner(Owner &&) = default;
79 Owner &operator=(Owner &&) = default;
80
81 std::unique_ptr<ValueT> value;
82 std::weak_ptr<PointerAndFlag> ptrRef;
83 };
84
85 // Keep a separate shared_ptr protected state that can be acquired atomically
86 // instead of using shared_ptr's for each value. This avoids a problem
87 // where the instance shared_ptr is locked() successfully, and then the
88 // ThreadLocalCache gets destroyed before remove() can be called successfully.
89 struct PerInstanceState {
90 /// Remove the given value entry. This is called when a thread local cache
91 /// is destructing but still contains references to values owned by the
92 /// `PerInstanceState`. Removal is required because it prevents writeback to
93 /// a pointer that was deallocated.
94 void remove(ValueT *value) {
95 // Erase the found value directly, because it is guaranteed to be in the
96 // list.
97 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
98 auto it = llvm::find_if(instances, [&](Owner &instance) {
99 return instance.value.get() == value;
100 });
101 assert(it != instances.end() && "expected value to exist in cache");
102 instances.erase(it);
103 }
104
105 /// Owning pointers to all of the values that have been constructed for this
106 /// object in the static cache.
107 SmallVector<Owner, 1> instances;
108
109 /// A mutex used when a new thread instance has been added to the cache for
110 /// this object.
111 llvm::sys::SmartMutex<true> instanceMutex;
112 };
113
114 /// The type used for the static thread_local cache. This is a map between an
115 /// instance of the non-static cache and a weak reference to an instance of
116 /// ValueT. We use a weak reference here so that the object can be destroyed
117 /// without needing to lock access to the cache itself.
118 struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
119 ~CacheType() {
120 // Remove the values of this cache that haven't already expired. This is
121 // required because if we don't remove them, they will contain a reference
122 // back to the data here that is being destroyed.
123 for (auto &[instance, observer] : *this)
124 if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
125 state->remove(observer.ptr->first);
126 }
127
128 /// Clear out any unused entries within the map. This method is not
129 /// thread-safe, and should only be called by the same thread as the cache.
130 void clearExpiredEntries() {
131 for (auto it = this->begin(), e = this->end(); it != e;) {
132 auto curIt = it++;
133 if (!curIt->second.ptr->second)
134 this->erase(curIt);
135 }
136 }
137 };
138
139public:
140 ThreadLocalCache() = default;
142 // No cleanup is necessary here as the shared_pointer memory will go out of
143 // scope and invalidate the weak pointers held by the thread_local caches.
144 }
145
146 /// Return an instance of the value type for the current thread.
147 ValueT &get() {
148 // Check for an already existing instance for this thread.
149 CacheType &staticCache = getStaticCache();
150 Observer &threadInstance = staticCache[perInstanceState.get()];
151 if (ValueT *value = threadInstance.ptr->first)
152 return *value;
153
154 // Otherwise, create a new instance for this thread.
155 {
156 llvm::sys::SmartScopedLock<true> threadInstanceLock(
157 perInstanceState->instanceMutex);
158 perInstanceState->instances.emplace_back(threadInstance);
159 }
160 threadInstance.keepalive = perInstanceState;
161
162 // Before returning the new instance, take the chance to clear out any used
163 // entries in the static map. The cache is only cleared within the same
164 // thread to remove the need to lock the cache itself.
165 staticCache.clearExpiredEntries();
166 return *threadInstance.ptr->first;
167 }
168 ValueT &operator*() { return get(); }
169 ValueT *operator->() { return &get(); }
170
171private:
173 ThreadLocalCache(const ThreadLocalCache &) = delete;
174 ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
175
176 /// Return the static thread local instance of the cache type.
177 static CacheType &getStaticCache() {
178 static LLVM_THREAD_LOCAL CacheType cache;
179 return cache;
180 }
181
182 std::shared_ptr<PerInstanceState> perInstanceState =
183 std::make_shared<PerInstanceState>();
184};
185} // namespace mlir
186
187#endif // MLIR_SUPPORT_THREADLOCALCACHE_H
This class provides support for defining a thread local object with non static storage duration.
ValueT & get()
Return an instance of the value type for the current thread.
Include the generated interface declarations.