18 #include <condition_variable>
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/ThreadPool.h"
48 AsyncRuntime() : numRefCountedObjects(0) {}
52 assert(getNumRefCountedObjects() == 0 &&
53 "all ref counted objects must be destroyed");
56 int64_t getNumRefCountedObjects() {
57 return numRefCountedObjects.load(std::memory_order_relaxed);
60 llvm::ThreadPoolInterface &getThreadPool() {
return threadPool; }
63 friend class RefCounted;
67 void addNumRefCountedObjects() {
68 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
70 void dropNumRefCountedObjects() {
71 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74 std::atomic<int64_t> numRefCountedObjects;
75 llvm::DefaultThreadPool threadPool;
84 enum StateEnum : int8_t {
95 State(StateEnum s) : state(s) {}
96 operator StateEnum() {
return state; }
98 bool isUnavailable()
const {
return state == kUnavailable; }
99 bool isAvailable()
const {
return state == kAvailable; }
100 bool isError()
const {
return state == kError; }
101 bool isAvailableOrError()
const {
return isAvailable() || isError(); }
103 const char *debug()
const {
106 return "unavailable";
124 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125 : runtime(runtime), refCount(refCount) {
126 runtime->addNumRefCountedObjects();
129 virtual ~RefCounted() {
130 assert(refCount.load() == 0 &&
"reference count must be zero");
131 runtime->dropNumRefCountedObjects();
134 RefCounted(
const RefCounted &) =
delete;
135 RefCounted &operator=(
const RefCounted &) =
delete;
137 void addRef(int64_t count = 1) { refCount.fetch_add(count); }
139 void dropRef(int64_t count = 1) {
140 int64_t previous = refCount.fetch_sub(count);
141 assert(previous >= count &&
"reference count should not go below zero");
142 if (previous == count)
147 virtual void destroy() {
delete this; }
150 AsyncRuntime *runtime;
151 std::atomic<int64_t> refCount;
158 static auto runtime = std::make_unique<AsyncRuntime>();
178 : RefCounted(runtime, 2), state(State::kUnavailable) {}
180 std::atomic<State::StateEnum>
state;
184 std::condition_variable
cv;
194 : RefCounted(runtime, 2), state(State::kUnavailable),
197 std::atomic<State::StateEnum>
state;
204 std::condition_variable
cv;
213 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
221 std::condition_variable
cv;
227 RefCounted *refCounted =
static_cast<RefCounted *
>(ptr);
228 refCounted->addRef(count);
233 RefCounted *refCounted =
static_cast<RefCounted *
>(ptr);
234 refCounted->dropRef(count);
257 std::unique_lock<std::mutex> lockToken(token->
mu);
258 std::unique_lock<std::mutex> lockGroup(group->
mu);
261 int rank = group->
rank.fetch_add(1);
263 auto onTokenReady = [group, token]() {
265 if (State(token->
state).isError())
274 group->
cv.notify_all();
275 for (
auto &awaiter : group->
awaiters)
280 if (State(token->
state).isAvailableOrError()) {
290 token->
awaiters.emplace_back([group, onTokenReady]() {
293 std::unique_lock<std::mutex> lockGroup(group->
mu);
306 assert(state.isAvailableOrError() &&
"must be terminal state");
307 assert(State(token->
state).isUnavailable() &&
"token must be unavailable");
311 std::unique_lock<std::mutex> lock(token->
mu);
312 token->
state = state;
313 token->
cv.notify_all();
314 for (
auto &awaiter : token->
awaiters)
324 assert(state.isAvailableOrError() &&
"must be terminal state");
325 assert(State(value->
state).isUnavailable() &&
"value must be unavailable");
329 std::unique_lock<std::mutex> lock(value->
mu);
330 value->
state = state;
331 value->
cv.notify_all();
332 for (
auto &awaiter : value->
awaiters)
358 return State(token->
state).isError();
362 return State(value->
state).isError();
370 std::unique_lock<std::mutex> lock(token->
mu);
371 if (!State(token->
state).isAvailableOrError())
373 lock, [token] {
return State(token->
state).isAvailableOrError(); });
377 std::unique_lock<std::mutex> lock(value->
mu);
378 if (!State(value->
state).isAvailableOrError())
380 lock, [value] {
return State(value->
state).isAvailableOrError(); });
384 std::unique_lock<std::mutex> lock(group->
mu);
391 assert(!State(value->
state).isError() &&
"unexpected error state");
397 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
403 auto execute = [handle, resume]() { (*resume)(handle); };
404 std::unique_lock<std::mutex> lock(token->
mu);
405 if (State(token->
state).isAvailableOrError()) {
409 token->
awaiters.emplace_back([execute]() { execute(); });
416 auto execute = [handle, resume]() { (*resume)(handle); };
417 std::unique_lock<std::mutex> lock(value->
mu);
418 if (State(value->
state).isAvailableOrError()) {
422 value->
awaiters.emplace_back([execute]() { execute(); });
429 auto execute = [handle, resume]() { (*resume)(handle); };
430 std::unique_lock<std::mutex> lock(group->
mu);
435 group->
awaiters.emplace_back([execute]() { execute(); });
448 static thread_local std::thread::id thisId = std::this_thread::get_id();
449 std::cout <<
"Current thread id: " << thisId <<
'\n';
467 auto exportSymbol = [&](llvm::StringRef name,
auto ptr) {
468 assert(exportSymbols.count(name) == 0 &&
"symbol already exists");
469 exportSymbols[name] =
reinterpret_cast<void *
>(ptr);
472 exportSymbol(
"mlirAsyncRuntimeAddRef",
474 exportSymbol(
"mlirAsyncRuntimeDropRef",
476 exportSymbol(
"mlirAsyncRuntimeExecute",
478 exportSymbol(
"mlirAsyncRuntimeGetValueStorage",
480 exportSymbol(
"mlirAsyncRuntimeCreateToken",
482 exportSymbol(
"mlirAsyncRuntimeCreateValue",
484 exportSymbol(
"mlirAsyncRuntimeEmplaceToken",
486 exportSymbol(
"mlirAsyncRuntimeEmplaceValue",
488 exportSymbol(
"mlirAsyncRuntimeSetTokenError",
490 exportSymbol(
"mlirAsyncRuntimeSetValueError",
492 exportSymbol(
"mlirAsyncRuntimeIsTokenError",
494 exportSymbol(
"mlirAsyncRuntimeIsValueError",
496 exportSymbol(
"mlirAsyncRuntimeIsGroupError",
498 exportSymbol(
"mlirAsyncRuntimeAwaitToken",
500 exportSymbol(
"mlirAsyncRuntimeAwaitValue",
502 exportSymbol(
"mlirAsyncRuntimeAwaitTokenAndExecute",
504 exportSymbol(
"mlirAsyncRuntimeAwaitValueAndExecute",
506 exportSymbol(
"mlirAsyncRuntimeCreateGroup",
508 exportSymbol(
"mlirAsyncRuntimeAddTokenToGroup",
510 exportSymbol(
"mlirAsyncRuntimeAwaitAllInGroup",
512 exportSymbol(
"mlirAsyncRuntimeAwaitAllInGroupAndExecute",
514 exportSymbol(
"mlirAsyncRuntimGetNumWorkerThreads",
516 exportSymbol(
"mlirAsyncRuntimePrintCurrentThreadId",
#define MLIR_ASYNC_RUNTIME_EXPORT
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume)
struct AsyncValue AsyncValue
static void resetDefaultAsyncRuntime()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsValueError(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT AsyncValue * mlirAsyncRuntimeCreateValue(int64_t)
MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimePrintCurrentThreadId()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeEmplaceToken(AsyncToken *)
void(*)(void *) CoroResume
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitToken(AsyncToken *)
struct AsyncToken AsyncToken
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeEmplaceValue(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume)
MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_init(llvm::StringMap< void * > &exportSymbols)
struct AsyncGroup AsyncGroup
static std::unique_ptr< AsyncRuntime > & getDefaultAsyncRuntimeInstance()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAddRef(RefCountedObjPtr, int64_t)
MLIR_ASYNC_RUNTIME_EXPORT int64_t mlirAsyncRuntimGetNumWorkerThreads()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeSetTokenError(AsyncToken *)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsTokenError(AsyncToken *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeSetValueError(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeDropRef(RefCountedObjPtr, int64_t)
MLIR_ASYNC_RUNTIME_EXPORT AsyncToken * mlirAsyncRuntimeCreateToken()
MLIR_ASYNC_RUNTIME_EXPORT int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *)
MLIR_ASYNC_RUNTIME_EXPORT bool mlirAsyncRuntimeIsGroupError(AsyncGroup *)
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *)
static void setTokenState(AsyncToken *token, State state)
static AsyncRuntime * getDefaultAsyncRuntime()
MLIR_ASYNC_RUNTIME_EXPORT void mlirAsyncRuntimeAwaitValue(AsyncValue *)
MLIR_ASYNC_RUNTIME_EXPORT AsyncGroup * mlirAsyncRuntimeCreateGroup(int64_t size)
MLIR_ASYNC_RUNTIME_EXPORT ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *)
static void setValueState(AsyncValue *value, State state)
Include the generated interface declarations.
AsyncGroup(AsyncRuntime *runtime, int64_t size)
std::atomic< int > numErrors
std::atomic< int > pendingTokens
std::condition_variable cv
std::vector< std::function< void()> > awaiters
AsyncToken(AsyncRuntime *runtime)
std::atomic< State::StateEnum > state
std::vector< std::function< void()> > awaiters
std::condition_variable cv
AsyncValue(AsyncRuntime *runtime, int64_t size)
std::vector< std::byte > storage
std::vector< std::function< void()> > awaiters
std::atomic< State::StateEnum > state
std::condition_variable cv