MLIR 23.0.0git
RocmRuntimeWrappers.cpp
Go to the documentation of this file.
1//===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements C wrappers around the ROCM library for easy linking in ORC jit.
10// Also adds some debugging helpers that are helpful when writing MLIR code to
11// run on GPUs.
12//
13//===----------------------------------------------------------------------===//
14
15#include <cassert>
16#include <numeric>
17
19#include "llvm/ADT/ArrayRef.h"
20
21#include "hip/hip_runtime.h"
22
23#define HIP_REPORT_IF_ERROR(expr) \
24 [](hipError_t result) { \
25 if (!result) \
26 return; \
27 const char *name = hipGetErrorName(result); \
28 if (!name) \
29 name = "<unknown>"; \
30 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
31 }(expr)
32
33thread_local static int32_t defaultDevice = 0;
34
35extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
36 hipModule_t module = nullptr;
37 HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
38 return module;
39}
40
41extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel,
42 size_t /*assmeblySize*/) {
43 assert(false && "This function is not available in HIP.");
44 return nullptr;
45}
46
47extern "C" void mgpuModuleUnload(hipModule_t module) {
48 HIP_REPORT_IF_ERROR(hipModuleUnload(module));
49}
50
51extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
52 const char *name) {
53 hipFunction_t function = nullptr;
54 HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
55 return function;
56}
57
58// The wrapper uses intptr_t instead of ROCM's unsigned int to match
59// the type of MLIR's index type. This avoids the need for casts in the
60// generated MLIR code.
61extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
62 intptr_t gridY, intptr_t gridZ,
63 intptr_t blockX, intptr_t blockY,
64 intptr_t blockZ, int32_t smem,
65 hipStream_t stream, void **params,
66 void **extra, size_t /*paramsCount*/) {
67 HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
68 blockX, blockY, blockZ, smem,
69 stream, params, extra));
70}
71
72// Cooperative launch entry point. The cluster dimensions are accepted to
73// match the CUDA wrapper signature, but HIP does not support thread block
74// clusters; passing nonzero cluster dimensions is a usage error.
76 hipFunction_t function, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
77 intptr_t clusterX, intptr_t clusterY, intptr_t clusterZ, intptr_t blockX,
78 intptr_t blockY, intptr_t blockZ, int32_t smem, hipStream_t stream,
79 void **params, void ** /*extra*/) {
80 if (clusterX != 0 || clusterY != 0 || clusterZ != 0) {
81 fprintf(stderr,
82 "mgpuLaunchKernelCooperative: HIP does not support thread block "
83 "clusters (got cluster=%ld,%ld,%ld)\n",
84 clusterX, clusterY, clusterZ);
85 abort();
86 }
88 hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, blockX,
89 blockY, blockZ, smem, stream, params));
90}
91
92extern "C" hipStream_t mgpuStreamCreate() {
93 hipStream_t stream = nullptr;
94 HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
95 return stream;
96}
97
98extern "C" void mgpuStreamDestroy(hipStream_t stream) {
99 HIP_REPORT_IF_ERROR(hipStreamDestroy(stream));
100}
101
102extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
103 return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
104}
105
106extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
107 HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0));
108}
109
110extern "C" hipEvent_t mgpuEventCreate() {
111 hipEvent_t event = nullptr;
112 HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
113 return event;
114}
115
116extern "C" void mgpuEventDestroy(hipEvent_t event) {
117 HIP_REPORT_IF_ERROR(hipEventDestroy(event));
118}
119
120extern "C" void mgpuEventSynchronize(hipEvent_t event) {
121 HIP_REPORT_IF_ERROR(hipEventSynchronize(event));
122}
123
124extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
125 HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
126}
127
128extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/,
129 bool /*isHostShared*/) {
130 void *ptr;
131 HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
132 return ptr;
133}
134
135extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
136 HIP_REPORT_IF_ERROR(hipFree(ptr));
137}
138
139extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
140 hipStream_t stream) {
142 hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
143}
144
145extern "C" void mgpuMemset32(void *dst, int value, size_t count,
146 hipStream_t stream) {
147 HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
148 value, count, stream));
149}
150
151extern "C" void mgpuMemset16(void *dst, int short value, size_t count,
152 hipStream_t stream) {
153 HIP_REPORT_IF_ERROR(hipMemsetD16Async(reinterpret_cast<hipDeviceptr_t>(dst),
154 value, count, stream));
155}
156
157/// Helper functions for writing mlir example code
158
159// Allows to register byte array with the ROCM runtime. Helpful until we have
160// transfer functions implemented.
161extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
162 HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
163}
164
165// Allows to register a MemRef with the ROCm runtime. Helpful until we have
166// transfer functions implemented.
167extern "C" void
169 int64_t elementSizeBytes) {
170
171 llvm::SmallVector<int64_t, 4> denseStrides(rank);
172 llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
173 llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
174
175 std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
176 std::multiplies<int64_t>());
177 auto sizeBytes = denseStrides.front() * elementSizeBytes;
178
179 // Only densely packed tensors are currently supported.
180 std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
181 denseStrides.end());
182 denseStrides.back() = 1;
183 assert(strides == llvm::ArrayRef(denseStrides));
184
185 auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
186 mgpuMemHostRegister(ptr, sizeBytes);
187}
188
189// Allows to unregister byte array with the ROCM runtime. Helpful until we have
190// transfer functions implemented.
191extern "C" void mgpuMemHostUnregister(void *ptr) {
192 HIP_REPORT_IF_ERROR(hipHostUnregister(ptr));
193}
194
195// Allows to unregister a MemRef with the ROCm runtime. Helpful until we have
196// transfer functions implemented.
197extern "C" void
199 StridedMemRefType<char, 1> *descriptor,
200 int64_t elementSizeBytes) {
201 auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
203}
204
205template <typename T>
206void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
207 HIP_REPORT_IF_ERROR(hipSetDevice(0));
209 hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0));
210}
211
213mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset,
214 int64_t size, int64_t stride) {
215 float *devicePtr = nullptr;
216 mgpuMemGetDevicePointer(aligned, &devicePtr);
217 return {devicePtr, devicePtr, offset, {size}, {stride}};
218}
219
221mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
222 int64_t offset, int64_t size, int64_t stride) {
223 int32_t *devicePtr = nullptr;
224 mgpuMemGetDevicePointer(aligned, &devicePtr);
225 return {devicePtr, devicePtr, offset, {size}, {stride}};
226}
227
228extern "C" void mgpuSetDefaultDevice(int32_t device) {
229 defaultDevice = device;
230 HIP_REPORT_IF_ERROR(hipSetDevice(device));
231}
static thread_local int32_t defaultDevice
void mgpuMemset32(void *dst, int value, size_t count, hipStream_t stream)
void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes)
Helper functions for writing mlir example code.
void mgpuMemset16(void *dst, int short value, size_t count, hipStream_t stream)
hipModule_t mgpuModuleLoadJIT(void *data, int optLevel, size_t)
void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event)
hipEvent_t mgpuEventCreate()
void mgpuEventSynchronize(hipEvent_t event)
void mgpuStreamDestroy(hipStream_t stream)
void mgpuMemHostUnregister(void *ptr)
void mgpuStreamSynchronize(hipStream_t stream)
StridedMemRefType< int32_t, 1 > mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned, int64_t offset, int64_t size, int64_t stride)
void mgpuModuleUnload(hipModule_t module)
void mgpuLaunchKernelCooperative(hipFunction_t function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, intptr_t clusterX, intptr_t clusterY, intptr_t clusterZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **)
void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr)
StridedMemRefType< float, 1 > mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset, int64_t size, int64_t stride)
void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, hipStream_t stream)
void mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType< char, 1 > *descriptor, int64_t elementSizeBytes)
hipFunction_t mgpuModuleGetFunction(hipModule_t module, const char *name)
hipModule_t mgpuModuleLoad(void *data, size_t)
void mgpuEventDestroy(hipEvent_t event)
void mgpuEventRecord(hipEvent_t event, hipStream_t stream)
void * mgpuMemAlloc(uint64_t sizeBytes, hipStream_t, bool)
#define HIP_REPORT_IF_ERROR(expr)
void mgpuMemFree(void *ptr, hipStream_t)
void mgpuSetDefaultDevice(int32_t device)
void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra, size_t)
void mgpuMemHostUnregisterMemRef(int64_t rank, StridedMemRefType< char, 1 > *descriptor, int64_t elementSizeBytes)
hipStream_t mgpuStreamCreate()
StridedMemRef descriptor type with static rank.