MLIR 23.0.0git
RocmRuntimeWrappers.cpp
Go to the documentation of this file.
1//===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements C wrappers around the ROCM library for easy linking in ORC jit.
10// Also adds some debugging helpers that are helpful when writing MLIR code to
11// run on GPUs.
12//
13//===----------------------------------------------------------------------===//
14
15#include <cassert>
16#include <numeric>
17
19#include "llvm/ADT/ArrayRef.h"
20
21#include "hip/hip_runtime.h"
22
23#define HIP_REPORT_IF_ERROR(expr) \
24 [](hipError_t result) { \
25 if (!result) \
26 return; \
27 const char *name = hipGetErrorName(result); \
28 if (!name) \
29 name = "<unknown>"; \
30 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
31 }(expr)
32
33thread_local static int32_t defaultDevice = 0;
34
35extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
36 hipModule_t module = nullptr;
37 HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
38 return module;
39}
40
41extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel,
42 size_t /*assmeblySize*/) {
43 assert(false && "This function is not available in HIP.");
44 return nullptr;
45}
46
47extern "C" void mgpuModuleUnload(hipModule_t module) {
48 HIP_REPORT_IF_ERROR(hipModuleUnload(module));
49}
50
51extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
52 const char *name) {
53 hipFunction_t function = nullptr;
54 HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
55 return function;
56}
57
58// The wrapper uses intptr_t instead of ROCM's unsigned int to match
59// the type of MLIR's index type. This avoids the need for casts in the
60// generated MLIR code.
61extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
62 intptr_t gridY, intptr_t gridZ,
63 intptr_t blockX, intptr_t blockY,
64 intptr_t blockZ, int32_t smem,
65 hipStream_t stream, void **params,
66 void **extra, size_t /*paramsCount*/) {
67 HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
68 blockX, blockY, blockZ, smem,
69 stream, params, extra));
70}
71
72extern "C" hipStream_t mgpuStreamCreate() {
73 hipStream_t stream = nullptr;
74 HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
75 return stream;
76}
77
78extern "C" void mgpuStreamDestroy(hipStream_t stream) {
79 HIP_REPORT_IF_ERROR(hipStreamDestroy(stream));
80}
81
82extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
83 return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
84}
85
86extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
87 HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0));
88}
89
90extern "C" hipEvent_t mgpuEventCreate() {
91 hipEvent_t event = nullptr;
92 HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
93 return event;
94}
95
96extern "C" void mgpuEventDestroy(hipEvent_t event) {
97 HIP_REPORT_IF_ERROR(hipEventDestroy(event));
98}
99
100extern "C" void mgpuEventSynchronize(hipEvent_t event) {
101 HIP_REPORT_IF_ERROR(hipEventSynchronize(event));
102}
103
104extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
105 HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
106}
107
108extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/,
109 bool /*isHostShared*/) {
110 void *ptr;
111 HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
112 return ptr;
113}
114
115extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
116 HIP_REPORT_IF_ERROR(hipFree(ptr));
117}
118
119extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
120 hipStream_t stream) {
122 hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
123}
124
125extern "C" void mgpuMemset32(void *dst, int value, size_t count,
126 hipStream_t stream) {
127 HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
128 value, count, stream));
129}
130
131extern "C" void mgpuMemset16(void *dst, int short value, size_t count,
132 hipStream_t stream) {
133 HIP_REPORT_IF_ERROR(hipMemsetD16Async(reinterpret_cast<hipDeviceptr_t>(dst),
134 value, count, stream));
135}
136
137/// Helper functions for writing mlir example code
138
139// Allows to register byte array with the ROCM runtime. Helpful until we have
140// transfer functions implemented.
141extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
142 HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
143}
144
145// Allows to register a MemRef with the ROCm runtime. Helpful until we have
146// transfer functions implemented.
147extern "C" void
149 int64_t elementSizeBytes) {
150
151 llvm::SmallVector<int64_t, 4> denseStrides(rank);
152 llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
153 llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
154
155 std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
156 std::multiplies<int64_t>());
157 auto sizeBytes = denseStrides.front() * elementSizeBytes;
158
159 // Only densely packed tensors are currently supported.
160 std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
161 denseStrides.end());
162 denseStrides.back() = 1;
163 assert(strides == llvm::ArrayRef(denseStrides));
164
165 auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
166 mgpuMemHostRegister(ptr, sizeBytes);
167}
168
169// Allows to unregister byte array with the ROCM runtime. Helpful until we have
170// transfer functions implemented.
171extern "C" void mgpuMemHostUnregister(void *ptr) {
172 HIP_REPORT_IF_ERROR(hipHostUnregister(ptr));
173}
174
175// Allows to unregister a MemRef with the ROCm runtime. Helpful until we have
176// transfer functions implemented.
177extern "C" void
179 StridedMemRefType<char, 1> *descriptor,
180 int64_t elementSizeBytes) {
181 auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
183}
184
185template <typename T>
186void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
187 HIP_REPORT_IF_ERROR(hipSetDevice(0));
189 hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0));
190}
191
193mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset,
194 int64_t size, int64_t stride) {
195 float *devicePtr = nullptr;
196 mgpuMemGetDevicePointer(aligned, &devicePtr);
197 return {devicePtr, devicePtr, offset, {size}, {stride}};
198}
199
201mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
202 int64_t offset, int64_t size, int64_t stride) {
203 int32_t *devicePtr = nullptr;
204 mgpuMemGetDevicePointer(aligned, &devicePtr);
205 return {devicePtr, devicePtr, offset, {size}, {stride}};
206}
207
208extern "C" void mgpuSetDefaultDevice(int32_t device) {
209 defaultDevice = device;
210 HIP_REPORT_IF_ERROR(hipSetDevice(device));
211}
static thread_local int32_t defaultDevice
void mgpuMemset32(void *dst, int value, size_t count, hipStream_t stream)
void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes)
Helper functions for writing mlir example code.
void mgpuMemset16(void *dst, int short value, size_t count, hipStream_t stream)
hipModule_t mgpuModuleLoadJIT(void *data, int optLevel, size_t)
void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event)
hipEvent_t mgpuEventCreate()
void mgpuEventSynchronize(hipEvent_t event)
void mgpuStreamDestroy(hipStream_t stream)
void mgpuMemHostUnregister(void *ptr)
void mgpuStreamSynchronize(hipStream_t stream)
StridedMemRefType< int32_t, 1 > mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned, int64_t offset, int64_t size, int64_t stride)
void mgpuModuleUnload(hipModule_t module)
void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr)
StridedMemRefType< float, 1 > mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset, int64_t size, int64_t stride)
void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, hipStream_t stream)
void mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType< char, 1 > *descriptor, int64_t elementSizeBytes)
hipFunction_t mgpuModuleGetFunction(hipModule_t module, const char *name)
hipModule_t mgpuModuleLoad(void *data, size_t)
void mgpuEventDestroy(hipEvent_t event)
void mgpuEventRecord(hipEvent_t event, hipStream_t stream)
void * mgpuMemAlloc(uint64_t sizeBytes, hipStream_t, bool)
#define HIP_REPORT_IF_ERROR(expr)
void mgpuMemFree(void *ptr, hipStream_t)
void mgpuSetDefaultDevice(int32_t device)
void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra, size_t)
void mgpuMemHostUnregisterMemRef(int64_t rank, StridedMemRefType< char, 1 > *descriptor, int64_t elementSizeBytes)
hipStream_t mgpuStreamCreate()
StridedMemRef descriptor type with static rank.