13#include "llvm/ADT/Twine.h" 
   15#include "level_zero/ze_api.h" 
   22#include <unordered_set> 
   27auto catchAll(F &&func) {
 
   30  } 
catch (
const std::exception &e) {
 
   31    std::cerr << 
"An exception was thrown: " << e.what() << std::endl;
 
   34    std::cerr << 
"An unknown exception was thrown." << std::endl;
 
   39#define L0_SAFE_CALL(call)                                                     \ 
   41    ze_result_t status = (call);                                               \ 
   42    if (status != ZE_RESULT_SUCCESS) {                                         \ 
   43      const char *errorString;                                                 \ 
   44      zeDriverGetLastErrorDescription(NULL, &errorString);                     \ 
   45      std::cerr << "L0 error " << status << ": " << errorString << std::endl;  \ 
 
   58static ze_driver_handle_t 
getDriver(uint32_t idx = 0) {
 
   59  ze_init_driver_type_desc_t driver_type = {};
 
   60  driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
 
   61  driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
 
   62  driver_type.pNext = 
nullptr;
 
   63  uint32_t driverCount{0};
 
   64  thread_local static std::vector<ze_driver_handle_t> drivers;
 
   65  thread_local static bool isDriverInitialised{
false};
 
   66  if (isDriverInitialised && idx < drivers.size())
 
   68  L0_SAFE_CALL(zeInitDrivers(&driverCount, 
nullptr, &driver_type));
 
   70    throw std::runtime_error(
"No L0 drivers found.");
 
   71  drivers.resize(driverCount);
 
   72  L0_SAFE_CALL(zeInitDrivers(&driverCount, drivers.data(), &driver_type));
 
   73  if (idx >= driverCount)
 
   74    throw std::runtime_error((llvm::Twine(
"Requested driver idx out-of-bound, " 
   75                                          "number of availabe drivers: ") +
 
   76                              std::to_string(driverCount))
 
   78  isDriverInitialised = 
true;
 
 
   82static ze_device_handle_t 
getDevice(
const uint32_t driverIdx = 0,
 
   83                                    const int32_t devIdx = 0) {
 
   84  thread_local static ze_device_handle_t l0Device;
 
   85  thread_local int32_t currDevIdx{-1};
 
   86  thread_local uint32_t currDriverIdx{0};
 
   87  if (currDriverIdx == driverIdx && currDevIdx == devIdx)
 
   90  uint32_t deviceCount{0};
 
   91  L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, 
nullptr));
 
   93    throw std::runtime_error(
"getDevice failed: did not find L0 device.");
 
   94  if (
static_cast<int>(deviceCount) < devIdx + 1)
 
   95    throw std::runtime_error(
"getDevice failed: devIdx out-of-bounds.");
 
   96  std::vector<ze_device_handle_t> devices(deviceCount);
 
   97  L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));
 
   98  l0Device = devices[devIdx];
 
   99  currDriverIdx = driverIdx;
 
 
  105static ze_context_handle_t 
getContext(ze_driver_handle_t driver) {
 
  106  thread_local static ze_context_handle_t context;
 
  107  thread_local static bool isContextInitialised{
false};
 
  108  if (isContextInitialised)
 
  110  ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, 
nullptr, 0};
 
  111  L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context));
 
  112  isContextInitialised = 
true;
 
 
  134    std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
 
  137    std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
 
  159    uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u;
 
  160    ze_device_properties_t deviceProperties{};
 
  162    uint32_t queueGroupCount = 0;
 
  164        device, &queueGroupCount, 
nullptr));
 
  165    std::vector<ze_command_queue_group_properties_t> queueGroupProperties(
 
  168        device, &queueGroupCount, queueGroupProperties.data()));
 
  170    for (uint32_t queueGroupIdx = 0; queueGroupIdx < queueGroupCount;
 
  172      const auto &group = queueGroupProperties[queueGroupIdx];
 
  173      if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE)
 
  174        computeEngineOrdinal = queueGroupIdx;
 
  175      else if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) {
 
  176        copyEngineOrdinal = queueGroupIdx;
 
  179      if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u)
 
  184    if (copyEngineOrdinal == -1u)
 
  185      copyEngineOrdinal = computeEngineOrdinal;
 
  187    assert(copyEngineOrdinal != -1u && computeEngineOrdinal != -1u &&
 
  188           "Expected two engines to be available.");
 
  191    ze_command_queue_desc_t cmdQueueDesc{
 
  192        ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
 
  197        ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
 
  198        ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
 
  200    ze_command_list_handle_t rawCmdListCopy = 
nullptr;
 
  202                                              &cmdQueueDesc, &rawCmdListCopy));
 
  206    cmdQueueDesc.ordinal = computeEngineOrdinal;
 
  207    ze_command_list_handle_t rawCmdListCompute = 
nullptr;
 
 
 
  235    std::unique_ptr<std::remove_pointer<ze_event_handle_t>::type,
 
  238    std::unique_ptr<std::remove_pointer<ze_event_pool_handle_t>::type,
 
  271    assert(
takenEvents.empty() && 
"Some events were not released");
 
 
  275    ze_event_pool_desc_t eventPoolDesc = {};
 
  276    eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
 
  277    eventPoolDesc.count = numEvents;
 
  279    ze_event_pool_handle_t rawPool = 
