20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/FormatVariadic.h"
30 class SelectObjectAttrImpl
31 :
public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
32 SelectObjectAttrImpl> {
37 llvm::IRBuilderBase &builder,
45 llvm::IRBuilderBase &builder,
49 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op)
const;
52 std::string getBinaryIdentifier(StringRef binaryName) {
53 return binaryName.str() +
"_bin_cst";
60 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
65 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op)
const {
71 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
75 if (
auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
76 index = indexAttr.getInt();
79 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
80 if (obj.getTarget() == target) {
91 if (index < 0 || index >=
static_cast<int64_t
>(objects.size())) {
92 op->
emitError(
"the requested target object couldn't be found");
95 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
101 assert(operation &&
"The binary operation must be non null.");
105 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
107 operation->
emitError(
"operation must be a GPU binary");
111 gpu::ObjectAttr
object = getSelectedObject(op);
118 llvm::Constant *binary = llvm::ConstantDataArray::getString(
119 builder.getContext(),
object.getObject().getValue(),
false);
120 llvm::GlobalVariable *serializedObj =
121 new llvm::GlobalVariable(*module, binary->getType(),
true,
122 llvm::GlobalValue::LinkageTypes::InternalLinkage,
123 binary, getBinaryIdentifier(op.
getName()));
124 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
125 serializedObj->setAlignment(llvm::MaybeAlign(8));
134 LaunchKernel(Module &module, IRBuilderBase &builder,
137 FunctionCallee getKernelLaunchFn();
140 FunctionCallee getClusterKernelLaunchFn();
143 FunctionCallee getModuleFunctionFn();
146 FunctionCallee getModuleLoadFn();
149 FunctionCallee getModuleLoadJITFn();
152 FunctionCallee getModuleUnloadFn();
155 FunctionCallee getStreamCreateFn();
158 FunctionCallee getStreamDestroyFn();
161 FunctionCallee getStreamSyncFn();
164 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
167 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
171 mlir::gpu::ObjectAttr
object);
175 IRBuilderBase &builder;
181 PointerType *ptrTy{};
188 Operation *binaryOperation, llvm::IRBuilderBase &builder,
191 assert(launchFuncOperation &&
"The launch func operation must be non null.");
192 if (!launchFuncOperation)
195 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
197 launchFuncOperation->
emitError(
"operation must be a GPU launch func Op.");
201 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
203 binaryOperation->
emitError(
"operation must be a GPU binary.");
206 gpu::ObjectAttr
object = getSelectedObject(binOp);
210 return llvm::LaunchKernel(*moduleTranslation.
getLLVMModule(), builder,
212 .createKernelLaunch(launchFuncOp,
object);
215 llvm::LaunchKernel::LaunchKernel(
216 Module &module, IRBuilderBase &builder,
218 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
219 i32Ty = builder.getInt32Ty();
220 i64Ty = builder.getInt64Ty();
221 ptrTy = builder.getPtrTy(0);
222 voidTy = builder.getVoidTy();
223 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
226 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
227 return module.getOrInsertFunction(
231 intPtrTy, intPtrTy, intPtrTy, i32Ty,
232 ptrTy, ptrTy, ptrTy, i64Ty}),
236 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
237 return module.getOrInsertFunction(
238 "mgpuLaunchClusterKernel",
242 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
243 i32Ty, ptrTy, ptrTy, ptrTy}),
247 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
248 return module.getOrInsertFunction(
249 "mgpuModuleGetFunction",
253 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
254 return module.getOrInsertFunction(
259 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
260 return module.getOrInsertFunction(
265 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
266 return module.getOrInsertFunction(
271 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
272 return module.getOrInsertFunction(
"mgpuStreamCreate",
276 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
277 return module.getOrInsertFunction(
282 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
283 return module.getOrInsertFunction(
284 "mgpuStreamSynchronize",
290 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
291 StringRef kernelName) {
292 std::string globalName =
293 std::string(formatv(
"{0}_{1}_kernel_name", moduleName, kernelName));
295 if (GlobalVariable *gv = module.getGlobalVariable(globalName))
298 return builder.CreateGlobalString(kernelName, globalName);
315 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
321 structTypes[i] = arg->getType();
323 Type *structTy = StructType::create(module.getContext(), structTypes);
324 Value *argStruct = builder.CreateAlloca(structTy, 0u);
325 Value *argArray = builder.CreateAlloca(
329 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
330 builder.CreateStore(arg, structMember);
331 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
332 builder.CreateStore(structMember, arrayMember);
349 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
350 mlir::gpu::ObjectAttr
object) {
353 assert(v &&
"Value has not been translated.");
359 Value *gx = llvmValue(grid.
x), *gy = llvmValue(grid.
y),
360 *gz = llvmValue(grid.
z);
364 Value *bx = llvmValue(block.
x), *by = llvmValue(block.
y),
365 *bz = llvmValue(block.
z);
368 Value *dynamicMemorySize =
nullptr;
369 if (
mlir::Value dynSz = op.getDynamicSharedMemorySize())
370 dynamicMemorySize = llvmValue(dynSz);
375 Value *argArray = createKernelArgArray(op);
380 DictionaryAttr objectProps =
object.getProperties();
382 if (objectProps && (optAttr = objectProps.get(
"O"))) {
383 auto optLevel = dyn_cast<IntegerAttr>(optAttr);
385 return op.
emitError(
"the optimization level must be an integer");
390 StringRef moduleName = op.getKernelModuleName().getValue();
391 std::string binaryIdentifier = getBinaryIdentifier(moduleName);
392 Value *binary = module.getGlobalVariable(binaryIdentifier,
true);
394 return op.
emitError() <<
"Couldn't find the binary: " << binaryIdentifier;
396 auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
398 return op.
emitError() <<
"Binary is not a global variable: "
400 llvm::Constant *binaryInit = binaryVar->getInitializer();
402 dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
404 return op.
emitError() <<
"Couldn't find binary data array: "
406 llvm::Constant *binarySize =
408 binaryDataSeq->getElementByteSize());
410 Value *moduleObject =
411 object.getFormat() == gpu::CompilationTarget::Assembly
412 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
413 : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
416 Value *moduleFunction = builder.CreateCall(
417 getModuleFunctionFn(),
419 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
423 Value *stream =
nullptr;
424 bool handleStream =
false;
425 if (
mlir::Value asyncObject = op.getAsyncObject()) {
426 stream = llvmValue(asyncObject);
429 stream = builder.CreateCall(getStreamCreateFn(), {});
432 llvm::Constant *paramsCount =
439 if (op.hasClusterSize()) {
441 Value *cx = llvmValue(cluster.
x), *cy = llvmValue(cluster.
y),
442 *cz = llvmValue(cluster.
z);
444 getClusterKernelLaunchFn(),
445 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
446 dynamicMemorySize, stream, argArray, nullPtr}));
448 builder.CreateCall(getKernelLaunchFn(),
450 bz, dynamicMemorySize, stream,
451 argArray, nullPtr, paramsCount}));
456 builder.CreateCall(getStreamSyncFn(), {stream});
457 builder.CreateCall(getStreamDestroyFn(), {stream});
461 builder.CreateCall(getModuleUnloadFn(), {moduleObject});
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Utility class for the GPU dialect to represent triples of Values accessible through ....