MLIR  20.0.0git
ConvertLaunchFuncToVulkanCalls.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 static constexpr const char *kCInterfaceVulkanLaunch =
35  "_mlir_ciface_vulkanLaunch";
36 static constexpr const char *kDeinitVulkan = "deinitVulkan";
37 static constexpr const char *kRunOnVulkan = "runOnVulkan";
38 static constexpr const char *kInitVulkan = "initVulkan";
39 static constexpr const char *kSetBinaryShader = "setBinaryShader";
40 static constexpr const char *kSetEntryPoint = "setEntryPoint";
41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
45 static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
46 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
47 
48 namespace {
49 
50 /// A pass to convert vulkan launch call op into a sequence of Vulkan
51 /// runtime calls in the following order:
52 ///
53 /// * initVulkan -- initializes vulkan runtime
54 /// * bindMemRef -- binds memref
55 /// * setBinaryShader -- sets the binary shader data
56 /// * setEntryPoint -- sets the entry point name
57 /// * setNumWorkGroups -- sets the number of a local workgroups
58 /// * runOnVulkan -- runs vulkan runtime
59 /// * deinitVulkan -- deinitializes vulkan runtime
60 ///
61 class VulkanLaunchFuncToVulkanCallsPass
62  : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
63  VulkanLaunchFuncToVulkanCallsPass> {
64 private:
65  void initializeCachedTypes() {
66  llvmFloatType = Float32Type::get(&getContext());
67  llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
68  llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
69  llvmInt32Type = IntegerType::get(&getContext(), 32);
70  llvmInt64Type = IntegerType::get(&getContext(), 64);
71  }
72 
73  Type getMemRefType(uint32_t rank, Type elemenType) {
74  // According to the MLIR doc memref argument is converted into a
75  // pointer-to-struct argument of type:
76  // template <typename Elem, size_t Rank>
77  // struct {
78  // Elem *allocated;
79  // Elem *aligned;
80  // int64_t offset;
81  // int64_t sizes[Rank]; // omitted when rank == 0
82  // int64_t strides[Rank]; // omitted when rank == 0
83  // };
84  auto llvmArrayRankElementSizeType =
85  LLVM::LLVMArrayType::get(getInt64Type(), rank);
86 
87  // Create a type
88  // `!llvm<"{ `element-type`*, `element-type`*, i64,
89  // [`rank` x i64], [`rank` x i64]}">`.
90  return LLVM::LLVMStructType::getLiteral(
91  &getContext(),
92  {llvmPointerType, llvmPointerType, getInt64Type(),
93  llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
94  }
95 
96  Type getVoidType() { return llvmVoidType; }
97  Type getPointerType() { return llvmPointerType; }
98  Type getInt32Type() { return llvmInt32Type; }
99  Type getInt64Type() { return llvmInt64Type; }
100 
101  /// Creates an LLVM global for the given `name`.
102  Value createEntryPointNameConstant(StringRef name, Location loc,
103  OpBuilder &builder);
104 
105  /// Declares all needed runtime functions.
106  void declareVulkanFunctions(Location loc);
107 
108  /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
109  bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
110  return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch &&
111  callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
112  }
113 
114  /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
115  /// op.
116  bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
117  return (callOp.getCallee() &&
118  *callOp.getCallee() == kCInterfaceVulkanLaunch &&
119  callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
120  }
121 
122  /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
123  /// runtime calls.
124  void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
125 
126  /// Creates call to `bindMemRef` for each memref operand.
127  void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
128  Value vulkanRuntime);
129 
130  /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
131  void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
132 
133  /// Deduces a rank from the given 'launchCallArg`.
134  LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
135 
136  /// Returns a string representation from the given `type`.
137  StringRef stringifyType(Type type) {
138  if (isa<Float32Type>(type))
139  return "Float";
140  if (isa<Float16Type>(type))
141  return "Half";
142  if (auto intType = dyn_cast<IntegerType>(type)) {
143  if (intType.getWidth() == 32)
144  return "Int32";
145  if (intType.getWidth() == 16)
146  return "Int16";
147  if (intType.getWidth() == 8)
148  return "Int8";
149  }
150 
151  llvm_unreachable("unsupported type");
152  }
153 
154 public:
155  using Base::Base;
156 
157  void runOnOperation() override;
158 
159 private:
160  Type llvmFloatType;
161  Type llvmVoidType;
162  Type llvmPointerType;
163  Type llvmInt32Type;
164  Type llvmInt64Type;
165 
166  struct SPIRVAttributes {
167  StringAttr blob;
168  StringAttr entryPoint;
169  SmallVector<Type> elementTypes;
170  };
171 
172  // TODO: Use an associative array to support multiple vulkan launch calls.
173  SPIRVAttributes spirvAttributes;
174  /// The number of vulkan launch configuration operands, placed at the leading
175  /// positions of the operand list.
176  static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
177 };
178 
179 } // namespace
180 
181 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
182  initializeCachedTypes();
183 
184  // Collect SPIR-V attributes such as `spirv_blob` and
185  // `spirv_entry_point_name`.
186  getOperation().walk([this](LLVM::CallOp op) {
187  if (isVulkanLaunchCallOp(op))
188  collectSPIRVAttributes(op);
189  });
190 
191  // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
192  getOperation().walk([this](LLVM::CallOp op) {
193  if (isCInterfaceVulkanLaunchCallOp(op))
194  translateVulkanLaunchCall(op);
195  });
196 }
197 
198 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
199  LLVM::CallOp vulkanLaunchCallOp) {
200  // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
201  // for the given vulkan launch call.
202  auto spirvBlobAttr =
203  vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
204  if (!spirvBlobAttr) {
205  vulkanLaunchCallOp.emitError()
206  << "missing " << kSPIRVBlobAttrName << " attribute";
207  return signalPassFailure();
208  }
209 
210  auto spirvEntryPointNameAttr =
211  vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
212  if (!spirvEntryPointNameAttr) {
213  vulkanLaunchCallOp.emitError()
214  << "missing " << kSPIRVEntryPointAttrName << " attribute";
215  return signalPassFailure();
216  }
217 
218  auto spirvElementTypesAttr =
219  vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
220  if (!spirvElementTypesAttr) {
221  vulkanLaunchCallOp.emitError()
222  << "missing " << kSPIRVElementTypesAttrName << " attribute";
223  return signalPassFailure();
224  }
225  if (llvm::any_of(spirvElementTypesAttr,
226  [](Attribute attr) { return !isa<TypeAttr>(attr); })) {
227  vulkanLaunchCallOp.emitError()
228  << "expected " << spirvElementTypesAttr << " to be an array of types";
229  return signalPassFailure();
230  }
231 
232  spirvAttributes.blob = spirvBlobAttr;
233  spirvAttributes.entryPoint = spirvEntryPointNameAttr;
234  spirvAttributes.elementTypes =
235  llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
236 }
237 
238 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
239  LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
240  if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
241  kVulkanLaunchNumConfigOperands)
242  return;
243  OpBuilder builder(cInterfaceVulkanLaunchCallOp);
244  Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
245 
246  // Create LLVM constant for the descriptor set index.
247  // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
248  // pass does.
249  Value descriptorSet =
250  builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
251 
252  for (auto [index, ptrToMemRefDescriptor] :
253  llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
254  kVulkanLaunchNumConfigOperands))) {
255  // Create LLVM constant for the descriptor binding index.
256  Value descriptorBinding =
257  builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
258 
259  if (index >= spirvAttributes.elementTypes.size()) {
260  cInterfaceVulkanLaunchCallOp.emitError()
261  << kSPIRVElementTypesAttrName << " missing element type for "
262  << ptrToMemRefDescriptor;
263  return signalPassFailure();
264  }
265 
266  uint32_t rank = 0;
267  Type type = spirvAttributes.elementTypes[index];
268  if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
269  cInterfaceVulkanLaunchCallOp.emitError()
270  << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
271  return signalPassFailure();
272  }
273 
274  auto symbolName =
275  llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
276  // Create call to `bindMemRef`.
277  builder.create<LLVM::CallOp>(
278  loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
279  ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
280  ptrToMemRefDescriptor});
281  }
282 }
283 
284 LogicalResult
285 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
286  uint32_t &rank) {
287  // Deduce the rank from the type used to allocate the lowered MemRef.
288  auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
289  if (!alloca)
290  return failure();
291 
292  std::optional<Type> elementType = alloca.getElemType();
293  assert(elementType && "expected to work with opaque pointers");
294  auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
295  // template <typename Elem, size_t Rank>
296  // struct {
297  // Elem *allocated;
298  // Elem *aligned;
299  // int64_t offset;
300  // int64_t sizes[Rank]; // omitted when rank == 0
301  // int64_t strides[Rank]; // omitted when rank == 0
302  // };
303  if (!llvmDescriptorTy)
304  return failure();
305 
306  if (llvmDescriptorTy.getBody().size() == 3) {
307  rank = 0;
308  return success();
309  }
310  rank =
311  cast<LLVM::LLVMArrayType>(llvmDescriptorTy.getBody()[3]).getNumElements();
312  return success();
313 }
314 
315 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
316  ModuleOp module = getOperation();
317  auto builder = OpBuilder::atBlockEnd(module.getBody());
318 
319  if (!module.lookupSymbol(kSetEntryPoint)) {
320  builder.create<LLVM::LLVMFuncOp>(
321  loc, kSetEntryPoint,
322  LLVM::LLVMFunctionType::get(getVoidType(),
323  {getPointerType(), getPointerType()}));
324  }
325 
326  if (!module.lookupSymbol(kSetNumWorkGroups)) {
327  builder.create<LLVM::LLVMFuncOp>(
328  loc, kSetNumWorkGroups,
329  LLVM::LLVMFunctionType::get(getVoidType(),
330  {getPointerType(), getInt64Type(),
331  getInt64Type(), getInt64Type()}));
332  }
333 
334  if (!module.lookupSymbol(kSetBinaryShader)) {
335  builder.create<LLVM::LLVMFuncOp>(
336  loc, kSetBinaryShader,
338  getVoidType(),
339  {getPointerType(), getPointerType(), getInt32Type()}));
340  }
341 
342  if (!module.lookupSymbol(kRunOnVulkan)) {
343  builder.create<LLVM::LLVMFuncOp>(
344  loc, kRunOnVulkan,
345  LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
346  }
347 
348  for (unsigned i = 1; i <= 3; i++) {
349  SmallVector<Type, 5> types{
353  for (auto type : types) {
354  std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
355  std::string(stringifyType(type));
356  if (isa<Float16Type>(type))
357  type = IntegerType::get(&getContext(), 16);
358  if (!module.lookupSymbol(fnName)) {
359  auto fnType = LLVM::LLVMFunctionType::get(
360  getVoidType(),
361  {llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType},
362  /*isVarArg=*/false);
363  builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
364  }
365  }
366  }
367 
368  if (!module.lookupSymbol(kInitVulkan)) {
369  builder.create<LLVM::LLVMFuncOp>(
370  loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
371  }
372 
373  if (!module.lookupSymbol(kDeinitVulkan)) {
374  builder.create<LLVM::LLVMFuncOp>(
375  loc, kDeinitVulkan,
376  LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
377  }
378 }
379 
380 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
381  StringRef name, Location loc, OpBuilder &builder) {
382  SmallString<16> shaderName(name.begin(), name.end());
383  // Append `\0` to follow C style string given that LLVM::createGlobalString()
384  // won't handle this directly for us.
385  shaderName.push_back('\0');
386 
387  std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
388  return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
389  shaderName, LLVM::Linkage::Internal);
390 }
391 
392 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
393  LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
394  OpBuilder builder(cInterfaceVulkanLaunchCallOp);
395  Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
396  // Create call to `initVulkan`.
397  auto initVulkanCall = builder.create<LLVM::CallOp>(
398  loc, TypeRange{getPointerType()}, kInitVulkan);
399  // The result of `initVulkan` function is a pointer to Vulkan runtime, we
400  // need to pass that pointer to each Vulkan runtime call.
401  auto vulkanRuntime = initVulkanCall.getResult();
402 
403  // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
404  // that data to runtime call.
405  Value ptrToSPIRVBinary = LLVM::createGlobalString(
406  loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
407  LLVM::Linkage::Internal);
408 
409  // Create LLVM constant for the size of SPIR-V binary shader.
410  Value binarySize = builder.create<LLVM::ConstantOp>(
411  loc, getInt32Type(), spirvAttributes.blob.getValue().size());
412 
413  // Create call to `bindMemRef` for each memref operand.
414  createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
415 
416  // Create call to `setBinaryShader` runtime function with the given pointer to
417  // SPIR-V binary and binary size.
418  builder.create<LLVM::CallOp>(
419  loc, TypeRange(), kSetBinaryShader,
420  ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
421  // Create LLVM global with entry point name.
422  Value entryPointName = createEntryPointNameConstant(
423  spirvAttributes.entryPoint.getValue(), loc, builder);
424  // Create call to `setEntryPoint` runtime function with the given pointer to
425  // entry point name.
426  builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
427  ValueRange{vulkanRuntime, entryPointName});
428 
429  // Create number of local workgroup for each dimension.
430  builder.create<LLVM::CallOp>(
432  ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
433  cInterfaceVulkanLaunchCallOp.getOperand(1),
434  cInterfaceVulkanLaunchCallOp.getOperand(2)});
435 
436  // Create call to `runOnVulkan` runtime function.
437  builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
438  ValueRange{vulkanRuntime});
439 
440  // Create call to 'deinitVulkan' runtime function.
441  builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
442  ValueRange{vulkanRuntime});
443 
444  // Declare runtime functions.
445  declareVulkanFunctions(loc);
446 
447  cInterfaceVulkanLaunchCallOp.erase();
448 }
static constexpr const char * kVulkanLaunch
static constexpr const char * kCInterfaceVulkanLaunch
static constexpr const char * kInitVulkan
static constexpr const char * kSPIRVBinary
static constexpr const char * kSPIRVElementTypesAttrName
static constexpr const char * kDeinitVulkan
static constexpr const char * kSPIRVEntryPointAttrName
static constexpr const char * kSetEntryPoint
static constexpr const char * kSetBinaryShader
static constexpr const char * kSPIRVBlobAttrName
static constexpr const char * kSetNumWorkGroups
static constexpr const char * kRunOnVulkan
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:255
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...