MLIR  19.0.0git
AsyncRuntime.cpp
Go to the documentation of this file.
1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include <atomic>
17 #include <cassert>
18 #include <condition_variable>
19 #include <functional>
20 #include <iostream>
21 #include <mutex>
22 #include <thread>
23 #include <vector>
24 
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/ThreadPool.h"
27 
28 using namespace mlir::runtime;
29 
30 //===----------------------------------------------------------------------===//
31 // Async runtime API.
32 //===----------------------------------------------------------------------===//
33 
34 namespace mlir {
35 namespace runtime {
36 namespace {
37 
38 // Forward declare class defined below.
39 class RefCounted;
40 
41 // -------------------------------------------------------------------------- //
42 // AsyncRuntime orchestrates all async operations and Async runtime API is built
43 // on top of the default runtime instance.
44 // -------------------------------------------------------------------------- //
45 
46 class AsyncRuntime {
47 public:
48  AsyncRuntime() : numRefCountedObjects(0) {}
49 
50  ~AsyncRuntime() {
51  threadPool.wait(); // wait for the completion of all async tasks
52  assert(getNumRefCountedObjects() == 0 &&
53  "all ref counted objects must be destroyed");
54  }
55 
56  int64_t getNumRefCountedObjects() {
57  return numRefCountedObjects.load(std::memory_order_relaxed);
58  }
59 
60  llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61 
62 private:
63  friend class RefCounted;
64 
65  // Count the total number of reference counted objects in this instance
66  // of an AsyncRuntime. For debugging purposes only.
67  void addNumRefCountedObjects() {
68  numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
69  }
70  void dropNumRefCountedObjects() {
71  numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
72  }
73 
74  std::atomic<int64_t> numRefCountedObjects;
75  llvm::DefaultThreadPool threadPool;
76 };
77 
78 // -------------------------------------------------------------------------- //
79 // A state of the async runtime value (token, value or group).
80 // -------------------------------------------------------------------------- //
81 
82 class State {
83 public:
84  enum StateEnum : int8_t {
85  // The underlying value is not yet available for consumption.
86  kUnavailable = 0,
87  // The underlying value is available for consumption. This state can not
88  // transition to any other state.
89  kAvailable = 1,
90  // This underlying value is available and contains an error. This state can
91  // not transition to any other state.
92  kError = 2,
93  };
94 
95  /* implicit */ State(StateEnum s) : state(s) {}
96  /* implicit */ operator StateEnum() { return state; }
97 
98  bool isUnavailable() const { return state == kUnavailable; }
99  bool isAvailable() const { return state == kAvailable; }
100  bool isError() const { return state == kError; }
101  bool isAvailableOrError() const { return isAvailable() || isError(); }
102 
103  const char *debug() const {
104  switch (state) {
105  case kUnavailable:
106  return "unavailable";
107  case kAvailable:
108  return "available";
109  case kError:
110  return "error";
111  }
112  }
113 
114 private:
115  StateEnum state;
116 };
117 
118 // -------------------------------------------------------------------------- //
119 // A base class for all reference counted objects created by the async runtime.
120 // -------------------------------------------------------------------------- //
121 
122 class RefCounted {
123 public:
124  RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125  : runtime(runtime), refCount(refCount) {
126  runtime->addNumRefCountedObjects();
127  }
128 
129  virtual ~RefCounted() {
130  assert(refCount.load() == 0 && "reference count must be zero");
131  runtime->dropNumRefCountedObjects();
132  }
133 
134  RefCounted(const RefCounted &) = delete;
135  RefCounted &operator=(const RefCounted &) = delete;
136 
137  void addRef(int64_t count = 1) { refCount.fetch_add(count); }
138 
139  void dropRef(int64_t count = 1) {
140  int64_t previous = refCount.fetch_sub(count);
141  assert(previous >= count && "reference count should not go below zero");
142  if (previous == count)
143  destroy();
144  }
145 
146 protected:
147  virtual void destroy() { delete this; }
148 
149 private:
150  AsyncRuntime *runtime;
151  std::atomic<int64_t> refCount;
152 };
153 
154 } // namespace
155 
156 // Returns the default per-process instance of an async runtime.
157 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
158  static auto runtime = std::make_unique<AsyncRuntime>();
159  return runtime;
160 }
161 
163  return getDefaultAsyncRuntimeInstance().reset();
164 }
165 
166 static AsyncRuntime *getDefaultAsyncRuntime() {
167  return getDefaultAsyncRuntimeInstance().get();
168 }
169 
170 // Async token provides a mechanism to signal asynchronous operation completion.
171 struct AsyncToken : public RefCounted {
172  // AsyncToken created with a reference count of 2 because it will be returned
173  // to the `async.execute` caller and also will be later on emplaced by the
174  // asynchronously executed task. If the caller immediately will drop its
175  // reference we must ensure that the token will be alive until the
176  // asynchronous operation is completed.
177  AsyncToken(AsyncRuntime *runtime)
178  : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179 
180  std::atomic<State::StateEnum> state;
181 
182  // Pending awaiters are guarded by a mutex.
183  std::mutex mu;
184  std::condition_variable cv;
185  std::vector<std::function<void()>> awaiters;
186 };
187 
188 // Async value provides a mechanism to access the result of asynchronous
189 // operations. It owns the storage that is used to store/load the value of the
190 // underlying type, and a flag to signal if the value is ready or not.
191 struct AsyncValue : public RefCounted {
192  // AsyncValue similar to an AsyncToken created with a reference count of 2.
193  AsyncValue(AsyncRuntime *runtime, int64_t size)
194  : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
195  storage(size) {}
196 
197  std::atomic<State::StateEnum> state;
198 
199  // Use vector of bytes to store async value payload.
200  std::vector<std::byte> storage;
201 
202  // Pending awaiters are guarded by a mutex.
203  std::mutex mu;
204  std::condition_variable cv;
205  std::vector<std::function<void()>> awaiters;
206 };
207 
208 // Async group provides a mechanism to group together multiple async tokens or
209 // values to await on all of them together (wait for the completion of all
210 // tokens or values added to the group).
211 struct AsyncGroup : public RefCounted {
212  AsyncGroup(AsyncRuntime *runtime, int64_t size)
213  : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214 
215  std::atomic<int> pendingTokens;
216  std::atomic<int> numErrors;
217  std::atomic<int> rank;
218 
219  // Pending awaiters are guarded by a mutex.
220  std::mutex mu;
221  std::condition_variable cv;
222  std::vector<std::function<void()>> awaiters;
223 };
224 
225 // Adds references to reference counted runtime object.
226 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227  RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228  refCounted->addRef(count);
229 }
230 
231 // Drops references from reference counted runtime object.
232 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233  RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234  refCounted->dropRef(count);
235 }
236 
237 // Creates a new `async.token` in not-ready state.
240  return token;
241 }
242 
243 // Creates a new `async.value` in not-ready state.
244 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
245  AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246  return value;
247 }
248 
249 // Create a new `async.group` in empty state.
250 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251  AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252  return group;
253 }
254 
255 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
256  AsyncGroup *group) {
257  std::unique_lock<std::mutex> lockToken(token->mu);
258  std::unique_lock<std::mutex> lockGroup(group->mu);
259 
260  // Get the rank of the token inside the group before we drop the reference.
261  int rank = group->rank.fetch_add(1);
262 
263  auto onTokenReady = [group, token]() {
264  // Increment the number of errors in the group.
265  if (State(token->state).isError())
266  group->numErrors.fetch_add(1);
267 
268  // If pending tokens go below zero it means that more tokens than the group
269  // size were added to this group.
270  assert(group->pendingTokens > 0 && "wrong group size");
271 
272  // Run all group awaiters if it was the last token in the group.
273  if (group->pendingTokens.fetch_sub(1) == 1) {
274  group->cv.notify_all();
275  for (auto &awaiter : group->awaiters)
276  awaiter();
277  }
278  };
279 
280  if (State(token->state).isAvailableOrError()) {
281  // Update group pending tokens immediately and maybe run awaiters.
282  onTokenReady();
283 
284  } else {
285  // Update group pending tokens when token will become ready. Because this
286  // will happen asynchronously we must ensure that `group` is alive until
287  // then, and re-ackquire the lock.
288  group->addRef();
289 
290  token->awaiters.emplace_back([group, onTokenReady]() {
291  // Make sure that `dropRef` does not destroy the mutex owned by the lock.
292  {
293  std::unique_lock<std::mutex> lockGroup(group->mu);
294  onTokenReady();
295  }
296  group->dropRef();
297  });
298  }
299 
300  return rank;
301 }
302 
303 // Switches `async.token` to available or error state (terminatl state) and runs
304 // all awaiters.
305 static void setTokenState(AsyncToken *token, State state) {
306  assert(state.isAvailableOrError() && "must be terminal state");
307  assert(State(token->state).isUnavailable() && "token must be unavailable");
308 
309  // Make sure that `dropRef` does not destroy the mutex owned by the lock.
310  {
311  std::unique_lock<std::mutex> lock(token->mu);
312  token->state = state;
313  token->cv.notify_all();
314  for (auto &awaiter : token->awaiters)
315  awaiter();
316  }
317 
318  // Async tokens created with a ref count `2` to keep token alive until the
319  // async task completes. Drop this reference explicitly when token emplaced.
320  token->dropRef();
321 }
322 
323 static void setValueState(AsyncValue *value, State state) {
324  assert(state.isAvailableOrError() && "must be terminal state");
325  assert(State(value->state).isUnavailable() && "value must be unavailable");
326 
327  // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328  {
329  std::unique_lock<std::mutex> lock(value->mu);
330  value->state = state;
331  value->cv.notify_all();
332  for (auto &awaiter : value->awaiters)
333  awaiter();
334  }
335 
336  // Async values created with a ref count `2` to keep value alive until the
337  // async task completes. Drop this reference explicitly when value emplaced.
338  value->dropRef();
339 }
340 
341 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
342  setTokenState(token, State::kAvailable);
343 }
344 
345 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
346  setValueState(value, State::kAvailable);
347 }
348 
349 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
350  setTokenState(token, State::kError);
351 }
352 
353 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
354  setValueState(value, State::kError);
355 }
356 
357 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
358  return State(token->state).isError();
359 }
360 
361 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
362  return State(value->state).isError();
363 }
364 
365 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366  return group->numErrors.load() > 0;
367 }
368 
369 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
370  std::unique_lock<std::mutex> lock(token->mu);
371  if (!State(token->state).isAvailableOrError())
372  token->cv.wait(
373  lock, [token] { return State(token->state).isAvailableOrError(); });
374 }
375 
376 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377  std::unique_lock<std::mutex> lock(value->mu);
378  if (!State(value->state).isAvailableOrError())
379  value->cv.wait(
380  lock, [value] { return State(value->state).isAvailableOrError(); });
381 }
382 
384  std::unique_lock<std::mutex> lock(group->mu);
385  if (group->pendingTokens != 0)
386  group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
387 }
388 
389 // Returns a pointer to the storage owned by the async value.
391  assert(!State(value->state).isError() && "unexpected error state");
392  return value->storage.data();
393 }
394 
395 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396  auto *runtime = getDefaultAsyncRuntime();
397  runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
398 }
399 
401  CoroHandle handle,
402  CoroResume resume) {
403  auto execute = [handle, resume]() { (*resume)(handle); };
404  std::unique_lock<std::mutex> lock(token->mu);
405  if (State(token->state).isAvailableOrError()) {
406  lock.unlock();
407  execute();
408  } else {
409  token->awaiters.emplace_back([execute]() { execute(); });
410  }
411 }
412 
414  CoroHandle handle,
415  CoroResume resume) {
416  auto execute = [handle, resume]() { (*resume)(handle); };
417  std::unique_lock<std::mutex> lock(value->mu);
418  if (State(value->state).isAvailableOrError()) {
419  lock.unlock();
420  execute();
421  } else {
422  value->awaiters.emplace_back([execute]() { execute(); });
423  }
424 }
425 
427  CoroHandle handle,
428  CoroResume resume) {
429  auto execute = [handle, resume]() { (*resume)(handle); };
430  std::unique_lock<std::mutex> lock(group->mu);
431  if (group->pendingTokens == 0) {
432  lock.unlock();
433  execute();
434  } else {
435  group->awaiters.emplace_back([execute]() { execute(); });
436  }
437 }
438 
440  return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // Small async runtime support library for testing.
445 //===----------------------------------------------------------------------===//
446 
448  static thread_local std::thread::id thisId = std::this_thread::get_id();
449  std::cout << "Current thread id: " << thisId << '\n';
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // MLIR ExecutionEngine dynamic library integration.
454 //===----------------------------------------------------------------------===//
455 
456 // Visual Studio had a bug that fails to compile nested generic lambdas
457 // inside an `extern "C"` function.
458 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460 // a work around for older versions of Visual Studio.
461 // NOLINTNEXTLINE(*-identifier-naming): externally called.
462 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
463 __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464 
465 // NOLINTNEXTLINE(*-identifier-naming): externally called.
466 void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
467  auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
468  assert(exportSymbols.count(name) == 0 && "symbol already exists");
469  exportSymbols[name] = reinterpret_cast<void *>(ptr);
470  };
471 
472  exportSymbol("mlirAsyncRuntimeAddRef",
474  exportSymbol("mlirAsyncRuntimeDropRef",
476  exportSymbol("mlirAsyncRuntimeExecute",
478  exportSymbol("mlirAsyncRuntimeGetValueStorage",
480  exportSymbol("mlirAsyncRuntimeCreateToken",
482  exportSymbol("mlirAsyncRuntimeCreateValue",
484  exportSymbol("mlirAsyncRuntimeEmplaceToken",
486  exportSymbol("mlirAsyncRuntimeEmplaceValue",
488  exportSymbol("mlirAsyncRuntimeSetTokenError",
490  exportSymbol("mlirAsyncRuntimeSetValueError",
492  exportSymbol("mlirAsyncRuntimeIsTokenError",
494  exportSymbol("mlirAsyncRuntimeIsValueError",
496  exportSymbol("mlirAsyncRuntimeIsGroupError",
498  exportSymbol("mlirAsyncRuntimeAwaitToken",
500  exportSymbol("mlirAsyncRuntimeAwaitValue",
502  exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
504  exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
506  exportSymbol("mlirAsyncRuntimeCreateGroup",
508  exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
510  exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
512  exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
514  exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
516  exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
518 }
519 
520 // NOLINTNEXTLINE(*-identifier-naming): externally called.
523 }
524 
525 } // namespace runtime
526 } // namespace mlir
#define MLIR_ASYNC_RUNTIME_EXPORT
Definition: AsyncRuntime.h:32
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume)
struct AsyncValue AsyncValue
Definition: AsyncRuntime.h:49
static void resetDefaultAsyncRuntime()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsValueError(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT AsyncValue * mlirAsyncRuntimeCreateValue(int64_t)
MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimePrintCurrentThreadId()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeEmplaceToken(AsyncToken *)
void(*)(void *) CoroResume
Definition: AsyncRuntime.h:58
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitToken(AsyncToken *)
struct AsyncToken AsyncToken
Definition: AsyncRuntime.h:43
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeEmplaceValue(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume)
std::byte * ValueStorage
Definition: AsyncRuntime.h:52
MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_init(llvm::StringMap< void * > &exportSymbols)
struct AsyncGroup AsyncGroup
Definition: AsyncRuntime.h:46
void * CoroHandle
Definition: AsyncRuntime.h:57
static std::unique_ptr< AsyncRuntime > & getDefaultAsyncRuntimeInstance()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAddRef(RefCountedObjPtr, int64_t)
MLIR_ASYNC_RUNTIME_EXPORT int64_t mlirAsyncRuntimGetNumWorkerThreads()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeSetTokenError(AsyncToken *)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsTokenError(AsyncToken *)
void * RefCountedObjPtr
Definition: AsyncRuntime.h:62
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeSetValueError(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeDropRef(RefCountedObjPtr, int64_t)
MLIR_ASYNC_RUNTIME_EXPORT AsyncToken * mlirAsyncRuntimeCreateToken()
MLIR_ASYNC_RUNTIME_EXPORT int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsGroupError(AsyncGroup *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *)
static void setTokenState(AsyncToken *token, State state)
static AsyncRuntime * getDefaultAsyncRuntime()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitValue(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT AsyncGroup * mlirAsyncRuntimeCreateGroup(int64_t size)
MLIR_ASYNC_RUNTIME_EXPORT ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *)
static void setValueState(AsyncValue *value, State state)
Include the generated interface declarations.
AsyncGroup(AsyncRuntime *runtime, int64_t size)
std::atomic< int > numErrors
std::atomic< int > rank
std::atomic< int > pendingTokens
std::condition_variable cv
std::vector< std::function< void()> > awaiters
AsyncToken(AsyncRuntime *runtime)
std::atomic< State::StateEnum > state
std::vector< std::function< void()> > awaiters
std::condition_variable cv
AsyncValue(AsyncRuntime *runtime, int64_t size)
std::vector< std::byte > storage
std::vector< std::function< void()> > awaiters
std::atomic< State::StateEnum > state
std::condition_variable cv