nullptr;
 
  281                                   &
rtCtx->device, &rawPool));
 
 
  288    ze_event_handle_t rawEvent = 
nullptr;
 
  294      rawEvent = uniqueEvent.get();
 
  298        throw std::runtime_error(
"DynamicEventPool: reached max events limit");
 
  303      ze_event_desc_t eventDesc = {
 
  304          ZE_STRUCTURE_TYPE_EVENT_DESC, 
nullptr,
 
  306          ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST};
 
  308      ze_event_handle_t newEvent = 
nullptr;
 
  310          zeEventCreate(
eventPools.back().get(), &eventDesc, &newEvent));
 
 
  323           "Attempting to release unknown or already released event");
 
 
 
  354  void sync(ze_event_handle_t explicitEvent = 
nullptr) {
 
  355    ze_event_handle_t syncEvent{
nullptr};
 
  356    if (!explicitEvent) {
 
  358      syncEvent = lastImplicitEventPtr ? *lastImplicitEventPtr : 
nullptr;
 
  360      syncEvent = explicitEvent;
 
  364          syncEvent, std::numeric_limits<uint64_t>::max()));
 
 
  372  template <
typename Func>
 
  374    ze_event_handle_t newImplicitEvent = 
dynEventPool.takeEvent();
 
  376    const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0;
 
  377    std::forward<Func>(op)(newImplicitEvent, numWaitEvents,
 
  378                           lastImplicitEventPtr);
 
 
 
  383static ze_module_handle_t 
loadModule(
const void *data, 
size_t dataSize) {
 
  385  ze_module_handle_t zeModule;
 
  386  ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
 
  388                           ZE_MODULE_FORMAT_IL_SPIRV,
 
  390                           (
const uint8_t *)data,
 
  393  ze_module_build_log_handle_t buildLogHandle;
 
  396                     &zeModule, &buildLogHandle);
 
  397  if (
result != ZE_RESULT_SUCCESS) {
 
  398    std::cerr << 
"Error creating module, error code: " << 
result << std::endl;
 
  400    L0_SAFE_CALL(zeModuleBuildLogGetString(buildLogHandle, &logSize, 
nullptr));
 
  401    std::string buildLog(
" ", logSize);
 
  403        zeModuleBuildLogGetString(buildLogHandle, &logSize, buildLog.data()));
 
  404    std::cerr << 
"Build log:\n" << buildLog << std::endl;
 
 
  426                                    ze_event_handle_t event) {
 
  427  assert(stream && 
"Invalid stream");
 
  428  assert(event && 
"Invalid event");
 
 
  442      zeEventHostSynchronize(event, std::numeric_limits<uint64_t>::max()));
 
 
  456  return catchAll([&]() {
 
  457    void *memPtr = 
nullptr;
 
  458    constexpr size_t alignment{64};
 
  459    ze_device_mem_alloc_desc_t deviceDesc = {};
 
  460    deviceDesc.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC;
 
  462      ze_host_mem_alloc_desc_t hostDesc = {};
 
  463      hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC;
 
  465                                    &hostDesc, size, alignment,
 
  473      throw std::runtime_error(
"mem allocation failed!");
 
 
  484extern "C" void mgpuMemcpy(
void *dst, 
void *src, 
size_t sizeBytes,
 
  486  stream->
enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
 
  487                        ze_event_handle_t *waitEvents) {
 
  490        numWaitEvents, waitEvents));
 
 
  494template <
typename PATTERN_TYPE>
 
  495static void mgpuMemset(
void *dst, PATTERN_TYPE value, 
size_t count,
 
  502  stream->
enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
 
  503                        ze_event_handle_t *waitEvents) {
 
  505        listType, dst, &value, 
sizeof(PATTERN_TYPE),
 
  506        count * 
sizeof(PATTERN_TYPE), newEvent, numWaitEvents, waitEvents));
 
 
  509extern "C" void mgpuMemset32(
void *dst, 
unsigned int value, 
size_t count,
 
 
  514extern "C" void mgpuMemset16(
void *dst, 
unsigned short value, 
size_t count,
 
 
  520                                             size_t gpuBlobSize) {
 
  521  return catchAll([&]() { 
return loadModule(data, gpuBlobSize); });
 
 
  526  assert(module && name);
 
  527  ze_kernel_handle_t zeKernel;
 
  528  ze_kernel_desc_t desc = {};
 
  529  desc.pKernelName = name;
 
 
  535                                 size_t gridY, 
size_t gridZ, 
size_t blockX,
 
  536                                 size_t blockY, 
size_t blockZ,
 
  538                                 void **params, 
void ** ,
 
  539                                 size_t paramsCount) {
 
  541  if (sharedMemBytes > 0) {
 
  542    paramsCount = paramsCount - 1; 
 
  544        zeKernelSetArgumentValue(kernel, paramsCount, sharedMemBytes, 
nullptr));
 
  546  for (
size_t i = 0; i < paramsCount; ++i)
 
  547    L0_SAFE_CALL(zeKernelSetArgumentValue(kernel, 
static_cast<uint32_t
>(i),
 
  548                                          sizeof(
void *), params[i]));
 
  549  L0_SAFE_CALL(zeKernelSetGroupSize(kernel, blockX, blockY, blockZ));
 
  550  ze_group_count_t dispatch;
 
  551  dispatch.groupCountX = 
static_cast<uint32_t
>(gridX);
 
  552  dispatch.groupCountY = 
static_cast<uint32_t
>(gridY);
 
  553  dispatch.groupCountZ = 
static_cast<uint32_t
>(gridZ);
 
  554  stream->
enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
 
  555                        ze_event_handle_t *waitEvents) {
 
  558        numWaitEvents, waitEvents));
 
 
