20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/FormatVariadic.h"
30 class SelectObjectAttrImpl
31 :
public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
32 SelectObjectAttrImpl> {
37 llvm::IRBuilderBase &builder,
45 llvm::IRBuilderBase &builder,
49 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op)
const;
52 std::string getBinaryIdentifier(StringRef binaryName) {
53 return binaryName.str() +
"_bin_cst";
60 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
65 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op)
const {
71 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
75 if (
auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
76 index = indexAttr.getInt();
79 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
80 if (obj.getTarget() == target) {
91 if (index < 0 || index >=
static_cast<int64_t
>(objects.size())) {
92 op->emitError(
"the requested target object couldn't be found");
95 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
98 LogicalResult SelectObjectAttrImpl::embedBinary(
101 assert(operation &&
"The binary operation must be non null.");
105 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
107 operation->
emitError(
"operation must be a GPU binary");
111 gpu::ObjectAttr
object = getSelectedObject(op);
118 llvm::Constant *binary = llvm::ConstantDataArray::getString(
119 builder.getContext(),
object.getObject().getValue(),
false);
120 llvm::GlobalVariable *serializedObj =
121 new llvm::GlobalVariable(*module, binary->getType(),
true,
122 llvm::GlobalValue::LinkageTypes::InternalLinkage,
123 binary, getBinaryIdentifier(op.getName()));
125 if (
object.getProperties()) {
126 if (
auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
127 object.getProperties().
get(
"section"))) {
128 serializedObj->setSection(section.getValue());
131 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
132 serializedObj->setAlignment(llvm::MaybeAlign(8));
141 LaunchKernel(Module &module, IRBuilderBase &builder,
144 FunctionCallee getKernelLaunchFn();
147 FunctionCallee getClusterKernelLaunchFn();
150 FunctionCallee getModuleFunctionFn();
153 FunctionCallee getModuleLoadFn();
156 FunctionCallee getModuleLoadJITFn();
159 FunctionCallee getModuleUnloadFn();
162 FunctionCallee getStreamCreateFn();
165 FunctionCallee getStreamDestroyFn();
168 FunctionCallee getStreamSyncFn();
171 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
174 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
177 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
178 mlir::gpu::ObjectAttr
object);
182 IRBuilderBase &builder;
188 PointerType *ptrTy{};
195 Operation *binaryOperation, llvm::IRBuilderBase &builder,
198 assert(launchFuncOperation &&
"The launch func operation must be non null.");
199 if (!launchFuncOperation)
202 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
204 launchFuncOperation->
emitError(
"operation must be a GPU launch func Op.");
208 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
210 binaryOperation->
emitError(
"operation must be a GPU binary.");
213 gpu::ObjectAttr
object = getSelectedObject(binOp);
217 return llvm::LaunchKernel(*moduleTranslation.
getLLVMModule(), builder,
219 .createKernelLaunch(launchFuncOp,
object);
222 llvm::LaunchKernel::LaunchKernel(
223 Module &module, IRBuilderBase &builder,
225 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
226 i32Ty = builder.getInt32Ty();
227 i64Ty = builder.getInt64Ty();
228 ptrTy = builder.getPtrTy(0);
229 voidTy = builder.getVoidTy();
230 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
233 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
234 return module.getOrInsertFunction(
238 intPtrTy, intPtrTy, intPtrTy, i32Ty,
239 ptrTy, ptrTy, ptrTy, i64Ty}),
243 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
244 return module.getOrInsertFunction(
245 "mgpuLaunchClusterKernel",
249 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
250 i32Ty, ptrTy, ptrTy, ptrTy}),
254 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
255 return module.getOrInsertFunction(
256 "mgpuModuleGetFunction",
260 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
261 return module.getOrInsertFunction(
266 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
267 return module.getOrInsertFunction(
272 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
273 return module.getOrInsertFunction(
278 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
279 return module.getOrInsertFunction(
"mgpuStreamCreate",
283 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
284 return module.getOrInsertFunction(
289 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
290 return module.getOrInsertFunction(
291 "mgpuStreamSynchronize",
297 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
298 StringRef kernelName) {
299 std::string globalName =
300 std::string(formatv(
"{0}_{1}_kernel_name", moduleName, kernelName));
302 if (GlobalVariable *gv = module.getGlobalVariable(globalName))
305 return builder.CreateGlobalString(kernelName, globalName);
322 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
328 structTypes[i] = arg->getType();
330 Type *structTy = StructType::create(module.getContext(), structTypes);
331 Value *argStruct = builder.CreateAlloca(structTy, 0u);
332 Value *argArray = builder.CreateAlloca(
336 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
337 builder.CreateStore(arg, structMember);
338 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
339 builder.CreateStore(structMember, arrayMember);
356 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
357 mlir::gpu::ObjectAttr
object) {
360 assert(v &&
"Value has not been translated.");
366 Value *gx = llvmValue(grid.
x), *gy = llvmValue(grid.
y),
367 *gz = llvmValue(grid.
z);
371 Value *bx = llvmValue(block.
x), *by = llvmValue(block.
y),
372 *bz = llvmValue(block.
z);
375 Value *dynamicMemorySize =
nullptr;
376 if (
mlir::Value dynSz = op.getDynamicSharedMemorySize())
377 dynamicMemorySize = llvmValue(dynSz);
382 Value *argArray = createKernelArgArray(op);
387 DictionaryAttr objectProps =
object.getProperties();
389 if (objectProps && (optAttr = objectProps.get(
"O"))) {
390 auto optLevel = dyn_cast<IntegerAttr>(optAttr);
392 return op.emitError(
"the optimization level must be an integer");
397 StringRef moduleName = op.getKernelModuleName().getValue();
398 std::string binaryIdentifier = getBinaryIdentifier(moduleName);
399 Value *binary = module.getGlobalVariable(binaryIdentifier,
true);
401 return op.emitError() <<
"Couldn't find the binary: " << binaryIdentifier;
403 auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
405 return op.emitError() <<
"Binary is not a global variable: "
407 llvm::Constant *binaryInit = binaryVar->getInitializer();
409 dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
411 return op.emitError() <<
"Couldn't find binary data array: "
413 llvm::Constant *binarySize =
415 binaryDataSeq->getElementByteSize());
417 Value *moduleObject =
418 object.getFormat() == gpu::CompilationTarget::Assembly
419 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
420 : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
423 Value *moduleFunction = builder.CreateCall(
424 getModuleFunctionFn(),
426 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
430 Value *stream =
nullptr;
431 bool handleStream =
false;
432 if (
mlir::Value asyncObject = op.getAsyncObject()) {
433 stream = llvmValue(asyncObject);
436 stream = builder.CreateCall(getStreamCreateFn(), {});
439 llvm::Constant *paramsCount =
446 if (op.hasClusterSize()) {
448 Value *cx = llvmValue(cluster.
x), *cy = llvmValue(cluster.
y),
449 *cz = llvmValue(cluster.
z);
451 getClusterKernelLaunchFn(),
452 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
453 dynamicMemorySize, stream, argArray, nullPtr}));
455 builder.CreateCall(getKernelLaunchFn(),
457 bz, dynamicMemorySize, stream,
458 argArray, nullPtr, paramsCount}));
463 builder.CreateCall(getStreamSyncFn(), {stream});
464 builder.CreateCall(getStreamDestroyFn(), {stream});
468 builder.CreateCall(getModuleUnloadFn(), {moduleObject});
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Utility class for the GPU dialect to represent triples of Values accessible through ....