MLIR  20.0.0git
ThreadLocalCache.h
Go to the documentation of this file.
1 //===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a definition of the ThreadLocalCache class. This class
10 // provides support for defining thread local objects with non-static duration.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15 #define MLIR_SUPPORT_THREADLOCALCACHE_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
21 
22 namespace mlir {
23 /// This class provides support for defining a thread local object with non
24 /// static storage duration. This is very useful for situations in which a data
25 /// cache has very large lock contention.
26 template <typename ValueT>
28  struct PerInstanceState;
29 
30  /// The "observer" is owned by a thread-local cache instance. It is
31  /// constructed the first time a `ThreadLocalCache` instance is accessed by a
32  /// thread, unless `perInstanceState` happens to get re-allocated to the same
33  /// address as a previous one. A `thread_local` instance of this class is
34  /// destructed when the thread in which it lives is destroyed.
35  ///
36  /// This class is called the "observer" because while values cached in
37  /// thread-local caches are owned by `PerInstanceState`, a reference is stored
38  /// via this class in the TLC. With a double pointer, it knows when the
39  /// referenced value has been destroyed.
40  struct Observer {
41  /// This is the double pointer, explicitly allocated because we need to keep
42  /// the address stable if the TLC map re-allocates. It is owned by the
43  /// observer and shared with the value owner.
44  std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
45  /// Because the `Owner` instance that lives inside `PerInstanceState`
46  /// contains a reference to the double pointer, and likewise this class
47  /// contains a reference to the value, we need to synchronize destruction of
48  /// the TLC and the `PerInstanceState` to avoid racing. This weak pointer is
49  /// acquired during TLC destruction if the `PerInstanceState` hasn't entered
50  /// its destructor yet, and prevents it from happening.
51  std::weak_ptr<PerInstanceState> keepalive;
52  };
53 
54  /// This struct owns the cache entries. It contains a reference back to the
55  /// reference inside the cache so that it can be written to null to indicate
56  /// that the cache entry is invalidated. It needs to do this because
57  /// `perInstanceState` could get re-allocated to the same pointer and we don't
58  /// remove entries from the TLC when it is deallocated. Thus, we have to reset
59  /// the TLC entries to a starting state in case the `ThreadLocalCache` lives
60  /// shorter than the threads.
61  struct Owner {
62  /// Save a pointer to the reference and write it to the newly created entry.
63  Owner(Observer &observer)
64  : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
65  *observer.ptr = value.get();
66  }
67  ~Owner() {
68  if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
69  *ptr = nullptr;
70  }
71 
72  Owner(Owner &&) = default;
73  Owner &operator=(Owner &&) = default;
74 
75  std::unique_ptr<ValueT> value;
76  std::weak_ptr<ValueT *> ptrRef;
77  };
78 
79  // Keep a separate shared_ptr protected state that can be acquired atomically
80  // instead of using shared_ptr's for each value. This avoids a problem
81  // where the instance shared_ptr is locked() successfully, and then the
82  // ThreadLocalCache gets destroyed before remove() can be called successfully.
83  struct PerInstanceState {
84  /// Remove the given value entry. This is called when a thread local cache
85  /// is destructing but still contains references to values owned by the
86  /// `PerInstanceState`. Removal is required because it prevents writeback to
87  /// a pointer that was deallocated.
88  void remove(ValueT *value) {
89  // Erase the found value directly, because it is guaranteed to be in the
90  // list.
91  llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
92  auto it = llvm::find_if(instances, [&](Owner &instance) {
93  return instance.value.get() == value;
94  });
95  assert(it != instances.end() && "expected value to exist in cache");
96  instances.erase(it);
97  }
98 
99  /// Owning pointers to all of the values that have been constructed for this
100  /// object in the static cache.
101  SmallVector<Owner, 1> instances;
102 
103  /// A mutex used when a new thread instance has been added to the cache for
104  /// this object.
105  llvm::sys::SmartMutex<true> instanceMutex;
106  };
107 
108  /// The type used for the static thread_local cache. This is a map between an
109  /// instance of the non-static cache and a weak reference to an instance of
110  /// ValueT. We use a weak reference here so that the object can be destroyed
111  /// without needing to lock access to the cache itself.
112  struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
113  ~CacheType() {
114  // Remove the values of this cache that haven't already expired. This is
115  // required because if we don't remove them, they will contain a reference
116  // back to the data here that is being destroyed.
117  for (auto &[instance, observer] : *this)
118  if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
119  state->remove(*observer.ptr);
120  }
121 
122  /// Clear out any unused entries within the map. This method is not
123  /// thread-safe, and should only be called by the same thread as the cache.
124  void clearExpiredEntries() {
125  for (auto it = this->begin(), e = this->end(); it != e;) {
126  auto curIt = it++;
127  if (!*curIt->second.ptr)
128  this->erase(curIt);
129  }
130  }
131  };
132 
133 public:
134  ThreadLocalCache() = default;
136  // No cleanup is necessary here as the shared_pointer memory will go out of
137  // scope and invalidate the weak pointers held by the thread_local caches.
138  }
139 
140  /// Return an instance of the value type for the current thread.
141  ValueT &get() {
142  // Check for an already existing instance for this thread.
143  CacheType &staticCache = getStaticCache();
144  Observer &threadInstance = staticCache[perInstanceState.get()];
145  if (ValueT *value = *threadInstance.ptr)
146  return *value;
147 
148  // Otherwise, create a new instance for this thread.
149  {
150  llvm::sys::SmartScopedLock<true> threadInstanceLock(
151  perInstanceState->instanceMutex);
152  perInstanceState->instances.emplace_back(threadInstance);
153  }
154  threadInstance.keepalive = perInstanceState;
155 
156  // Before returning the new instance, take the chance to clear out any used
157  // entries in the static map. The cache is only cleared within the same
158  // thread to remove the need to lock the cache itself.
159  staticCache.clearExpiredEntries();
160  return **threadInstance.ptr;
161  }
162  ValueT &operator*() { return get(); }
163  ValueT *operator->() { return &get(); }
164 
165 private:
166  ThreadLocalCache(ThreadLocalCache &&) = delete;
167  ThreadLocalCache(const ThreadLocalCache &) = delete;
168  ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
169 
170  /// Return the static thread local instance of the cache type.
171  static CacheType &getStaticCache() {
172  static LLVM_THREAD_LOCAL CacheType cache;
173  return cache;
174  }
175 
176  std::shared_ptr<PerInstanceState> perInstanceState =
177  std::make_shared<PerInstanceState>();
178 };
179 } // namespace mlir
180 
181 #endif // MLIR_SUPPORT_THREADLOCALCACHE_H
This class provides support for defining a thread local object with non static storage duration.
ValueT & get()
Return an instance of the value type for the current thread.
Include the generated interface declarations.