21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Transforms/Utils/ModuleUtils.h"
33 class SelectObjectAttrImpl
34 :
public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
35 SelectObjectAttrImpl> {
37 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op)
const;
44 llvm::IRBuilderBase &builder,
52 llvm::IRBuilderBase &builder,
58 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op)
const {
64 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
68 if (
auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
69 index = indexAttr.getInt();
72 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
73 if (obj.getTarget() == target) {
84 if (index < 0 || index >=
static_cast<int64_t
>(objects.size())) {
85 op->emitError(
"the requested target object couldn't be found");
88 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
92 return moduleName +
"_module";
97 gpu::ObjectAttr
object, Module &module) {
102 bool addNull = (
object.getFormat() == gpu::CompilationTarget::Assembly);
103 StringRef serializedStr =
object.getObject().getValue();
105 ConstantDataArray::getString(module.getContext(), serializedStr, addNull);
106 GlobalVariable *serializedObj =
107 new GlobalVariable(module, serializedCst->getType(),
true,
108 GlobalValue::LinkageTypes::InternalLinkage,
109 serializedCst, moduleName +
"_binary");
110 serializedObj->setAlignment(MaybeAlign(8));
116 if (DictionaryAttr objectProps =
object.getProperties()) {
117 if (
auto section = dyn_cast_or_null<StringAttr>(
118 objectProps.get(gpu::elfSectionName))) {
119 serializedObj->setSection(section.getValue());
122 if (
auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get(
"O")))
123 optLevel = optAttr.getValue();
126 IRBuilder<> builder(module.getContext());
127 auto i32Ty = builder.getInt32Ty();
128 auto i64Ty = builder.getInt64Ty();
129 auto ptrTy = builder.getPtrTy(0);
130 auto voidTy = builder.getVoidTy();
133 auto *modulePtr =
new GlobalVariable(
134 module, ptrTy,
false, GlobalValue::InternalLinkage,
139 GlobalValue::InternalLinkage,
140 moduleName +
"_load", module);
141 loadFn->setSection(
".text.startup");
142 auto *loadBlock = BasicBlock::Create(module.getContext(),
"entry", loadFn);
143 builder.SetInsertPoint(loadBlock);
144 Value *moduleObj = [&] {
145 if (
object.getFormat() == gpu::CompilationTarget::Assembly) {
146 FunctionCallee moduleLoadFn = module.getOrInsertFunction(
149 return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
151 FunctionCallee moduleLoadFn = module.getOrInsertFunction(
155 return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
158 builder.CreateStore(moduleObj, modulePtr);
159 builder.CreateRetVoid();
160 appendToGlobalCtors(module, loadFn, 123);
162 auto *unloadFn = Function::Create(
164 GlobalValue::InternalLinkage, moduleName +
"_unload", module);
165 unloadFn->setSection(
".text.startup");
167 BasicBlock::Create(module.getContext(),
"entry", unloadFn);
168 builder.SetInsertPoint(unloadBlock);
169 FunctionCallee moduleUnloadFn = module.getOrInsertFunction(
171 builder.CreateCall(moduleUnloadFn, builder.CreateLoad(ptrTy, modulePtr));
172 builder.CreateRetVoid();
173 appendToGlobalDtors(module, unloadFn, 123);
179 LogicalResult SelectObjectAttrImpl::embedBinary(
182 assert(operation &&
"The binary operation must be non null.");
186 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
188 operation->
emitError(
"operation must be a GPU binary");
192 gpu::ObjectAttr
object = getSelectedObject(op);
204 LaunchKernel(Module &module, IRBuilderBase &builder,
207 FunctionCallee getKernelLaunchFn();
210 FunctionCallee getClusterKernelLaunchFn();
213 FunctionCallee getModuleFunctionFn();
216 FunctionCallee getStreamCreateFn();
219 FunctionCallee getStreamDestroyFn();
222 FunctionCallee getStreamSyncFn();
225 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
228 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
231 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
232 mlir::gpu::ObjectAttr
object);
236 IRBuilderBase &builder;
242 PointerType *ptrTy{};
249 Operation *binaryOperation, llvm::IRBuilderBase &builder,
252 assert(launchFuncOperation &&
"The launch func operation must be non null.");
253 if (!launchFuncOperation)
256 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
258 launchFuncOperation->
emitError(
"operation must be a GPU launch func Op.");
262 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
264 binaryOperation->
emitError(
"operation must be a GPU binary.");
267 gpu::ObjectAttr
object = getSelectedObject(binOp);
271 return llvm::LaunchKernel(*moduleTranslation.
getLLVMModule(), builder,
273 .createKernelLaunch(launchFuncOp,
object);
276 llvm::LaunchKernel::LaunchKernel(
277 Module &module, IRBuilderBase &builder,
279 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
280 i32Ty = builder.getInt32Ty();
281 i64Ty = builder.getInt64Ty();
282 ptrTy = builder.getPtrTy(0);
283 voidTy = builder.getVoidTy();
284 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
287 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
288 return module.getOrInsertFunction(
292 intPtrTy, intPtrTy, intPtrTy, i32Ty,
293 ptrTy, ptrTy, ptrTy, i64Ty}),
297 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
298 return module.getOrInsertFunction(
299 "mgpuLaunchClusterKernel",
303 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
304 i32Ty, ptrTy, ptrTy, ptrTy}),
308 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
309 return module.getOrInsertFunction(
310 "mgpuModuleGetFunction",
314 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
315 return module.getOrInsertFunction(
"mgpuStreamCreate",
319 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
320 return module.getOrInsertFunction(
325 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
326 return module.getOrInsertFunction(
327 "mgpuStreamSynchronize",
333 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
334 StringRef kernelName) {
335 std::string globalName =
336 std::string(formatv(
"{0}_{1}_name", moduleName, kernelName));
338 if (GlobalVariable *gv = module.getGlobalVariable(globalName,
true))
341 return builder.CreateGlobalString(kernelName, globalName);
358 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
364 structTypes[i] = arg->getType();
366 Type *structTy = StructType::create(module.getContext(), structTypes);
367 Value *argStruct = builder.CreateAlloca(structTy, 0u);
368 Value *argArray = builder.CreateAlloca(
372 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
373 builder.CreateStore(arg, structMember);
374 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
375 builder.CreateStore(structMember, arrayMember);
389 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
390 mlir::gpu::ObjectAttr
object) {
393 assert(v &&
"Value has not been translated.");
399 Value *gx = llvmValue(grid.
x), *gy = llvmValue(grid.
y),
400 *gz = llvmValue(grid.
z);
404 Value *bx = llvmValue(block.
x), *by = llvmValue(block.
y),
405 *bz = llvmValue(block.
z);
408 Value *dynamicMemorySize =
nullptr;
409 if (
mlir::Value dynSz = op.getDynamicSharedMemorySize())
410 dynamicMemorySize = llvmValue(dynSz);
415 Value *argArray = createKernelArgArray(op);
418 StringRef moduleName = op.getKernelModuleName().getValue();
420 Value *modulePtr = module.getGlobalVariable(moduleIdentifier.str(),
true);
422 return op.emitError() <<
"Couldn't find the binary: " << moduleIdentifier;
423 Value *moduleObj = builder.CreateLoad(ptrTy, modulePtr);
424 Value *functionName = getOrCreateFunctionName(moduleName, op.getKernelName());
425 Value *moduleFunction =
426 builder.CreateCall(getModuleFunctionFn(), {moduleObj, functionName});
430 Value *stream =
nullptr;
432 auto destroyStream = make_scope_exit([&]() {
433 builder.CreateCall(getStreamSyncFn(), {stream});
434 builder.CreateCall(getStreamDestroyFn(), {stream});
436 if (
mlir::Value asyncObject = op.getAsyncObject()) {
437 stream = llvmValue(asyncObject);
438 destroyStream.release();
440 stream = builder.CreateCall(getStreamCreateFn(), {});
443 llvm::Constant *paramsCount =
450 if (op.hasClusterSize()) {
452 Value *cx = llvmValue(cluster.
x), *cy = llvmValue(cluster.
y),
453 *cz = llvmValue(cluster.
z);
455 getClusterKernelLaunchFn(),
456 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
457 dynamicMemorySize, stream, argArray, nullPtr}));
459 builder.CreateCall(getKernelLaunchFn(),
461 bz, dynamicMemorySize, stream,
462 argArray, nullPtr, paramsCount}));
471 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Twine getModuleIdentifier(StringRef moduleName)
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The OpAsmOpInterface, see OpAsmInterface.td for more details.
static LogicalResult embedBinaryImpl(StringRef moduleName, gpu::ObjectAttr object, Module &module)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
@ Constant
Constant integer.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Utility class for the GPU dialect to represent triples of Values accessible through ....