21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/LLVMContext.h"
25#include "llvm/IR/Module.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Transforms/Utils/ModuleUtils.h"
33class SelectObjectAttrImpl
34 :
public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
35 SelectObjectAttrImpl> {
37 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op)
const;
43 LogicalResult embedBinary(Attribute attribute, Operation *operation,
44 llvm::IRBuilderBase &builder,
45 LLVM::ModuleTranslation &moduleTranslation)
const;
50 Operation *launchFuncOperation,
51 Operation *binaryOperation,
52 llvm::IRBuilderBase &builder,
53 LLVM::ModuleTranslation &moduleTranslation)
const;
58SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op)
const {
59 ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
64 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
68 if (
auto indexAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
69 index = indexAttr.getInt();
71 for (
auto [i, attr] : llvm::enumerate(objects)) {
72 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
73 if (obj.getTarget() ==
target) {
84 if (index < 0 || index >=
static_cast<int64_t
>(objects.size())) {
85 op->emitError(
"the requested target object couldn't be found");
88 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
92 return moduleName +
"_module";
97 gpu::ObjectAttr
object, Module &module) {
102 bool addNull = (
object.getFormat() == gpu::CompilationTarget::Assembly);
103 StringRef serializedStr =
object.getObject().getValue();
105 ConstantDataArray::getString(module.getContext(), serializedStr, addNull);
106 GlobalVariable *serializedObj =
107 new GlobalVariable(module, serializedCst->getType(),
true,
108 GlobalValue::LinkageTypes::InternalLinkage,
109 serializedCst, moduleName +
"_binary");
110 serializedObj->setAlignment(MaybeAlign(8));
111 serializedObj->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
114 auto optLevel = APInt::getZero(32);
116 if (DictionaryAttr objectProps =
object.getProperties()) {
117 if (
auto section = dyn_cast_or_null<StringAttr>(
119 serializedObj->setSection(section.getValue());
122 if (
auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get(
"O")))
123 optLevel = optAttr.getValue();
126 IRBuilder<> builder(module.getContext());
127 auto i32Ty = builder.getInt32Ty();
128 auto i64Ty = builder.getInt64Ty();
129 auto ptrTy = builder.getPtrTy(0);
130 auto voidTy = builder.getVoidTy();
133 auto *modulePtr =
new GlobalVariable(
134 module, ptrTy,
false, GlobalValue::InternalLinkage,
135 ConstantPointerNull::get(ptrTy),
138 auto *loadFn = Function::Create(FunctionType::get(voidTy,
false),
139 GlobalValue::InternalLinkage,
140 moduleName +
"_load", module);
141 loadFn->setSection(
".text.startup");
142 auto *loadBlock = BasicBlock::Create(module.getContext(),
"entry", loadFn);
143 builder.SetInsertPoint(loadBlock);
144 Value *moduleObj = [&] {
145 if (
object.getFormat() == gpu::CompilationTarget::Assembly) {
146 FunctionCallee moduleLoadFn =
module.getOrInsertFunction(
147 "mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false));
148 Constant *optValue = ConstantInt::get(i32Ty, optLevel);
149 return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
151 FunctionCallee moduleLoadFn =
module.getOrInsertFunction(
152 "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
154 ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
155 return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
158 builder.CreateStore(moduleObj, modulePtr);
159 builder.CreateRetVoid();
160 appendToGlobalCtors(module, loadFn, 123);
162 auto *unloadFn = Function::Create(
163 FunctionType::get(voidTy,
false),
164 GlobalValue::InternalLinkage, moduleName +
"_unload", module);
165 unloadFn->setSection(
".text.startup");
167 BasicBlock::Create(module.getContext(),
"entry", unloadFn);
168 builder.SetInsertPoint(unloadBlock);
169 FunctionCallee moduleUnloadFn =
module.getOrInsertFunction(
170 "mgpuModuleUnload", FunctionType::get(voidTy, ptrTy, false));
171 builder.CreateCall(moduleUnloadFn, builder.CreateLoad(ptrTy, modulePtr));
172 builder.CreateRetVoid();
173 appendToGlobalDtors(module, unloadFn, 123);
179LogicalResult SelectObjectAttrImpl::embedBinary(
180 Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
181 LLVM::ModuleTranslation &moduleTranslation)
const {
182 assert(operation &&
"The binary operation must be non null.");
186 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
188 operation->
emitError(
"operation must be a GPU binary");
192 gpu::ObjectAttr
object = getSelectedObject(op);
204 LaunchKernel(Module &module, IRBuilderBase &builder,
205 mlir::LLVM::ModuleTranslation &moduleTranslation);
207 FunctionCallee getKernelLaunchFn();
210 FunctionCallee getClusterKernelLaunchFn();
213 FunctionCallee getModuleFunctionFn();
216 FunctionCallee getStreamCreateFn();
219 FunctionCallee getStreamDestroyFn();
222 FunctionCallee getStreamSyncFn();
225 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
228 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
231 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
232 mlir::gpu::ObjectAttr
object);
236 IRBuilderBase &builder;
237 mlir::LLVM::ModuleTranslation &moduleTranslation;
242 PointerType *ptrTy{};
247LogicalResult SelectObjectAttrImpl::launchKernel(
248 Attribute attribute, Operation *launchFuncOperation,
249 Operation *binaryOperation, llvm::IRBuilderBase &builder,
250 LLVM::ModuleTranslation &moduleTranslation)
const {
252 assert(launchFuncOperation &&
"The launch func operation must be non null.");
253 if (!launchFuncOperation)
256 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
258 launchFuncOperation->
emitError(
"operation must be a GPU launch func Op.");
262 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
264 binaryOperation->
emitError(
"operation must be a GPU binary.");
267 gpu::ObjectAttr
object = getSelectedObject(binOp);
271 return llvm::LaunchKernel(*moduleTranslation.
getLLVMModule(), builder,
273 .createKernelLaunch(launchFuncOp,
object);
276llvm::LaunchKernel::LaunchKernel(
277 Module &module, IRBuilderBase &builder,
278 mlir::LLVM::ModuleTranslation &moduleTranslation)
279 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
280 i32Ty = builder.getInt32Ty();
281 i64Ty = builder.getInt64Ty();
282 ptrTy = builder.getPtrTy(0);
283 voidTy = builder.getVoidTy();
284 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
287llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
288 return module.getOrInsertFunction(
290 FunctionType::get(voidTy,
291 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
292 intPtrTy, intPtrTy, intPtrTy, i32Ty,
293 ptrTy, ptrTy, ptrTy, i64Ty}),
297llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
298 return module.getOrInsertFunction(
299 "mgpuLaunchClusterKernel",
302 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
303 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
304 i32Ty, ptrTy, ptrTy, ptrTy}),
308llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
309 return module.getOrInsertFunction(
310 "mgpuModuleGetFunction",
311 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
314llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
315 return module.getOrInsertFunction("mgpuStreamCreate",
316 FunctionType::get(ptrTy, false));
319llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
320 return module.getOrInsertFunction(
322 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
325llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
326 return module.getOrInsertFunction(
327 "mgpuStreamSynchronize",
328 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
333llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
334 StringRef kernelName) {
335 std::string globalName =
336 std::string(formatv(
"{0}_{1}_name", moduleName, kernelName));
338 if (GlobalVariable *gv = module.getGlobalVariable(globalName,
true))
341 return builder.CreateGlobalString(kernelName, globalName);
358llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
359 SmallVector<Value *> args =
361 SmallVector<Type *> structTypes(args.size(),
nullptr);
363 for (
auto [i, arg] : llvm::enumerate(args))
364 structTypes[i] = arg->getType();
366 Type *structTy = StructType::create(module.getContext(), structTypes);
367 Value *argStruct = builder.CreateAlloca(structTy, 0u);
368 Value *argArray = builder.CreateAlloca(
369 ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
372 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
373 builder.CreateStore(arg, structMember);
374 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
375 builder.CreateStore(structMember, arrayMember);
389llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
390 mlir::gpu::ObjectAttr
object) {
391 auto llvmValue = [&](mlir::Value value) -> Value * {
393 assert(v &&
"Value has not been translated.");
398 mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
399 Value *gx = llvmValue(grid.
x), *gy = llvmValue(grid.
y),
400 *gz = llvmValue(grid.
z);
403 mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
404 Value *bx = llvmValue(block.
x), *by = llvmValue(block.
y),
405 *bz = llvmValue(block.
z);
408 Value *dynamicMemorySize =
nullptr;
409 if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
410 dynamicMemorySize = llvmValue(dynSz);
412 dynamicMemorySize = ConstantInt::get(i32Ty, 0);
415 Value *argArray = createKernelArgArray(op);
418 StringRef moduleName = op.getKernelModuleName().getValue();
420 Value *modulePtr =
module.getGlobalVariable(moduleIdentifier.str(), true);
422 return op.emitError() <<
"Couldn't find the binary: " << moduleIdentifier;
423 Value *moduleObj = builder.CreateLoad(ptrTy, modulePtr);
424 Value *functionName = getOrCreateFunctionName(moduleName, op.getKernelName());
425 Value *moduleFunction =
426 builder.CreateCall(getModuleFunctionFn(), {moduleObj, functionName});
430 Value *stream =
nullptr;
432 auto destroyStream = make_scope_exit([&]() {
433 builder.CreateCall(getStreamSyncFn(), {stream});
434 builder.CreateCall(getStreamDestroyFn(), {stream});
436 if (mlir::Value asyncObject = op.getAsyncObject()) {
437 stream = llvmValue(asyncObject);
438 destroyStream.release();
440 stream = builder.CreateCall(getStreamCreateFn(), {});
443 llvm::Constant *paramsCount =
444 llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
447 Value *nullPtr = ConstantPointerNull::get(ptrTy);
450 if (op.hasClusterSize()) {
451 mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
452 Value *cx = llvmValue(cluster.
x), *cy = llvmValue(cluster.
y),
453 *cz = llvmValue(cluster.
z);
455 getClusterKernelLaunchFn(),
456 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
457 dynamicMemorySize, stream, argArray, nullPtr}));
459 builder.CreateCall(getKernelLaunchFn(),
460 ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
461 bz, dynamicMemorySize, stream,
462 argArray, nullPtr, paramsCount}));
471 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
static Twine getModuleIdentifier(StringRef moduleName)
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static LogicalResult embedBinaryImpl(StringRef moduleName, gpu::ObjectAttr object, Module &module)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr StringLiteral elfSectionName
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
@ Constant
Constant integer.