std::unique_ptr< std::remove_pointer< ze_event_handle_t >::type, ZeEventDeleter > UniqueZeEvent
 
void mgpuSetDefaultDevice(int32_t devIdx)
 
static ze_module_handle_t loadModule(const void *data, size_t dataSize)
 
static L0RTContextWrapper & getRtContext()
 
void mgpuMemset16(void *dst, unsigned short value, size_t count, StreamWrapper *stream)
 
#define L0_SAFE_CALL(call)
 
static void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count, StreamWrapper *stream)
 
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
 
static DynamicEventPool & getDynamicEventPool()
 
std::unique_ptr< std::remove_pointer< ze_context_handle_t >::type, ZeContextDeleter > UniqueZeContext
 
void * mgpuMemAlloc(uint64_t size, StreamWrapper *stream, bool isShared)
 
void mgpuStreamDestroy(StreamWrapper *stream)
 
ze_module_handle_t mgpuModuleLoad(const void *data, size_t gpuBlobSize)
 
void mgpuEventSynchronize(ze_event_handle_t event)
 
void mgpuModuleUnload(ze_module_handle_t module)
 
static ze_driver_handle_t getDriver(uint32_t idx=0)
 
void mgpuMemset32(void *dst, unsigned int value, size_t count, StreamWrapper *stream)
 
StreamWrapper * mgpuStreamCreate()
 
std::unique_ptr< std::remove_pointer< ze_command_list_handle_t >::type, ZeCommandListDeleter > UniqueZeCommandList
 
void mgpuEventDestroy(ze_event_handle_t event)
 
void mgpuStreamSynchronize(StreamWrapper *stream)
 
ze_kernel_handle_t mgpuModuleGetFunction(ze_module_handle_t module, const char *name)
 
void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, StreamWrapper *stream)
 
void mgpuStreamWaitEvent(StreamWrapper *stream, ze_event_handle_t event)
 
void mgpuMemFree(void *ptr, StreamWrapper *stream)
 
void mgpuEventRecord(ze_event_handle_t event, StreamWrapper *stream)
 
ze_event_handle_t mgpuEventCreate()
 
std::unique_ptr< std::remove_pointer< ze_event_pool_handle_t >::type, ZeEventPoolDeleter > UniqueZeEventPool
 
void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, StreamWrapper *stream, void **params, void **, size_t paramsCount)
 
void createNewPool(size_t numEvents)
 
L0RTContextWrapper * rtCtx
 
DynamicEventPool & operator=(const DynamicEventPool &)=delete
 
static constexpr size_t numEventsPerPool
 
void releaseEvent(ze_event_handle_t event)
 
DynamicEventPool(DynamicEventPool &&) noexcept=default
 
DynamicEventPool(L0RTContextWrapper *rtCtx)
 
std::vector< UniqueZeEventPool > eventPools
 
std::unordered_map< ze_event_handle_t, UniqueZeEvent > takenEvents
 
ze_event_handle_t takeEvent()
 
std::vector< UniqueZeEvent > availableEvents
 
DynamicEventPool(const DynamicEventPool &)=delete
 
size_t currentEventsLimit
 
UniqueZeCommandList immCmdListCopy
 
L0RTContextWrapper()=default
 
L0RTContextWrapper & operator=(const L0RTContextWrapper &)=delete
 
uint32_t copyEngineMaxMemoryFillPatternSize
 
ze_device_handle_t device
 
L0RTContextWrapper(L0RTContextWrapper &&) noexcept=default
 
L0RTContextWrapper(const uint32_t driverIdx=0, const int32_t devIdx=0)
 
ze_driver_handle_t driver
 
UniqueZeCommandList immCmdListCompute
 
L0RTContextWrapper(const L0RTContextWrapper &)=delete
 
ze_event_handle_t * getLastImplicitEventPtr()
 
void sync(ze_event_handle_t explicitEvent=nullptr)
 
StreamWrapper(DynamicEventPool &dynEventPool)
 
std::deque< ze_event_handle_t > implicitEventStack
 
DynamicEventPool & dynEventPool
 
void enqueueOp(Func &&op)
 
void operator()(ze_command_list_handle_t cmdList) const
 
void operator()(ze_context_handle_t ctx) const
 
void operator()(ze_event_handle_t event) const
 
void operator()(ze_event_pool_handle_t pool) const