MLIR 22.0.0git
AsyncRuntime.cpp
Go to the documentation of this file.
1//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements basic Async runtime API for supporting Async dialect
10// to LLVM dialect lowering.
11//
12//===----------------------------------------------------------------------===//
13
15
16#include <atomic>
17#include <cassert>
18#include <condition_variable>
19#include <functional>
20#include <iostream>
21#include <mutex>
22#include <thread>
23#include <vector>
24
25#include "llvm/ADT/StringMap.h"
26#include "llvm/Support/ThreadPool.h"
27
28using namespace mlir::runtime;
29
30//===----------------------------------------------------------------------===//
31// Async runtime API.
32//===----------------------------------------------------------------------===//
33
34namespace mlir {
35namespace runtime {
36namespace {
37
38// Forward declare class defined below.
39class RefCounted;
40
41// -------------------------------------------------------------------------- //
42// AsyncRuntime orchestrates all async operations and Async runtime API is built
43// on top of the default runtime instance.
44// -------------------------------------------------------------------------- //
45
46class AsyncRuntime {
47public:
48 AsyncRuntime() : numRefCountedObjects(0) {}
49
50 ~AsyncRuntime() {
51 threadPool.wait(); // wait for the completion of all async tasks
52 assert(getNumRefCountedObjects() == 0 &&
53 "all ref counted objects must be destroyed");
54 }
55
56 int64_t getNumRefCountedObjects() {
57 return numRefCountedObjects.load(std::memory_order_relaxed);
58 }
59
60 llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61
62private:
63 friend class RefCounted;
64
65 // Count the total number of reference counted objects in this instance
66 // of an AsyncRuntime. For debugging purposes only.
67 void addNumRefCountedObjects() {
68 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
69 }
70 void dropNumRefCountedObjects() {
71 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
72 }
73
74 std::atomic<int64_t> numRefCountedObjects;
75 llvm::DefaultThreadPool threadPool;
76};
77
78// -------------------------------------------------------------------------- //
79// A state of the async runtime value (token, value or group).
80// -------------------------------------------------------------------------- //
81
82class State {
83public:
84 enum StateEnum : int8_t {
85 // The underlying value is not yet available for consumption.
86 kUnavailable = 0,
87 // The underlying value is available for consumption. This state can not
88 // transition to any other state.
89 kAvailable = 1,
90 // This underlying value is available and contains an error. This state can
91 // not transition to any other state.
92 kError = 2,
93 };
94
95 /* implicit */ State(StateEnum s) : state(s) {}
96 /* implicit */ operator StateEnum() { return state; }
97
98 bool isUnavailable() const { return state == kUnavailable; }
99 bool isAvailable() const { return state == kAvailable; }
100 bool isError() const { return state == kError; }
101 bool isAvailableOrError() const { return isAvailable() || isError(); }
102
103 const char *debug() const {
104 switch (state) {
105 case kUnavailable:
106 return "unavailable";
107 case kAvailable:
108 return "available";
109 case kError:
110 return "error";
111 }
112 }
113
114private:
115 StateEnum state;
116};
117
118// -------------------------------------------------------------------------- //
119// A base class for all reference counted objects created by the async runtime.
120// -------------------------------------------------------------------------- //
121
122class RefCounted {
123public:
124 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125 : runtime(runtime), refCount(refCount) {
126 runtime->addNumRefCountedObjects();
127 }
128
129 virtual ~RefCounted() {
130 assert(refCount.load() == 0 && "reference count must be zero");
131 runtime->dropNumRefCountedObjects();
132 }
133
134 RefCounted(const RefCounted &) = delete;
135 RefCounted &operator=(const RefCounted &) = delete;
136
137 void addRef(int64_t count = 1) { refCount.fetch_add(count); }
138
139 void dropRef(int64_t count = 1) {
140 int64_t previous = refCount.fetch_sub(count);
141 assert(previous >= count && "reference count should not go below zero");
142 if (previous == count)
143 destroy();
144 }
145
146protected:
147 virtual void destroy() { delete this; }
148
149private:
150 AsyncRuntime *runtime;
151 std::atomic<int64_t> refCount;
152};
153
154} // namespace
155
156// Returns the default per-process instance of an async runtime.
157static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
158 static auto runtime = std::make_unique<AsyncRuntime>();
159 return runtime;
160}
161
163 return getDefaultAsyncRuntimeInstance().reset();
164}
165
166static AsyncRuntime *getDefaultAsyncRuntime() {
167 return getDefaultAsyncRuntimeInstance().get();
168}
169
170// Async token provides a mechanism to signal asynchronous operation completion.
171struct AsyncToken : public RefCounted {
172 // AsyncToken created with a reference count of 2 because it will be returned
173 // to the `async.execute` caller and also will be later on emplaced by the
174 // asynchronously executed task. If the caller immediately will drop its
175 // reference we must ensure that the token will be alive until the
176 // asynchronous operation is completed.
177 AsyncToken(AsyncRuntime *runtime)
178 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179
180 std::atomic<State::StateEnum> state;
181
182 // Pending awaiters are guarded by a mutex.
183 std::mutex mu;
184 std::condition_variable cv;
185 std::vector<std::function<void()>> awaiters;
186};
187
188// Async value provides a mechanism to access the result of asynchronous
189// operations. It owns the storage that is used to store/load the value of the
190// underlying type, and a flag to signal if the value is ready or not.
191struct AsyncValue : public RefCounted {
192 // AsyncValue similar to an AsyncToken created with a reference count of 2.
193 AsyncValue(AsyncRuntime *runtime, int64_t size)
194 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
195 storage(size) {}
196
197 std::atomic<State::StateEnum> state;
198
199 // Use vector of bytes to store async value payload.
200 std::vector<std::byte> storage;
201
202 // Pending awaiters are guarded by a mutex.
203 std::mutex mu;
204 std::condition_variable cv;
205 std::vector<std::function<void()>> awaiters;
206};
207
208// Async group provides a mechanism to group together multiple async tokens or
209// values to await on all of them together (wait for the completion of all
210// tokens or values added to the group).
211struct AsyncGroup : public RefCounted {
212 AsyncGroup(AsyncRuntime *runtime, int64_t size)
213 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214
215 std::atomic<int> pendingTokens;
216 std::atomic<int> numErrors;
217 std::atomic<int> rank;
218
219 // Pending awaiters are guarded by a mutex.
220 std::mutex mu;
221 std::condition_variable cv;
222 std::vector<std::function<void()>> awaiters;
223};
224
225// Adds references to reference counted runtime object.
227 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228 refCounted->addRef(count);
229}
230
231// Drops references from reference counted runtime object.
233 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234 refCounted->dropRef(count);
235}
236
237// Creates a new `async.token` in not-ready state.
240 return token;
241}
242
243// Creates a new `async.value` in not-ready state.
245 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246 return value;
247}
248
249// Create a new `async.group` in empty state.
251 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252 return group;
253}
254
256 AsyncGroup *group) {
257 std::unique_lock<std::mutex> lockToken(token->mu);
258 std::unique_lock<std::mutex> lockGroup(group->mu);
259
260 // Get the rank of the token inside the group before we drop the reference.
261 int rank = group->rank.fetch_add(1);
262
263 auto onTokenReady = [group, token]() {
264 // Increment the number of errors in the group.
265 if (State(token->state).isError())
266 group->numErrors.fetch_add(1);
267
268 // If pending tokens go below zero it means that more tokens than the group
269 // size were added to this group.
270 assert(group->pendingTokens > 0 && "wrong group size");
271
272 // Run all group awaiters if it was the last token in the group.
273 if (group->pendingTokens.fetch_sub(1) == 1) {
274 group->cv.notify_all();
275 for (auto &awaiter : group->awaiters)
276 awaiter();
277 }
278 };
279
280 if (State(token->state).isAvailableOrError()) {
281 // Update group pending tokens immediately and maybe run awaiters.
282 onTokenReady();
283
284 } else {
285 // Update group pending tokens when token will become ready. Because this
286 // will happen asynchronously we must ensure that `group` is alive until
287 // then, and re-ackquire the lock.
288 group->addRef();
289
290 token->awaiters.emplace_back([group, onTokenReady]() {
291 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
292 {
293 std::unique_lock<std::mutex> lockGroup(group->mu);
294 onTokenReady();
295 }
296 group->dropRef();
297 });
298 }
299
300 return rank;
301}
302
303// Switches `async.token` to available or error state (terminatl state) and runs
304// all awaiters.
305static void setTokenState(AsyncToken *token, State state) {
306 assert(state.isAvailableOrError() && "must be terminal state");
307 assert(State(token->state).isUnavailable() && "token must be unavailable");
308
309 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
310 {
311 std::unique_lock<std::mutex> lock(token->mu);
312 token->state = state;
313 token->cv.notify_all();
314 for (auto &awaiter : token->awaiters)
315 awaiter();
316 }
317
318 // Async tokens created with a ref count `2` to keep token alive until the
319 // async task completes. Drop this reference explicitly when token emplaced.
320 token->dropRef();
321}
322
323static void setValueState(AsyncValue *value, State state) {
324 assert(state.isAvailableOrError() && "must be terminal state");
325 assert(State(value->state).isUnavailable() && "value must be unavailable");
326
327 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328 {
329 std::unique_lock<std::mutex> lock(value->mu);
330 value->state = state;
331 value->cv.notify_all();
332 for (auto &awaiter : value->awaiters)
333 awaiter();
334 }
335
336 // Async values created with a ref count `2` to keep value alive until the
337 // async task completes. Drop this reference explicitly when value emplaced.
338 value->dropRef();
339}
340
342 setTokenState(token, State::kAvailable);
343}
344
346 setValueState(value, State::kAvailable);
347}
348
350 setTokenState(token, State::kError);
351}
352
354 setValueState(value, State::kError);
355}
356
358 return State(token->state).isError();
359}
360
362 return State(value->state).isError();
363}
364
366 return group->numErrors.load() > 0;
367}
368
369extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
370 std::unique_lock<std::mutex> lock(token->mu);
371 if (!State(token->state).isAvailableOrError())
372 token->cv.wait(
373 lock, [token] { return State(token->state).isAvailableOrError(); });
374}
375
376extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377 std::unique_lock<std::mutex> lock(value->mu);
378 if (!State(value->state).isAvailableOrError())
379 value->cv.wait(
380 lock, [value] { return State(value->state).isAvailableOrError(); });
381}
382
384 std::unique_lock<std::mutex> lock(group->mu);
385 if (group->pendingTokens != 0)
386 group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
387}
388
389// Returns a pointer to the storage owned by the async value.
391 assert(!State(value->state).isError() && "unexpected error state");
392 return value->storage.data();
393}
394
395extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
397 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
398}
399
401 CoroHandle handle,
402 CoroResume resume) {
403 auto execute = [handle, resume]() { (*resume)(handle); };
404 std::unique_lock<std::mutex> lock(token->mu);
405 if (State(token->state).isAvailableOrError()) {
406 lock.unlock();
407 execute();
408 } else {
409 token->awaiters.emplace_back([execute]() { execute(); });
410 }
411}
412
414 CoroHandle handle,
415 CoroResume resume) {
416 auto execute = [handle, resume]() { (*resume)(handle); };
417 std::unique_lock<std::mutex> lock(value->mu);
418 if (State(value->state).isAvailableOrError()) {
419 lock.unlock();
420 execute();
421 } else {
422 value->awaiters.emplace_back([execute]() { execute(); });
423 }
424}
425
427 CoroHandle handle,
428 CoroResume resume) {
429 auto execute = [handle, resume]() { (*resume)(handle); };
430 std::unique_lock<std::mutex> lock(group->mu);
431 if (group->pendingTokens == 0) {
432 lock.unlock();
433 execute();
434 } else {
435 group->awaiters.emplace_back([execute]() { execute(); });
436 }
437}
438
440 return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441}
442
443//===----------------------------------------------------------------------===//
444// Small async runtime support library for testing.
445//===----------------------------------------------------------------------===//
446
448 static thread_local std::thread::id thisId = std::this_thread::get_id();
449 std::cout << "Current thread id: " << thisId << '\n';
450}
451
452//===----------------------------------------------------------------------===//
453// MLIR ExecutionEngine dynamic library integration.
454//===----------------------------------------------------------------------===//
455
456// Visual Studio had a bug that fails to compile nested generic lambdas
457// inside an `extern "C"` function.
458// https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459// The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460// a work around for older versions of Visual Studio.
461// NOLINTNEXTLINE(*-identifier-naming): externally called.
462extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
463__mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464
465// NOLINTNEXTLINE(*-identifier-naming): externally called.
466void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
467 auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
468 assert(exportSymbols.count(name) == 0 && "symbol already exists");
469 exportSymbols[name] = reinterpret_cast<void *>(ptr);
470 };
471
472 exportSymbol("mlirAsyncRuntimeAddRef",
474 exportSymbol("mlirAsyncRuntimeDropRef",
476 exportSymbol("mlirAsyncRuntimeExecute",
478 exportSymbol("mlirAsyncRuntimeGetValueStorage",
480 exportSymbol("mlirAsyncRuntimeCreateToken",
482 exportSymbol("mlirAsyncRuntimeCreateValue",
484 exportSymbol("mlirAsyncRuntimeEmplaceToken",
486 exportSymbol("mlirAsyncRuntimeEmplaceValue",
488 exportSymbol("mlirAsyncRuntimeSetTokenError",
490 exportSymbol("mlirAsyncRuntimeSetValueError",
492 exportSymbol("mlirAsyncRuntimeIsTokenError",
494 exportSymbol("mlirAsyncRuntimeIsValueError",
496 exportSymbol("mlirAsyncRuntimeIsGroupError",
498 exportSymbol("mlirAsyncRuntimeAwaitToken",
500 exportSymbol("mlirAsyncRuntimeAwaitValue",
502 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
504 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
506 exportSymbol("mlirAsyncRuntimeCreateGroup",
508 exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
510 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
512 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
514 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
516 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
518}
519
520// NOLINTNEXTLINE(*-identifier-naming): externally called.
524
525} // namespace runtime
526} // namespace mlir
#define MLIR_ASYNC_RUNTIME_EXPORT
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume)
static void resetDefaultAsyncRuntime()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsValueError(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT AsyncValue * mlirAsyncRuntimeCreateValue(int64_t)
MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimePrintCurrentThreadId()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeEmplaceToken(AsyncToken *)
struct AsyncToken AsyncToken
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitToken(AsyncToken *)
static AsyncRuntime * getDefaultAsyncRuntime()
void * CoroHandle
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle, CoroResume)
std::byte * ValueStorage
struct AsyncGroup AsyncGroup
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeEmplaceValue(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_init(llvm::StringMap< void * > &exportSymbols)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAddRef(RefCountedObjPtr, int64_t)
MLIR_ASYNC_RUNTIME_EXPORT int64_t mlirAsyncRuntimGetNumWorkerThreads()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeSetTokenError(AsyncToken *)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsTokenError(AsyncToken *)
struct AsyncValue AsyncValue
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeSetValueError(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeDropRef(RefCountedObjPtr, int64_t)
MLIR_ASYNC_RUNTIME_EXPORT AsyncToken * mlirAsyncRuntimeCreateToken()
MLIR_ASYNC_RUNTIME_EXPORT int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsGroupError(AsyncGroup *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *)
static void setTokenState(AsyncToken *token, State state)
static std::unique_ptr< AsyncRuntime > & getDefaultAsyncRuntimeInstance()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitValue(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT AsyncGroup * mlirAsyncRuntimeCreateGroup(int64_t size)
void * RefCountedObjPtr
void(*)(void *) CoroResume
MLIR_ASYNC_RUNTIME_EXPORT ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *)
static void setValueState(AsyncValue *value, State state)
Include the generated interface declarations.
AsyncGroup(AsyncRuntime *runtime, int64_t size)
std::atomic< int > numErrors
std::atomic< int > rank
std::atomic< int > pendingTokens
std::condition_variable cv
std::vector< std::function< void()> > awaiters
AsyncToken(AsyncRuntime *runtime)
std::atomic< State::StateEnum > state
std::vector< std::function< void()> > awaiters
std::condition_variable cv
AsyncValue(AsyncRuntime *runtime, int64_t size)
std::vector< std::byte > storage
std::vector< std::function< void()> > awaiters
std::atomic< State::StateEnum > state
std::condition_variable cv