13#include "level_zero/ze_api.h"
23#include <unordered_set>
28auto catchAll(F &&func) {
31 }
catch (
const std::exception &e) {
32 std::cerr <<
"An exception was thrown: " << e.what() << std::endl;
35 std::cerr <<
"An unknown exception was thrown." << std::endl;
40#define L0_SAFE_CALL(call) \
42 ze_result_t status = (call); \
43 if (status != ZE_RESULT_SUCCESS) { \
44 const char *errorString; \
45 zeDriverGetLastErrorDescription(NULL, &errorString); \
46 std::cerr << "L0 error " << status << ": " << errorString << std::endl; \
59static ze_driver_handle_t
getDriver(uint32_t idx = 0) {
60 ze_init_driver_type_desc_t driver_type = {};
61 driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
62 driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
63 driver_type.pNext =
nullptr;
64 uint32_t driverCount{0};
65 thread_local static std::vector<ze_driver_handle_t> drivers;
66 thread_local static bool isDriverInitialised{
false};
67 if (isDriverInitialised && idx < drivers.size())
69 L0_SAFE_CALL(zeInitDrivers(&driverCount,
nullptr, &driver_type));
71 throw std::runtime_error(
"No L0 drivers found.");
72 drivers.resize(driverCount);
73 L0_SAFE_CALL(zeInitDrivers(&driverCount, drivers.data(), &driver_type));
74 if (idx >= driverCount)
75 throw std::runtime_error(std::string(
"Requested driver idx out-of-bound, "
76 "number of availabe drivers: ") +
77 std::to_string(driverCount));
78 isDriverInitialised =
true;
82static ze_device_handle_t
getDevice(
const uint32_t driverIdx = 0,
83 const int32_t devIdx = 0) {
84 thread_local static ze_device_handle_t l0Device;
85 thread_local int32_t currDevIdx{-1};
86 thread_local uint32_t currDriverIdx{0};
87 if (currDriverIdx == driverIdx && currDevIdx == devIdx)
90 uint32_t deviceCount{0};
91 L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount,
nullptr));
93 throw std::runtime_error(
"getDevice failed: did not find L0 device.");
94 if (
static_cast<int>(deviceCount) < devIdx + 1)
95 throw std::runtime_error(
"getDevice failed: devIdx out-of-bounds.");
96 std::vector<ze_device_handle_t> devices(deviceCount);
97 L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));
98 l0Device = devices[devIdx];
99 currDriverIdx = driverIdx;
105static ze_context_handle_t
getContext(ze_driver_handle_t driver) {
106 thread_local static ze_context_handle_t context;
107 thread_local static bool isContextInitialised{
false};
108 if (isContextInitialised)
110 ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC,
nullptr, 0};
111 L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context));
112 isContextInitialised =
true;
134 std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
137 std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
159 uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u;
160 ze_device_properties_t deviceProperties{};
162 uint32_t queueGroupCount = 0;
164 device, &queueGroupCount,
nullptr));
165 std::vector<ze_command_queue_group_properties_t> queueGroupProperties(
168 device, &queueGroupCount, queueGroupProperties.data()));
170 for (uint32_t queueGroupIdx = 0; queueGroupIdx < queueGroupCount;
172 const auto &group = queueGroupProperties[queueGroupIdx];
173 if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE)
174 computeEngineOrdinal = queueGroupIdx;
175 else if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) {
176 copyEngineOrdinal = queueGroupIdx;
179 if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u)
184 if (copyEngineOrdinal == -1u)
185 copyEngineOrdinal = computeEngineOrdinal;
187 assert(copyEngineOrdinal != -1u && computeEngineOrdinal != -1u &&
188 "Expected two engines to be available.");
191 ze_command_queue_desc_t cmdQueueDesc{
192 ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
197 ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
198 ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
200 ze_command_list_handle_t rawCmdListCopy =
nullptr;
202 &cmdQueueDesc, &rawCmdListCopy));
206 cmdQueueDesc.ordinal = computeEngineOrdinal;
207 ze_command_list_handle_t rawCmdListCompute =
nullptr;
235 std::unique_ptr<std::remove_pointer<ze_event_handle_t>::type,
238 std::unique_ptr<std::remove_pointer<ze_event_pool_handle_t>::type,
271 assert(
takenEvents.empty() &&
"Some events were not released");
275 ze_event_pool_desc_t eventPoolDesc = {};
276 eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
277 eventPoolDesc.count = numEvents;
279 ze_event_pool_handle_t rawPool =
nullptr;
281 &
rtCtx->device, &rawPool));
288 ze_event_handle_t rawEvent =
nullptr;
294 rawEvent = uniqueEvent.get();
298 throw std::runtime_error(
"DynamicEventPool: reached max events limit");
303 ze_event_desc_t eventDesc = {
304 ZE_STRUCTURE_TYPE_EVENT_DESC,
nullptr,
306 ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST};
308 ze_event_handle_t newEvent =
nullptr;
310 zeEventCreate(
eventPools.back().get(), &eventDesc, &newEvent));
323 "Attempting to release unknown or already released event");
354 void sync(ze_event_handle_t explicitEvent =
nullptr) {
355 ze_event_handle_t syncEvent{
nullptr};
356 if (!explicitEvent) {
358 syncEvent = lastImplicitEventPtr ? *lastImplicitEventPtr :
nullptr;
360 syncEvent = explicitEvent;
364 syncEvent, std::numeric_limits<uint64_t>::max()));
372 template <
typename Func>
374 ze_event_handle_t newImplicitEvent =
dynEventPool.takeEvent();
376 const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0;
377 std::forward<Func>(op)(newImplicitEvent, numWaitEvents,
378 lastImplicitEventPtr);
383static ze_module_handle_t
385 ze_module_format_t format = ZE_MODULE_FORMAT_NATIVE) {
387 ze_module_handle_t zeModule;
388 ze_module_desc_t desc = {
389 ZE_STRUCTURE_TYPE_MODULE_DESC,
nullptr, format, dataSize,
390 (
const uint8_t *)data,
nullptr,
nullptr};
392 ze_module_build_log_handle_t buildLogHandle;
395 &zeModule, &buildLogHandle);
396 if (
result != ZE_RESULT_SUCCESS) {
397 std::cerr <<
"Error creating module, error code: " <<
result << std::endl;
399 L0_SAFE_CALL(zeModuleBuildLogGetString(buildLogHandle, &logSize,
nullptr));
400 std::string buildLog(
" ", logSize);
402 zeModuleBuildLogGetString(buildLogHandle, &logSize, buildLog.data()));
403 std::cerr <<
"Build log:\n" << buildLog << std::endl;
425 ze_event_handle_t event) {
426 assert(stream &&
"Invalid stream");
427 assert(event &&
"Invalid event");
441 zeEventHostSynchronize(event, std::numeric_limits<uint64_t>::max()));
455 return catchAll([&]() {
456 void *memPtr =
nullptr;
457 constexpr size_t alignment{64};
458 ze_device_mem_alloc_desc_t deviceDesc = {};
459 deviceDesc.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC;
461 ze_host_mem_alloc_desc_t hostDesc = {};
462 hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC;
464 &hostDesc, size, alignment,
472 throw std::runtime_error(
"mem allocation failed!");
483extern "C" void mgpuMemcpy(
void *dst,
void *src,
size_t sizeBytes,
485 stream->
enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
486 ze_event_handle_t *waitEvents) {
489 numWaitEvents, waitEvents));
493template <
typename PATTERN_TYPE>
494static void mgpuMemset(
void *dst, PATTERN_TYPE value,
size_t count,
501 stream->
enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
502 ze_event_handle_t *waitEvents) {
504 listType, dst, &value,
sizeof(PATTERN_TYPE),
505 count *
sizeof(PATTERN_TYPE), newEvent, numWaitEvents, waitEvents));
508extern "C" void mgpuMemset32(
void *dst,
unsigned int value,
size_t count,
513extern "C" void mgpuMemset16(
void *dst,
unsigned short value,
size_t count,
519 size_t gpuBlobSize) {
520 return catchAll([&]() {
return loadModule(data, gpuBlobSize); });
524 return catchAll([&]() {
525 return loadModule(data, strlen(
reinterpret_cast<char *
>(data)),
526 ZE_MODULE_FORMAT_IL_SPIRV);
532 assert(module && name);
533 ze_kernel_handle_t zeKernel;
534 ze_kernel_desc_t desc = {};
535 desc.pKernelName = name;
541 size_t gridY,
size_t gridZ,
size_t blockX,
542 size_t blockY,
size_t blockZ,
544 void **params,
void ** ,
545 size_t paramsCount) {
547 if (sharedMemBytes > 0) {
548 paramsCount = paramsCount - 1;
550 zeKernelSetArgumentValue(kernel, paramsCount, sharedMemBytes,
nullptr));
552 for (
size_t i = 0; i < paramsCount; ++i)
553 L0_SAFE_CALL(zeKernelSetArgumentValue(kernel,
static_cast<uint32_t
>(i),
554 sizeof(
void *), params[i]));
555 L0_SAFE_CALL(zeKernelSetGroupSize(kernel, blockX, blockY, blockZ));
556 ze_group_count_t dispatch;
557 dispatch.groupCountX =
static_cast<uint32_t
>(gridX);
558 dispatch.groupCountY =
static_cast<uint32_t
>(gridY);
559 dispatch.groupCountZ =
static_cast<uint32_t
>(gridZ);
560 stream->
enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
561 ze_event_handle_t *waitEvents) {
564 numWaitEvents, waitEvents));
std::unique_ptr< std::remove_pointer< ze_event_handle_t >::type, ZeEventDeleter > UniqueZeEvent
void mgpuSetDefaultDevice(int32_t devIdx)
static L0RTContextWrapper & getRtContext()
static ze_module_handle_t loadModule(const void *data, size_t dataSize, ze_module_format_t format=ZE_MODULE_FORMAT_NATIVE)
void mgpuMemset16(void *dst, unsigned short value, size_t count, StreamWrapper *stream)
#define L0_SAFE_CALL(call)
static void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count, StreamWrapper *stream)
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static DynamicEventPool & getDynamicEventPool()
std::unique_ptr< std::remove_pointer< ze_context_handle_t >::type, ZeContextDeleter > UniqueZeContext
void * mgpuMemAlloc(uint64_t size, StreamWrapper *stream, bool isShared)
void mgpuStreamDestroy(StreamWrapper *stream)
ze_module_handle_t mgpuModuleLoad(const void *data, size_t gpuBlobSize)
void mgpuEventSynchronize(ze_event_handle_t event)
void mgpuModuleUnload(ze_module_handle_t module)
static ze_driver_handle_t getDriver(uint32_t idx=0)
void mgpuMemset32(void *dst, unsigned int value, size_t count, StreamWrapper *stream)
StreamWrapper * mgpuStreamCreate()
std::unique_ptr< std::remove_pointer< ze_command_list_handle_t >::type, ZeCommandListDeleter > UniqueZeCommandList
void mgpuEventDestroy(ze_event_handle_t event)
void mgpuStreamSynchronize(StreamWrapper *stream)
ze_kernel_handle_t mgpuModuleGetFunction(ze_module_handle_t module, const char *name)
ze_module_handle_t mgpuModuleLoadJIT(void *data, int optLevel)
void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, StreamWrapper *stream)
void mgpuStreamWaitEvent(StreamWrapper *stream, ze_event_handle_t event)
void mgpuMemFree(void *ptr, StreamWrapper *stream)
void mgpuEventRecord(ze_event_handle_t event, StreamWrapper *stream)
ze_event_handle_t mgpuEventCreate()
std::unique_ptr< std::remove_pointer< ze_event_pool_handle_t >::type, ZeEventPoolDeleter > UniqueZeEventPool
void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, StreamWrapper *stream, void **params, void **, size_t paramsCount)
void createNewPool(size_t numEvents)
L0RTContextWrapper * rtCtx
DynamicEventPool & operator=(const DynamicEventPool &)=delete
static constexpr size_t numEventsPerPool
void releaseEvent(ze_event_handle_t event)
DynamicEventPool(DynamicEventPool &&) noexcept=default
DynamicEventPool(L0RTContextWrapper *rtCtx)
std::vector< UniqueZeEventPool > eventPools
std::unordered_map< ze_event_handle_t, UniqueZeEvent > takenEvents
ze_event_handle_t takeEvent()
std::vector< UniqueZeEvent > availableEvents
DynamicEventPool(const DynamicEventPool &)=delete
size_t currentEventsLimit
UniqueZeCommandList immCmdListCopy
L0RTContextWrapper()=default
L0RTContextWrapper & operator=(const L0RTContextWrapper &)=delete
uint32_t copyEngineMaxMemoryFillPatternSize
ze_device_handle_t device
L0RTContextWrapper(L0RTContextWrapper &&) noexcept=default
L0RTContextWrapper(const uint32_t driverIdx=0, const int32_t devIdx=0)
ze_driver_handle_t driver
UniqueZeCommandList immCmdListCompute
L0RTContextWrapper(const L0RTContextWrapper &)=delete
ze_event_handle_t * getLastImplicitEventPtr()
void sync(ze_event_handle_t explicitEvent=nullptr)
StreamWrapper(DynamicEventPool &dynEventPool)
std::deque< ze_event_handle_t > implicitEventStack
DynamicEventPool & dynEventPool
void enqueueOp(Func &&op)
void operator()(ze_command_list_handle_t cmdList) const
void operator()(ze_context_handle_t ctx) const
void operator()(ze_event_handle_t event) const
void operator()(ze_event_pool_handle_t pool) const