MLIR  20.0.0git
ThreadLocalCache.h
Go to the documentation of this file.
1 //===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a definition of the ThreadLocalCache class. This class
10 // provides support for defining thread local objects with non-static duration.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15 #define MLIR_SUPPORT_THREADLOCALCACHE_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
21 
22 namespace mlir {
23 /// This class provides support for defining a thread local object with non
24 /// static storage duration. This is very useful for situations in which a data
25 /// cache has very large lock contention.
26 template <typename ValueT>
28  struct PerInstanceState;
29 
30  using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>;
31 
32  /// The "observer" is owned by a thread-local cache instance. It is
33  /// constructed the first time a `ThreadLocalCache` instance is accessed by a
34  /// thread, unless `perInstanceState` happens to get re-allocated to the same
35  /// address as a previous one. A `thread_local` instance of this class is
36  /// destructed when the thread in which it lives is destroyed.
37  ///
38  /// This class is called the "observer" because while values cached in
39  /// thread-local caches are owned by `PerInstanceState`, a reference is stored
40  /// via this class in the TLC. With a double pointer, it knows when the
41  /// referenced value has been destroyed.
42  struct Observer {
43  /// This is the double pointer, explicitly allocated because we need to keep
44  /// the address stable if the TLC map re-allocates. It is owned by the
45  /// observer and shared with the value owner.
46  std::shared_ptr<PointerAndFlag> ptr =
47  std::make_shared<PointerAndFlag>(std::make_pair(nullptr, false));
48  /// Because the `Owner` instance that lives inside `PerInstanceState`
49  /// contains a reference to the double pointer, and likewise this class
50  /// contains a reference to the value, we need to synchronize destruction of
51  /// the TLC and the `PerInstanceState` to avoid racing. This weak pointer is
52  /// acquired during TLC destruction if the `PerInstanceState` hasn't entered
53  /// its destructor yet, and prevents it from happening.
54  std::weak_ptr<PerInstanceState> keepalive;
55  };
56 
57  /// This struct owns the cache entries. It contains a reference back to the
58  /// reference inside the cache so that it can be written to null to indicate
59  /// that the cache entry is invalidated. It needs to do this because
60  /// `perInstanceState` could get re-allocated to the same pointer and we don't
61  /// remove entries from the TLC when it is deallocated. Thus, we have to reset
62  /// the TLC entries to a starting state in case the `ThreadLocalCache` lives
63  /// shorter than the threads.
64  struct Owner {
65  /// Save a pointer to the reference and write it to the newly created entry.
66  Owner(Observer &observer)
67  : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
68  observer.ptr->second = true;
69  observer.ptr->first = value.get();
70  }
71  ~Owner() {
72  if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) {
73  ptr->first = nullptr;
74  ptr->second = false;
75  }
76  }
77 
78  Owner(Owner &&) = default;
79  Owner &operator=(Owner &&) = default;
80 
81  std::unique_ptr<ValueT> value;
82  std::weak_ptr<PointerAndFlag> ptrRef;
83  };
84 
85  // Keep a separate shared_ptr protected state that can be acquired atomically
86  // instead of using shared_ptr's for each value. This avoids a problem
87  // where the instance shared_ptr is locked() successfully, and then the
88  // ThreadLocalCache gets destroyed before remove() can be called successfully.
89  struct PerInstanceState {
90  /// Remove the given value entry. This is called when a thread local cache
91  /// is destructing but still contains references to values owned by the
92  /// `PerInstanceState`. Removal is required because it prevents writeback to
93  /// a pointer that was deallocated.
94  void remove(ValueT *value) {
95  // Erase the found value directly, because it is guaranteed to be in the
96  // list.
97  llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
98  auto it = llvm::find_if(instances, [&](Owner &instance) {
99  return instance.value.get() == value;
100  });
101  assert(it != instances.end() && "expected value to exist in cache");
102  instances.erase(it);
103  }
104 
105  /// Owning pointers to all of the values that have been constructed for this
106  /// object in the static cache.
107  SmallVector<Owner, 1> instances;
108 
109  /// A mutex used when a new thread instance has been added to the cache for
110  /// this object.
111  llvm::sys::SmartMutex<true> instanceMutex;
112  };
113 
114  /// The type used for the static thread_local cache. This is a map between an
115  /// instance of the non-static cache and a weak reference to an instance of
116  /// ValueT. We use a weak reference here so that the object can be destroyed
117  /// without needing to lock access to the cache itself.
118  struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
119  ~CacheType() {
120  // Remove the values of this cache that haven't already expired. This is
121  // required because if we don't remove them, they will contain a reference
122  // back to the data here that is being destroyed.
123  for (auto &[instance, observer] : *this)
124  if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
125  state->remove(observer.ptr->first);
126  }
127 
128  /// Clear out any unused entries within the map. This method is not
129  /// thread-safe, and should only be called by the same thread as the cache.
130  void clearExpiredEntries() {
131  for (auto it = this->begin(), e = this->end(); it != e;) {
132  auto curIt = it++;
133  if (!curIt->second.ptr->second)
134  this->erase(curIt);
135  }
136  }
137  };
138 
139 public:
140  ThreadLocalCache() = default;
142  // No cleanup is necessary here as the shared_pointer memory will go out of
143  // scope and invalidate the weak pointers held by the thread_local caches.
144  }
145 
146  /// Return an instance of the value type for the current thread.
147  ValueT &get() {
148  // Check for an already existing instance for this thread.
149  CacheType &staticCache = getStaticCache();
150  Observer &threadInstance = staticCache[perInstanceState.get()];
151  if (ValueT *value = threadInstance.ptr->first)
152  return *value;
153 
154  // Otherwise, create a new instance for this thread.
155  {
156  llvm::sys::SmartScopedLock<true> threadInstanceLock(
157  perInstanceState->instanceMutex);
158  perInstanceState->instances.emplace_back(threadInstance);
159  }
160  threadInstance.keepalive = perInstanceState;
161 
162  // Before returning the new instance, take the chance to clear out any used
163  // entries in the static map. The cache is only cleared within the same
164  // thread to remove the need to lock the cache itself.
165  staticCache.clearExpiredEntries();
166  return *threadInstance.ptr->first;
167  }
168  ValueT &operator*() { return get(); }
169  ValueT *operator->() { return &get(); }
170 
171 private:
172  ThreadLocalCache(ThreadLocalCache &&) = delete;
173  ThreadLocalCache(const ThreadLocalCache &) = delete;
174  ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
175 
176  /// Return the static thread local instance of the cache type.
177  static CacheType &getStaticCache() {
178  static LLVM_THREAD_LOCAL CacheType cache;
179  return cache;
180  }
181 
182  std::shared_ptr<PerInstanceState> perInstanceState =
183  std::make_shared<PerInstanceState>();
184 };
185 } // namespace mlir
186 
187 #endif // MLIR_SUPPORT_THREADLOCALCACHE_H
This class provides support for defining a thread local object with non static storage duration.
ValueT & get()
Return an instance of the value type for the current thread.
Include the generated interface declarations.