MLIR  21.0.0git
VulkanRuntimeWrappers.cpp
Go to the documentation of this file.
1 //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements C runtime wrappers around the VulkanRuntime.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <iostream>
14 #include <mutex>
15 #include <numeric>
16 #include <string>
17 #include <vector>
18 
19 #include "VulkanRuntime.h"
20 
21 // Explicitly export entry points to the vulkan-runtime-wrapper.
22 
23 #ifdef _WIN32
24 #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
25 #else
26 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
27 #endif // _WIN32
28 
29 namespace {
30 
31 class VulkanModule;
32 
33 // Class to be a thing that can be returned from `mgpuModuleGetFunction`.
34 struct VulkanFunction {
35  VulkanModule *module;
36  std::string name;
37 
38  VulkanFunction(VulkanModule *module, const char *name)
39  : module(module), name(name) {}
40 };
41 
42 // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage
43 // allocation of pointers returned from `mgpuModuleGetFunction`.
44 class VulkanModule {
45 public:
46  VulkanModule(const uint8_t *ptr, size_t sizeInBytes)
47  : blob(ptr, ptr + sizeInBytes) {}
48  ~VulkanModule() = default;
49 
50  VulkanFunction *getFunction(const char *name) {
51  return functions.emplace_back(std::make_unique<VulkanFunction>(this, name))
52  .get();
53  }
54 
55  uint8_t *blobData() { return blob.data(); }
56  size_t blobSizeInBytes() const { return blob.size(); }
57 
58 private:
59  std::vector<uint8_t> blob;
60  std::vector<std::unique_ptr<VulkanFunction>> functions;
61 };
62 
63 class VulkanRuntimeManager {
64 public:
65  VulkanRuntimeManager() = default;
66  VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
67  VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
68  ~VulkanRuntimeManager() = default;
69 
70  void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
71  const VulkanHostMemoryBuffer &memBuffer) {
72  std::lock_guard<std::mutex> lock(mutex);
73  vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
74  }
75 
76  void setEntryPoint(const char *entryPoint) {
77  std::lock_guard<std::mutex> lock(mutex);
78  vulkanRuntime.setEntryPoint(entryPoint);
79  }
80 
81  void setNumWorkGroups(NumWorkGroups numWorkGroups) {
82  std::lock_guard<std::mutex> lock(mutex);
83  vulkanRuntime.setNumWorkGroups(numWorkGroups);
84  }
85 
86  void setShaderModule(uint8_t *shader, uint32_t size) {
87  std::lock_guard<std::mutex> lock(mutex);
88  vulkanRuntime.setShaderModule(shader, size);
89  }
90 
91  void runOnVulkan() {
92  std::lock_guard<std::mutex> lock(mutex);
93  if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
94  failed(vulkanRuntime.updateHostMemoryBuffers()) ||
95  failed(vulkanRuntime.destroy())) {
96  std::cerr << "runOnVulkan failed";
97  }
98  }
99 
100 private:
101  VulkanRuntime vulkanRuntime;
102  std::mutex mutex;
103 };
104 
105 } // namespace
106 
107 template <typename T, int N>
111  int64_t offset;
112  int64_t sizes[N];
113  int64_t strides[N];
114 };
115 
116 extern "C" {
117 
118 //===----------------------------------------------------------------------===//
119 //
120 // Wrappers intended for mlir-runner. Uses of GPU dialect operations get
121 // lowered to calls to these functions by GPUToLLVMConversionPass.
122 //
123 //===----------------------------------------------------------------------===//
124 
126  return new VulkanRuntimeManager();
127 }
128 
129 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
130  delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
131 }
132 
134  // Currently a no-op as the other operations are synchronous.
135 }
136 
138  size_t gpuBlobSize) {
139  // gpuBlobSize is the size of the data in bytes.
140  return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize);
141 }
142 
144  delete static_cast<VulkanModule *>(vkModule);
145 }
146 
148  const char *name) {
149  if (!vkModule)
150  abort();
151  return static_cast<VulkanModule *>(vkModule)->getFunction(name);
152 }
153 
155 mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
156  size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
157  size_t /*smem*/, void *vkRuntimeManager, void **params,
158  void ** /*extra*/, size_t paramsCount) {
159  auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
160 
161  // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
162  // kernelIntersperseSizeCallConv options will set up the params array like:
163  // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
164  const size_t paramsPerMemRef = 2;
165  if (paramsCount % paramsPerMemRef != 0) {
166  abort(); // This would indicate a serious calling convention mismatch.
167  }
168  const DescriptorSetIndex setIndex = 0;
169  BindingIndex bindIndex = 0;
170  for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
171  void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
172  size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
173  VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
174  static_cast<uint32_t>(memrefBufferSize)};
175  manager->setResourceData(setIndex, bindIndex, memBuffer);
176  ++bindIndex;
177  }
178 
179  manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX),
180  static_cast<uint32_t>(gridY),
181  static_cast<uint32_t>(gridZ)});
182 
183  auto function = static_cast<VulkanFunction *>(vkKernel);
184  // Expected size should be in bytes.
185  manager->setShaderModule(
186  function->module->blobData(),
187  static_cast<uint32_t>(function->module->blobSizeInBytes()));
188  manager->setEntryPoint(function->name.c_str());
189 
190  manager->runOnVulkan();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 //
195 // Miscellaneous utility functions that can be directly used by tests.
196 //
197 //===----------------------------------------------------------------------===//
198 
199 /// Fills the given 1D float memref with the given float value.
202  float value) {
203  std::fill_n(ptr->allocated, ptr->sizes[0], value);
204 }
205 
206 /// Fills the given 2D float memref with the given float value.
209  float value) {
210  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
211 }
212 
213 /// Fills the given 3D float memref with the given float value.
216  float value) {
217  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
218  value);
219 }
220 
221 /// Fills the given 1D int memref with the given int value.
224  int32_t value) {
225  std::fill_n(ptr->allocated, ptr->sizes[0], value);
226 }
227 
228 /// Fills the given 2D int memref with the given int value.
231  int32_t value) {
232  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
233 }
234 
235 /// Fills the given 3D int memref with the given int value.
238  int32_t value) {
239  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
240  value);
241 }
242 
243 /// Fills the given 1D int memref with the given int8 value.
246  int8_t value) {
247  std::fill_n(ptr->allocated, ptr->sizes[0], value);
248 }
249 
250 /// Fills the given 2D int memref with the given int8 value.
253  int8_t value) {
254  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
255 }
256 
257 /// Fills the given 3D int memref with the given int8 value.
260  int8_t value) {
261  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
262  value);
263 }
264 }
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource1DInt8(MemRefDescriptor< int8_t, 1 > *ptr, int8_t value)
Fills the given 1D int memref with the given int8 value.
VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *)
VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, size_t, size_t, size_t, size_t, void *vkRuntimeManager, void **params, void **, size_t paramsCount)
VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager)
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource2DFloat(MemRefDescriptor< float, 2 > *ptr, float value)
Fills the given 2D float memref with the given float value.
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource1DFloat(MemRefDescriptor< float, 1 > *ptr, float value)
Fills the given 1D float memref with the given float value.
#define VULKAN_WRAPPER_SYMBOL_EXPORT
VULKAN_WRAPPER_SYMBOL_EXPORT void * mgpuModuleGetFunction(void *vkModule, const char *name)
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource2DInt(MemRefDescriptor< int32_t, 2 > *ptr, int32_t value)
Fills the given 2D int memref with the given int value.
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource2DInt8(MemRefDescriptor< int8_t, 2 > *ptr, int8_t value)
Fills the given 2D int memref with the given int8 value.
VULKAN_WRAPPER_SYMBOL_EXPORT void * mgpuModuleLoad(const void *data, size_t gpuBlobSize)
VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule)
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource3DInt(MemRefDescriptor< int32_t, 3 > *ptr, int32_t value)
Fills the given 3D int memref with the given int value.
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource3DFloat(MemRefDescriptor< float, 3 > *ptr, float value)
Fills the given 3D float memref with the given float value.
VULKAN_WRAPPER_SYMBOL_EXPORT void * mgpuStreamCreate()
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource1DInt(MemRefDescriptor< int32_t, 1 > *ptr, int32_t value)
Fills the given 1D int memref with the given int value.
VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource3DInt8(MemRefDescriptor< int8_t, 3 > *ptr, int8_t value)
Fills the given 3D int memref with the given int8 value.
uint32_t BindingIndex
Definition: VulkanRuntime.h:25
uint32_t DescriptorSetIndex
Definition: VulkanRuntime.h:24
Vulkan runtime.
Definition: VulkanRuntime.h:93
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Struct containing the number of local workgroups to dispatch for each dimension.
Definition: VulkanRuntime.h:49
Struct containing information regarding to a host memory buffer.
Definition: VulkanRuntime.h:40