21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Support/FormatVariadic.h"
31 class SelectObjectAttrImpl
32 :
public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
33 SelectObjectAttrImpl> {
38 llvm::IRBuilderBase &builder,
46 llvm::IRBuilderBase &builder,
50 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op)
const;
53 std::string getBinaryIdentifier(StringRef binaryName) {
54 return binaryName.str() +
"_bin_cst";
61 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
66 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op)
const {
72 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
76 if (
auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
77 index = indexAttr.getInt();
80 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
81 if (obj.getTarget() == target) {
92 if (index < 0 || index >=
static_cast<int64_t
>(objects.size())) {
93 op->emitError(
"the requested target object couldn't be found");
96 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
99 LogicalResult SelectObjectAttrImpl::embedBinary(
102 assert(operation &&
"The binary operation must be non null.");
106 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
108 operation->
emitError(
"operation must be a GPU binary");
112 gpu::ObjectAttr
object = getSelectedObject(op);
119 llvm::Constant *binary = llvm::ConstantDataArray::getString(
120 builder.getContext(),
object.getObject().getValue(),
false);
121 llvm::GlobalVariable *serializedObj =
122 new llvm::GlobalVariable(*module, binary->getType(),
true,
123 llvm::GlobalValue::LinkageTypes::InternalLinkage,
124 binary, getBinaryIdentifier(op.getName()));
126 if (
object.getProperties()) {
127 if (
auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
128 object.getProperties().
get(gpu::elfSectionName))) {
129 serializedObj->setSection(section.getValue());
132 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
133 serializedObj->setAlignment(llvm::MaybeAlign(8));
142 LaunchKernel(Module &module, IRBuilderBase &builder,
145 FunctionCallee getKernelLaunchFn();
148 FunctionCallee getClusterKernelLaunchFn();
151 FunctionCallee getModuleFunctionFn();
154 FunctionCallee getModuleLoadFn();
157 FunctionCallee getModuleLoadJITFn();
160 FunctionCallee getModuleUnloadFn();
163 FunctionCallee getStreamCreateFn();
166 FunctionCallee getStreamDestroyFn();
169 FunctionCallee getStreamSyncFn();
172 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
175 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
178 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
179 mlir::gpu::ObjectAttr
object);
183 IRBuilderBase &builder;
189 PointerType *ptrTy{};
196 Operation *binaryOperation, llvm::IRBuilderBase &builder,
199 assert(launchFuncOperation &&
"The launch func operation must be non null.");
200 if (!launchFuncOperation)
203 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
205 launchFuncOperation->
emitError(
"operation must be a GPU launch func Op.");
209 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
211 binaryOperation->
emitError(
"operation must be a GPU binary.");
214 gpu::ObjectAttr
object = getSelectedObject(binOp);
218 return llvm::LaunchKernel(*moduleTranslation.
getLLVMModule(), builder,
220 .createKernelLaunch(launchFuncOp,
object);
223 llvm::LaunchKernel::LaunchKernel(
224 Module &module, IRBuilderBase &builder,
226 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
227 i32Ty = builder.getInt32Ty();
228 i64Ty = builder.getInt64Ty();
229 ptrTy = builder.getPtrTy(0);
230 voidTy = builder.getVoidTy();
231 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
234 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
235 return module.getOrInsertFunction(
239 intPtrTy, intPtrTy, intPtrTy, i32Ty,
240 ptrTy, ptrTy, ptrTy, i64Ty}),
244 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
245 return module.getOrInsertFunction(
246 "mgpuLaunchClusterKernel",
250 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
251 i32Ty, ptrTy, ptrTy, ptrTy}),
255 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
256 return module.getOrInsertFunction(
257 "mgpuModuleGetFunction",
261 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
262 return module.getOrInsertFunction(
267 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
268 return module.getOrInsertFunction(
273 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
274 return module.getOrInsertFunction(
279 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
280 return module.getOrInsertFunction(
"mgpuStreamCreate",
284 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
285 return module.getOrInsertFunction(
290 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
291 return module.getOrInsertFunction(
292 "mgpuStreamSynchronize",
298 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
299 StringRef kernelName) {
300 std::string globalName =
301 std::string(formatv(
"{0}_{1}_kernel_name", moduleName, kernelName));
303 if (GlobalVariable *gv = module.getGlobalVariable(globalName))
306 return builder.CreateGlobalString(kernelName, globalName);
323 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
329 structTypes[i] = arg->getType();
331 Type *structTy = StructType::create(module.getContext(), structTypes);
332 Value *argStruct = builder.CreateAlloca(structTy, 0u);
333 Value *argArray = builder.CreateAlloca(
337 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
338 builder.CreateStore(arg, structMember);
339 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
340 builder.CreateStore(structMember, arrayMember);
357 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
358 mlir::gpu::ObjectAttr
object) {
361 assert(v &&
"Value has not been translated.");
367 Value *gx = llvmValue(grid.
x), *gy = llvmValue(grid.
y),
368 *gz = llvmValue(grid.
z);
372 Value *bx = llvmValue(block.
x), *by = llvmValue(block.
y),
373 *bz = llvmValue(block.
z);
376 Value *dynamicMemorySize =
nullptr;
377 if (
mlir::Value dynSz = op.getDynamicSharedMemorySize())
378 dynamicMemorySize = llvmValue(dynSz);
383 Value *argArray = createKernelArgArray(op);
388 DictionaryAttr objectProps =
object.getProperties();
390 if (objectProps && (optAttr = objectProps.get(
"O"))) {
391 auto optLevel = dyn_cast<IntegerAttr>(optAttr);
393 return op.emitError(
"the optimization level must be an integer");
398 StringRef moduleName = op.getKernelModuleName().getValue();
399 std::string binaryIdentifier = getBinaryIdentifier(moduleName);
400 Value *binary = module.getGlobalVariable(binaryIdentifier,
true);
402 return op.emitError() <<
"Couldn't find the binary: " << binaryIdentifier;
404 auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
406 return op.emitError() <<
"Binary is not a global variable: "
408 llvm::Constant *binaryInit = binaryVar->getInitializer();
410 dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
412 return op.emitError() <<
"Couldn't find binary data array: "
414 llvm::Constant *binarySize =
416 binaryDataSeq->getElementByteSize());
418 Value *moduleObject =
419 object.getFormat() == gpu::CompilationTarget::Assembly
420 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
421 : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
424 Value *moduleFunction = builder.CreateCall(
425 getModuleFunctionFn(),
427 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
431 Value *stream =
nullptr;
432 bool handleStream =
false;
433 if (
mlir::Value asyncObject = op.getAsyncObject()) {
434 stream = llvmValue(asyncObject);
437 stream = builder.CreateCall(getStreamCreateFn(), {});
440 llvm::Constant *paramsCount =
447 if (op.hasClusterSize()) {
449 Value *cx = llvmValue(cluster.
x), *cy = llvmValue(cluster.
y),
450 *cz = llvmValue(cluster.
z);
452 getClusterKernelLaunchFn(),
453 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
454 dynamicMemorySize, stream, argArray, nullPtr}));
456 builder.CreateCall(getKernelLaunchFn(),
458 bz, dynamicMemorySize, stream,
459 argArray, nullPtr, paramsCount}));
464 builder.CreateCall(getStreamSyncFn(), {stream});
465 builder.CreateCall(getStreamDestroyFn(), {stream});
469 builder.CreateCall(getModuleUnloadFn(), {moduleObject});
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Utility class for the GPU dialect to represent triples of Values accessible through ....