24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
28 #define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
29 #include "mlir/Conversion/Passes.h.inc"
35 "_mlir_ciface_vulkanLaunch";
61 class VulkanLaunchFuncToVulkanCallsPass
62 :
public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
63 VulkanLaunchFuncToVulkanCallsPass> {
65 void initializeCachedTypes() {
84 auto llvmArrayRankElementSizeType =
90 return LLVM::LLVMStructType::getLiteral(
92 {llvmPointerType, llvmPointerType, getInt64Type(),
93 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
96 Type getVoidType() {
return llvmVoidType; }
97 Type getPointerType() {
return llvmPointerType; }
98 Type getInt32Type() {
return llvmInt32Type; }
99 Type getInt64Type() {
return llvmInt64Type; }
102 Value createEntryPointNameConstant(StringRef name,
Location loc,
106 void declareVulkanFunctions(
Location loc);
109 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
110 return (callOp.getCallee() && *callOp.getCallee() ==
kVulkanLaunch &&
111 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
116 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
117 return (callOp.getCallee() &&
119 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
124 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
127 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
128 Value vulkanRuntime);
131 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
134 LogicalResult deduceMemRefRank(
Value launchCallArg, uint32_t &rank);
137 StringRef stringifyType(
Type type) {
138 if (isa<Float32Type>(type))
140 if (isa<Float16Type>(type))
142 if (
auto intType = dyn_cast<IntegerType>(type)) {
143 if (intType.getWidth() == 32)
145 if (intType.getWidth() == 16)
147 if (intType.getWidth() == 8)
151 llvm_unreachable(
"unsupported type");
157 void runOnOperation()
override;
162 Type llvmPointerType;
166 struct SPIRVAttributes {
168 StringAttr entryPoint;
173 SPIRVAttributes spirvAttributes;
176 static constexpr
unsigned kVulkanLaunchNumConfigOperands = 3;
181 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
182 initializeCachedTypes();
186 getOperation().walk([
this](LLVM::CallOp op) {
187 if (isVulkanLaunchCallOp(op))
188 collectSPIRVAttributes(op);
192 getOperation().walk([
this](LLVM::CallOp op) {
193 if (isCInterfaceVulkanLaunchCallOp(op))
194 translateVulkanLaunchCall(op);
198 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
199 LLVM::CallOp vulkanLaunchCallOp) {
204 if (!spirvBlobAttr) {
205 vulkanLaunchCallOp.emitError()
207 return signalPassFailure();
210 auto spirvEntryPointNameAttr =
212 if (!spirvEntryPointNameAttr) {
213 vulkanLaunchCallOp.emitError()
215 return signalPassFailure();
218 auto spirvElementTypesAttr =
220 if (!spirvElementTypesAttr) {
221 vulkanLaunchCallOp.emitError()
223 return signalPassFailure();
225 if (llvm::any_of(spirvElementTypesAttr,
226 [](
Attribute attr) {
return !isa<TypeAttr>(attr); })) {
227 vulkanLaunchCallOp.emitError()
228 <<
"expected " << spirvElementTypesAttr <<
" to be an array of types";
229 return signalPassFailure();
232 spirvAttributes.blob = spirvBlobAttr;
233 spirvAttributes.entryPoint = spirvEntryPointNameAttr;
234 spirvAttributes.elementTypes =
235 llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
238 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
239 LLVM::CallOp cInterfaceVulkanLaunchCallOp,
Value vulkanRuntime) {
240 if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
241 kVulkanLaunchNumConfigOperands)
243 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
244 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
249 Value descriptorSet =
250 builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
252 for (
auto [index, ptrToMemRefDescriptor] :
254 kVulkanLaunchNumConfigOperands))) {
256 Value descriptorBinding =
257 builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
259 if (index >= spirvAttributes.elementTypes.size()) {
260 cInterfaceVulkanLaunchCallOp.emitError()
262 << ptrToMemRefDescriptor;
263 return signalPassFailure();
267 Type type = spirvAttributes.elementTypes[index];
268 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
269 cInterfaceVulkanLaunchCallOp.emitError()
270 <<
"invalid memref descriptor " << ptrToMemRefDescriptor.getType();
271 return signalPassFailure();
275 llvm::formatv(
"bindMemRef{0}D{1}", rank, stringifyType(type)).str();
277 builder.create<LLVM::CallOp>(
278 loc,
TypeRange(), StringRef(symbolName.data(), symbolName.size()),
279 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
280 ptrToMemRefDescriptor});
285 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(
Value launchCallArg,
292 std::optional<Type> elementType = alloca.getElemType();
293 assert(elementType &&
"expected to work with opaque pointers");
294 auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
303 if (!llvmDescriptorTy)
306 if (llvmDescriptorTy.getBody().size() == 3) {
311 cast<LLVM::LLVMArrayType>(llvmDescriptorTy.getBody()[3]).getNumElements();
315 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(
Location loc) {
316 ModuleOp module = getOperation();
320 builder.create<LLVM::LLVMFuncOp>(
323 {getPointerType(), getPointerType()}));
327 builder.create<LLVM::LLVMFuncOp>(
330 {getPointerType(), getInt64Type(),
331 getInt64Type(), getInt64Type()}));
335 builder.create<LLVM::LLVMFuncOp>(
339 {getPointerType(), getPointerType(), getInt32Type()}));
343 builder.create<LLVM::LLVMFuncOp>(
348 for (
unsigned i = 1; i <= 3; i++) {
353 for (
auto type : types) {
354 std::string fnName =
"bindMemRef" + std::to_string(i) +
"D" +
355 std::string(stringifyType(type));
356 if (isa<Float16Type>(type))
358 if (!module.lookupSymbol(fnName)) {
361 {llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType},
363 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
369 builder.create<LLVM::LLVMFuncOp>(
374 builder.create<LLVM::LLVMFuncOp>(
380 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
385 shaderName.push_back(
'\0');
387 std::string entryPointGlobalName = (name +
"_spv_entry_point_name").str();
389 shaderName, LLVM::Linkage::Internal);
392 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
393 LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
394 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
395 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
397 auto initVulkanCall = builder.
create<LLVM::CallOp>(
401 auto vulkanRuntime = initVulkanCall.getResult();
406 loc, builder,
kSPIRVBinary, spirvAttributes.blob.getValue(),
407 LLVM::Linkage::Internal);
410 Value binarySize = builder.
create<LLVM::ConstantOp>(
411 loc, getInt32Type(), spirvAttributes.blob.getValue().size());
414 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
418 builder.
create<LLVM::CallOp>(
420 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
422 Value entryPointName = createEntryPointNameConstant(
423 spirvAttributes.entryPoint.getValue(), loc, builder);
430 builder.
create<LLVM::CallOp>(
432 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
433 cInterfaceVulkanLaunchCallOp.getOperand(1),
434 cInterfaceVulkanLaunchCallOp.getOperand(2)});
445 declareVulkanFunctions(loc);
447 cInterfaceVulkanLaunchCallOp.erase();
static constexpr const char * kVulkanLaunch
static constexpr const char * kCInterfaceVulkanLaunch
static constexpr const char * kInitVulkan
static constexpr const char * kSPIRVBinary
static constexpr const char * kSPIRVElementTypesAttrName
static constexpr const char * kDeinitVulkan
static constexpr const char * kSPIRVEntryPointAttrName
static constexpr const char * kSetEntryPoint
static constexpr const char * kSetBinaryShader
static constexpr const char * kSPIRVBlobAttrName
static constexpr const char * kSetNumWorkGroups
static constexpr const char * kRunOnVulkan
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...