MLIR 23.0.0git
SyclRuntimeWrappers.cpp
Go to the documentation of this file.
1//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements wrappers around the sycl runtime library with C linkage
10//
11//===----------------------------------------------------------------------===//
12
13#include <cstdlib>
14
15#include <level_zero/ze_api.h>
16#include <sycl/ext/oneapi/backend/level_zero.hpp>
17#include <sycl/sycl.hpp>
18
19#ifdef _WIN32
20#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
21#else
22#define SYCL_RUNTIME_EXPORT
23#endif // _WIN32
24
25namespace {
26
27template <typename F>
28auto catchAll(F &&func) {
29 try {
30 return func();
31 } catch (const std::exception &e) {
32 fprintf(stderr, "SYCL runtime error: %s\n", e.what());
33 fflush(stderr);
34 std::exit(EXIT_FAILURE);
35 } catch (...) {
36 fprintf(stderr, "SYCL runtime error: unknown exception was thrown\n");
37 fflush(stderr);
38 std::exit(EXIT_FAILURE);
39 }
40}
41
42#define L0_SAFE_CALL(call) \
43 { \
44 ze_result_t status = (call); \
45 if (status != ZE_RESULT_SUCCESS) { \
46 fprintf(stderr, "L0 error %d\n", status); \
47 fflush(stderr); \
48 abort(); \
49 } \
50 }
51
52} // namespace
53
54static sycl::device getDefaultDevice() {
55 static sycl::device syclDevice;
56 static bool isDeviceInitialised = false;
57 if (!isDeviceInitialised) {
58 auto platformList = sycl::platform::get_platforms();
59 for (const auto &platform : platformList) {
60 auto platformName = platform.get_info<sycl::info::platform::name>();
61 bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
62 if (!isLevelZero)
63 continue;
64
65 syclDevice = platform.get_devices()[0];
66 isDeviceInitialised = true;
67 return syclDevice;
68 }
69 throw std::runtime_error(
70 "no Level-Zero SYCL platform found; the MLIR SYCL runtime wrapper "
71 "currently requires a Level-Zero backend");
72 } else {
73 return syclDevice;
74 }
75}
76
77static sycl::context getDefaultContext() {
78 static sycl::context syclContext{getDefaultDevice()};
79 return syclContext;
80}
81
82static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
83 void *memPtr = nullptr;
84 if (isShared) {
85 memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
87 } else {
88 memPtr = sycl::aligned_alloc_device(64, size, *queue);
89 }
90 if (memPtr == nullptr) {
91 throw std::runtime_error("mem allocation failed!");
92 }
93 return memPtr;
94}
95
96static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
97 sycl::free(ptr, *queue);
98}
99
100static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
101 assert(data);
102 ze_module_handle_t zeModule;
103 ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
104 nullptr,
105 ZE_MODULE_FORMAT_IL_SPIRV,
106 dataSize,
107 (const uint8_t *)data,
108 nullptr,
109 nullptr};
110 auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
112 auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
114 L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
115 return zeModule;
116}
117
118static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
119 assert(zeModule);
120 assert(name);
121 ze_kernel_handle_t zeKernel;
122 ze_kernel_desc_t desc = {};
123 desc.pKernelName = name;
124
125 L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
126 sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
127 sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
128 sycl::bundle_state::executable>(
129 {zeModule}, getDefaultContext());
130
131 auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
132 {kernelBundle, zeKernel}, getDefaultContext());
133 return new sycl::kernel(kernel);
134}
135
136static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
137 size_t gridY, size_t gridZ, size_t blockX,
138 size_t blockY, size_t blockZ, size_t sharedMemBytes,
139 void **params, size_t paramsCount) {
140 auto syclGlobalRange =
141 sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
142 auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
143 sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
144
145 queue->submit([&](sycl::handler &cgh) {
146 for (size_t i = 0; i < paramsCount; i++) {
147 cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
148 }
149 cgh.parallel_for(syclNdRange, *kernel);
150 });
151}
152
153// Wrappers
154
155extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
156
157 return catchAll([&]() {
158 sycl::queue *queue =
159 new sycl::queue(getDefaultContext(), getDefaultDevice());
160 return queue;
161 });
162}
163
164extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
165 catchAll([&]() { delete queue; });
166}
167
168extern "C" SYCL_RUNTIME_EXPORT void *
169mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
170 return catchAll([&]() {
171 return allocDeviceMemory(queue, static_cast<size_t>(size), true);
172 });
173}
174
175extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
176 catchAll([&]() {
177 if (ptr) {
178 deallocDeviceMemory(queue, ptr);
179 }
180 });
181}
182
183extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
184mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
185 return catchAll([&]() { return loadModule(data, gpuBlobSize); });
186}
187
188extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
189mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
190 return catchAll([&]() { return getKernel(module, name); });
191}
192
193extern "C" SYCL_RUNTIME_EXPORT void
194mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
195 size_t blockX, size_t blockY, size_t blockZ,
196 size_t sharedMemBytes, sycl::queue *queue, void **params,
197 void ** /*extra*/, size_t paramsCount) {
198 return catchAll([&]() {
199 launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
200 sharedMemBytes, params, paramsCount);
201 });
202}
203
204extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
205
206 catchAll([&]() { queue->wait(); });
207}
208
209extern "C" SYCL_RUNTIME_EXPORT void
210mgpuModuleUnload(ze_module_handle_t module) {
211
212 catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
213}
214
215extern "C" SYCL_RUNTIME_EXPORT void
216mgpuMemcpy(void *dst, void *src, size_t sizeBytes, sycl::queue *queue) {
217 catchAll([&]() { queue->memcpy(dst, src, sizeBytes).wait(); });
218}
static ze_module_handle_t loadModule(const void *data, size_t dataSize)
static void deallocDeviceMemory(sycl::queue *queue, void *ptr)
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
static void * allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared)
#define SYCL_RUNTIME_EXPORT
static sycl::device getDefaultDevice()
#define L0_SAFE_CALL(call)
SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue)
SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue)
SYCL_RUNTIME_EXPORT void mgpuModuleUnload(ze_module_handle_t module)
SYCL_RUNTIME_EXPORT void * mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared)
SYCL_RUNTIME_EXPORT ze_module_handle_t mgpuModuleLoad(const void *data, size_t gpuBlobSize)
SYCL_RUNTIME_EXPORT void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, sycl::queue *queue)
SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue)
SYCL_RUNTIME_EXPORT sycl::queue * mgpuStreamCreate()
static sycl::context getDefaultContext()
SYCL_RUNTIME_EXPORT void mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, sycl::queue *queue, void **params, void **, size_t paramsCount)
SYCL_RUNTIME_EXPORT sycl::kernel * mgpuModuleGetFunction(ze_module_handle_t module, const char *name)