20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/FormatVariadic.h"
30 class SelectObjectAttrImpl
31 :
public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
32 SelectObjectAttrImpl> {
37 llvm::IRBuilderBase &builder,
45 llvm::IRBuilderBase &builder,
49 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op)
const;
52 std::string getBinaryIdentifier(StringRef binaryName) {
53 return binaryName.str() +
"_bin_cst";
60 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
65 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op)
const {
71 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
75 if (
auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
76 index = indexAttr.getInt();
79 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
80 if (obj.getTarget() == target) {
91 if (index < 0 || index >=
static_cast<int64_t
>(objects.size())) {
92 op->
emitError(
"the requested target object couldn't be found");
95 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
101 assert(operation &&
"The binary operation must be non null.");
105 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
107 operation->
emitError(
"operation must be a GPU binary");
111 gpu::ObjectAttr
object = getSelectedObject(op);
118 llvm::Constant *binary = llvm::ConstantDataArray::getString(
119 builder.getContext(),
object.getObject().getValue(),
false);
120 llvm::GlobalVariable *serializedObj =
121 new llvm::GlobalVariable(*module, binary->getType(),
true,
122 llvm::GlobalValue::LinkageTypes::InternalLinkage,
123 binary, getBinaryIdentifier(op.
getName()));
124 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
125 serializedObj->setAlignment(llvm::MaybeAlign(8));
134 LaunchKernel(Module &module, IRBuilderBase &builder,
137 FunctionCallee getKernelLaunchFn();
140 FunctionCallee getClusterKernelLaunchFn();
143 FunctionCallee getModuleFunctionFn();
146 FunctionCallee getModuleLoadFn();
149 FunctionCallee getModuleLoadJITFn();
152 FunctionCallee getModuleUnloadFn();
155 FunctionCallee getStreamCreateFn();
158 FunctionCallee getStreamDestroyFn();
161 FunctionCallee getStreamSyncFn();
164 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
167 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
171 mlir::gpu::ObjectAttr
object);
175 IRBuilderBase &builder;
180 PointerType *ptrTy{};
187 Operation *binaryOperation, llvm::IRBuilderBase &builder,
190 assert(launchFuncOperation &&
"The launch func operation must be non null.");
191 if (!launchFuncOperation)
194 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
196 launchFuncOperation->
emitError(
"operation must be a GPU launch func Op.");
200 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
202 binaryOperation->
emitError(
"operation must be a GPU binary.");
205 gpu::ObjectAttr
object = getSelectedObject(binOp);
209 return llvm::LaunchKernel(*moduleTranslation.
getLLVMModule(), builder,
211 .createKernelLaunch(launchFuncOp,
object);
214 llvm::LaunchKernel::LaunchKernel(
215 Module &module, IRBuilderBase &builder,
217 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
218 i32Ty = builder.getInt32Ty();
219 ptrTy = builder.getPtrTy(0);
220 voidTy = builder.getVoidTy();
221 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
224 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
225 return module.getOrInsertFunction(
230 intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy}),
234 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
235 return module.getOrInsertFunction(
236 "mgpuLaunchClusterKernel",
240 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
241 i32Ty, ptrTy, ptrTy, ptrTy}),
245 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
246 return module.getOrInsertFunction(
247 "mgpuModuleGetFunction",
251 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
252 return module.getOrInsertFunction(
257 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
258 return module.getOrInsertFunction(
263 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
264 return module.getOrInsertFunction(
269 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
270 return module.getOrInsertFunction(
"mgpuStreamCreate",
274 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
275 return module.getOrInsertFunction(
280 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
281 return module.getOrInsertFunction(
282 "mgpuStreamSynchronize",
288 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
289 StringRef kernelName) {
290 std::string globalName =
291 std::string(formatv(
"{0}_{1}_kernel_name", moduleName, kernelName));
293 if (GlobalVariable *gv = module.getGlobalVariable(globalName))
296 return builder.CreateGlobalString(kernelName, globalName);
313 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
319 structTypes[i] = arg->getType();
321 Type *structTy = StructType::create(module.getContext(), structTypes);
322 Value *argStruct = builder.CreateAlloca(structTy, 0u);
323 Value *argArray = builder.CreateAlloca(
327 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
328 builder.CreateStore(arg, structMember);
329 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
330 builder.CreateStore(structMember, arrayMember);
347 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
348 mlir::gpu::ObjectAttr
object) {
351 assert(v &&
"Value has not been translated.");
357 Value *gx = llvmValue(grid.
x), *gy = llvmValue(grid.
y),
358 *gz = llvmValue(grid.
z);
362 Value *bx = llvmValue(block.
x), *by = llvmValue(block.
y),
363 *bz = llvmValue(block.
z);
366 Value *dynamicMemorySize =
nullptr;
367 if (
mlir::Value dynSz = op.getDynamicSharedMemorySize())
368 dynamicMemorySize = llvmValue(dynSz);
373 Value *argArray = createKernelArgArray(op);
378 DictionaryAttr objectProps =
object.getProperties();
380 if (objectProps && (optAttr = objectProps.get(
"O"))) {
381 auto optLevel = dyn_cast<IntegerAttr>(optAttr);
383 return op.
emitError(
"the optimization level must be an integer");
388 StringRef moduleName = op.getKernelModuleName().getValue();
389 std::string binaryIdentifier = getBinaryIdentifier(moduleName);
390 Value *binary = module.getGlobalVariable(binaryIdentifier,
true);
392 return op.
emitError() <<
"Couldn't find the binary: " << binaryIdentifier;
394 Value *moduleObject =
395 object.getFormat() == gpu::CompilationTarget::Assembly
396 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
397 : builder.CreateCall(getModuleLoadFn(), {binary});
400 Value *moduleFunction = builder.CreateCall(
401 getModuleFunctionFn(),
403 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
407 Value *stream =
nullptr;
408 bool handleStream =
false;
409 if (
mlir::Value asyncObject = op.getAsyncObject()) {
410 stream = llvmValue(asyncObject);
413 stream = builder.CreateCall(getStreamCreateFn(), {});
420 if (op.hasClusterSize()) {
422 Value *cx = llvmValue(cluster.
x), *cy = llvmValue(cluster.
y),
423 *cz = llvmValue(cluster.
z);
425 getClusterKernelLaunchFn(),
426 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
427 dynamicMemorySize, stream, argArray, nullPtr}));
432 dynamicMemorySize, stream, argArray, nullPtr}));
437 builder.CreateCall(getStreamSyncFn(), {stream});
438 builder.CreateCall(getStreamDestroyFn(), {stream});
442 builder.CreateCall(getModuleUnloadFn(), {moduleObject});
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Utility class for the GPU dialect to represent triples of Values accessible through ....