MLIR  20.0.0git
SelectObjectAttr.cpp
Go to the documentation of this file.
1 //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `OffloadingLLVMTranslationAttrInterface` for the
10 // `SelectObject` attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
19 
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
30 class SelectObjectAttrImpl
31  : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
32  SelectObjectAttrImpl> {
33 public:
34  // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
35  // global binary string.
36  LogicalResult embedBinary(Attribute attribute, Operation *operation,
37  llvm::IRBuilderBase &builder,
38  LLVM::ModuleTranslation &moduleTranslation) const;
39 
40  // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
41  // in a kernel launch call.
42  LogicalResult launchKernel(Attribute attribute,
43  Operation *launchFuncOperation,
44  Operation *binaryOperation,
45  llvm::IRBuilderBase &builder,
46  LLVM::ModuleTranslation &moduleTranslation) const;
47 
48  // Returns the selected object for embedding.
49  gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
50 };
51 // Returns an identifier for the global string holding the binary.
52 std::string getBinaryIdentifier(StringRef binaryName) {
53  return binaryName.str() + "_bin_cst";
54 }
55 } // namespace
56 
58  DialectRegistry &registry) {
59  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
60  SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
61  });
62 }
63 
64 gpu::ObjectAttr
65 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
66  ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
67 
68  // Obtain the index of the object to select.
69  int64_t index = -1;
70  if (Attribute target =
71  cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
72  .getTarget()) {
73  // If the target attribute is a number it is the index. Otherwise compare
74  // the attribute to every target inside the object array to find the index.
75  if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
76  index = indexAttr.getInt();
77  } else {
78  for (auto [i, attr] : llvm::enumerate(objects)) {
79  auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
80  if (obj.getTarget() == target) {
81  index = i;
82  }
83  }
84  }
85  } else {
86  // If the target attribute is null then it's selecting the first object in
87  // the object array.
88  index = 0;
89  }
90 
91  if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
92  op->emitError("the requested target object couldn't be found");
93  return nullptr;
94  }
95  return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
96 }
97 
98 LogicalResult SelectObjectAttrImpl::embedBinary(
99  Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
100  LLVM::ModuleTranslation &moduleTranslation) const {
101  assert(operation && "The binary operation must be non null.");
102  if (!operation)
103  return failure();
104 
105  auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
106  if (!op) {
107  operation->emitError("operation must be a GPU binary");
108  return failure();
109  }
110 
111  gpu::ObjectAttr object = getSelectedObject(op);
112  if (!object)
113  return failure();
114 
115  llvm::Module *module = moduleTranslation.getLLVMModule();
116 
117  // Embed the object as a global string.
118  llvm::Constant *binary = llvm::ConstantDataArray::getString(
119  builder.getContext(), object.getObject().getValue(), false);
120  llvm::GlobalVariable *serializedObj =
121  new llvm::GlobalVariable(*module, binary->getType(), true,
122  llvm::GlobalValue::LinkageTypes::InternalLinkage,
123  binary, getBinaryIdentifier(op.getName()));
124  serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
125  serializedObj->setAlignment(llvm::MaybeAlign(8));
126  serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
127  return success();
128 }
129 
130 namespace llvm {
131 namespace {
132 class LaunchKernel {
133 public:
134  LaunchKernel(Module &module, IRBuilderBase &builder,
135  mlir::LLVM::ModuleTranslation &moduleTranslation);
136  // Get the kernel launch callee.
137  FunctionCallee getKernelLaunchFn();
138 
139  // Get the kernel launch callee.
140  FunctionCallee getClusterKernelLaunchFn();
141 
142  // Get the module function callee.
143  FunctionCallee getModuleFunctionFn();
144 
145  // Get the module load callee.
146  FunctionCallee getModuleLoadFn();
147 
148  // Get the module load JIT callee.
149  FunctionCallee getModuleLoadJITFn();
150 
151  // Get the module unload callee.
152  FunctionCallee getModuleUnloadFn();
153 
154  // Get the stream create callee.
155  FunctionCallee getStreamCreateFn();
156 
157  // Get the stream destroy callee.
158  FunctionCallee getStreamDestroyFn();
159 
160  // Get the stream sync callee.
161  FunctionCallee getStreamSyncFn();
162 
163  // Ger or create the function name global string.
164  Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
165 
166  // Create the void* kernel array for passing the arguments.
167  Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
168 
169  // Create the full kernel launch.
170  llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
171  mlir::gpu::ObjectAttr object);
172 
173 private:
174  Module &module;
175  IRBuilderBase &builder;
176  mlir::LLVM::ModuleTranslation &moduleTranslation;
177  Type *i32Ty{};
178  Type *i64Ty{};
179  Type *voidTy{};
180  Type *intPtrTy{};
181  PointerType *ptrTy{};
182 };
183 } // namespace
184 } // namespace llvm
185 
187  Attribute attribute, Operation *launchFuncOperation,
188  Operation *binaryOperation, llvm::IRBuilderBase &builder,
189  LLVM::ModuleTranslation &moduleTranslation) const {
190 
191  assert(launchFuncOperation && "The launch func operation must be non null.");
192  if (!launchFuncOperation)
193  return failure();
194 
195  auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
196  if (!launchFuncOp) {
197  launchFuncOperation->emitError("operation must be a GPU launch func Op.");
198  return failure();
199  }
200 
201  auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
202  if (!binOp) {
203  binaryOperation->emitError("operation must be a GPU binary.");
204  return failure();
205  }
206  gpu::ObjectAttr object = getSelectedObject(binOp);
207  if (!object)
208  return failure();
209 
210  return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
211  moduleTranslation)
212  .createKernelLaunch(launchFuncOp, object);
213 }
214 
215 llvm::LaunchKernel::LaunchKernel(
216  Module &module, IRBuilderBase &builder,
217  mlir::LLVM::ModuleTranslation &moduleTranslation)
218  : module(module), builder(builder), moduleTranslation(moduleTranslation) {
219  i32Ty = builder.getInt32Ty();
220  i64Ty = builder.getInt64Ty();
221  ptrTy = builder.getPtrTy(0);
222  voidTy = builder.getVoidTy();
223  intPtrTy = builder.getIntPtrTy(module.getDataLayout());
224 }
225 
226 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
227  return module.getOrInsertFunction(
228  "mgpuLaunchKernel",
229  FunctionType::get(voidTy,
230  ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
231  intPtrTy, intPtrTy, intPtrTy, i32Ty,
232  ptrTy, ptrTy, ptrTy, i64Ty}),
233  false));
234 }
235 
236 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
237  return module.getOrInsertFunction(
238  "mgpuLaunchClusterKernel",
240  voidTy,
241  ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
242  intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
243  i32Ty, ptrTy, ptrTy, ptrTy}),
244  false));
245 }
246 
247 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
248  return module.getOrInsertFunction(
249  "mgpuModuleGetFunction",
250  FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
251 }
252 
253 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
254  return module.getOrInsertFunction(
255  "mgpuModuleLoad",
256  FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
257 }
258 
259 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
260  return module.getOrInsertFunction(
261  "mgpuModuleLoadJIT",
262  FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
263 }
264 
265 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
266  return module.getOrInsertFunction(
267  "mgpuModuleUnload",
268  FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
269 }
270 
271 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
272  return module.getOrInsertFunction("mgpuStreamCreate",
273  FunctionType::get(ptrTy, false));
274 }
275 
276 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
277  return module.getOrInsertFunction(
278  "mgpuStreamDestroy",
279  FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
280 }
281 
282 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
283  return module.getOrInsertFunction(
284  "mgpuStreamSynchronize",
285  FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
286 }
287 
288 // Generates an LLVM IR dialect global that contains the name of the given
289 // kernel function as a C string, and returns a pointer to its beginning.
290 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
291  StringRef kernelName) {
292  std::string globalName =
293  std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
294 
295  if (GlobalVariable *gv = module.getGlobalVariable(globalName))
296  return gv;
297 
298  return builder.CreateGlobalString(kernelName, globalName);
299 }
300 
301 // Creates a struct containing all kernel parameters on the stack and returns
302 // an array of type-erased pointers to the fields of the struct. The array can
303 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
304 // The generated code is essentially as follows:
305 //
306 // %struct = alloca(sizeof(struct { Parameters... }))
307 // %array = alloca(NumParameters * sizeof(void *))
308 // for (i : [0, NumParameters))
309 // %fieldPtr = llvm.getelementptr %struct[0, i]
310 // llvm.store parameters[i], %fieldPtr
311 // %elementPtr = llvm.getelementptr %array[i]
312 // llvm.store %fieldPtr, %elementPtr
313 // return %array
314 llvm::Value *
315 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
316  SmallVector<Value *> args =
317  moduleTranslation.lookupValues(op.getKernelOperands());
318  SmallVector<Type *> structTypes(args.size(), nullptr);
319 
320  for (auto [i, arg] : llvm::enumerate(args))
321  structTypes[i] = arg->getType();
322 
323  Type *structTy = StructType::create(module.getContext(), structTypes);
324  Value *argStruct = builder.CreateAlloca(structTy, 0u);
325  Value *argArray = builder.CreateAlloca(
326  ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
327 
328  for (auto [i, arg] : enumerate(args)) {
329  Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
330  builder.CreateStore(arg, structMember);
331  Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
332  builder.CreateStore(structMember, arrayMember);
333  }
334  return argArray;
335 }
336 
337 // Emits LLVM IR to launch a kernel function:
338 // %0 = call %binarygetter
339 // %1 = call %moduleLoad(%0)
340 // %2 = <see generateKernelNameConstant>
341 // %3 = call %moduleGetFunction(%1, %2)
342 // %4 = call %streamCreate()
343 // %5 = <see generateParamsArray>
344 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
345 // call %streamSynchronize(%4)
346 // call %streamDestroy(%4)
347 // call %moduleUnload(%1)
348 llvm::LogicalResult
349 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
350  mlir::gpu::ObjectAttr object) {
351  auto llvmValue = [&](mlir::Value value) -> Value * {
352  Value *v = moduleTranslation.lookupValue(value);
353  assert(v && "Value has not been translated.");
354  return v;
355  };
356 
357  // Get grid dimensions.
358  mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
359  Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
360  *gz = llvmValue(grid.z);
361 
362  // Get block dimensions.
363  mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
364  Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
365  *bz = llvmValue(block.z);
366 
367  // Get dynamic shared memory size.
368  Value *dynamicMemorySize = nullptr;
369  if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
370  dynamicMemorySize = llvmValue(dynSz);
371  else
372  dynamicMemorySize = ConstantInt::get(i32Ty, 0);
373 
374  // Create the argument array.
375  Value *argArray = createKernelArgArray(op);
376 
377  // Default JIT optimization level.
378  llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
379  // Check if there's an optimization level embedded in the object.
380  DictionaryAttr objectProps = object.getProperties();
381  mlir::Attribute optAttr;
382  if (objectProps && (optAttr = objectProps.get("O"))) {
383  auto optLevel = dyn_cast<IntegerAttr>(optAttr);
384  if (!optLevel)
385  return op.emitError("the optimization level must be an integer");
386  optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
387  }
388 
389  // Load the kernel module.
390  StringRef moduleName = op.getKernelModuleName().getValue();
391  std::string binaryIdentifier = getBinaryIdentifier(moduleName);
392  Value *binary = module.getGlobalVariable(binaryIdentifier, true);
393  if (!binary)
394  return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
395 
396  auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
397  if (!binaryVar)
398  return op.emitError() << "Binary is not a global variable: "
399  << binaryIdentifier;
400  llvm::Constant *binaryInit = binaryVar->getInitializer();
401  auto binaryDataSeq =
402  dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
403  if (!binaryDataSeq)
404  return op.emitError() << "Couldn't find binary data array: "
405  << binaryIdentifier;
406  llvm::Constant *binarySize =
407  llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
408  binaryDataSeq->getElementByteSize());
409 
410  Value *moduleObject =
411  object.getFormat() == gpu::CompilationTarget::Assembly
412  ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
413  : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
414 
415  // Load the kernel function.
416  Value *moduleFunction = builder.CreateCall(
417  getModuleFunctionFn(),
418  {moduleObject,
419  getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
420 
421  // Get the stream to use for execution. If there's no async object then create
422  // a stream to make a synchronous kernel launch.
423  Value *stream = nullptr;
424  bool handleStream = false;
425  if (mlir::Value asyncObject = op.getAsyncObject()) {
426  stream = llvmValue(asyncObject);
427  } else {
428  handleStream = true;
429  stream = builder.CreateCall(getStreamCreateFn(), {});
430  }
431 
432  llvm::Constant *paramsCount =
433  llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
434 
435  // Create the launch call.
436  Value *nullPtr = ConstantPointerNull::get(ptrTy);
437 
438  // Launch kernel with clusters if cluster size is specified.
439  if (op.hasClusterSize()) {
440  mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
441  Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
442  *cz = llvmValue(cluster.z);
443  builder.CreateCall(
444  getClusterKernelLaunchFn(),
445  ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
446  dynamicMemorySize, stream, argArray, nullPtr}));
447  } else {
448  builder.CreateCall(getKernelLaunchFn(),
449  ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
450  bz, dynamicMemorySize, stream,
451  argArray, nullPtr, paramsCount}));
452  }
453 
454  // Sync & destroy the stream, for synchronous launches.
455  if (handleStream) {
456  builder.CreateCall(getStreamSyncFn(), {stream});
457  builder.CreateCall(getStreamDestroyFn(), {stream});
458  }
459 
460  // Unload the kernel module.
461  builder.CreateCall(getModuleUnloadFn(), {moduleObject});
462 
463  return success();
464 }
@ None
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, void **params, size_t paramsCount)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerOffloadingLLVMTranslationInterfaceExternalModels(mlir::DialectRegistry &registry)
Registers the offloading LLVM translation interfaces for gpu.select_object.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38