29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/StringExtras.h"
31#include "llvm/Support/FormatVariadic.h"
34#define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
35#include "mlir/Conversion/Passes.h.inc"
48 return llvm::convertToSnakeFromCamelCase(
49 stringifyDecoration(spirv::Decoration::DescriptorSet));
54 return llvm::convertToSnakeFromCamelCase(
55 stringifyDecoration(spirv::Decoration::Binding));
64 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
65 return binding.getInt();
71 LLVM::MemcpyOp::create(builder, loc, dst, src, size,
false);
81 StringRef kernelModuleName) {
82 IntegerAttr descriptorSet =
84 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
85 return llvm::formatv(
"{0}_{1}_descriptor_set{2}_binding{3}",
86 kernelModuleName.str(), op.getSymName().str(),
87 std::to_string(descriptorSet.getInt()),
88 std::to_string(binding.getInt()));
94 IntegerAttr descriptorSet =
96 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
97 return descriptorSet && binding;
105 spirv::ModuleOp module,
107 auto entryPoints =
module.getOps<spirv::EntryPointOp>();
108 if (!llvm::hasSingleElement(entryPoints)) {
109 return module.emitError(
110 "The module must contain exactly one entry point function");
112 auto globalVariables =
module.getOps<spirv::GlobalVariableOp>();
113 for (
auto globalOp : globalVariables) {
123 StringRef spvModuleName =
module.getSymName().value_or(kSPIRVModule);
128 auto entryPoints =
module.getOps<spirv::EntryPointOp>();
129 if (!llvm::hasSingleElement(entryPoints)) {
130 return module.emitError(
131 "The module must contain exactly one entry point function");
133 spirv::EntryPointOp entryPoint = *entryPoints.begin();
134 StringRef funcName = entryPoint.getFn();
135 auto funcOp =
module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
136 StringAttr newFuncName =
137 StringAttr::get(module->getContext(), spvModuleName +
"_" + funcName);
162 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
165 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter)
const override {
167 auto *op = launchOp.getOperation();
168 MLIRContext *context = rewriter.getContext();
169 auto module = launchOp->getParentOfType<ModuleOp>();
175 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
176 std::string spvModuleName =
kSPIRVModule + kernelModuleName.str();
177 auto spvModule =
module.lookupSymbol<spirv::ModuleOp>(
178 StringAttr::get(context, spvModuleName));
180 return launchOp.emitOpError(
"SPIR-V kernel module '")
181 << spvModuleName <<
"' is not found";
190 StringRef kernelFuncName = launchOp.getKernelName().getValue();
191 std::string newKernelFuncName = spvModuleName +
"_" + kernelFuncName.str();
192 auto kernelFunc =
module.lookupSymbol<LLVM::LLVMFuncOp>(
193 StringAttr::get(context, newKernelFuncName));
195 OpBuilder::InsertionGuard guard(rewriter);
196 rewriter.setInsertionPointToStart(module.getBody());
197 kernelFunc = LLVM::LLVMFuncOp::create(
198 rewriter, rewriter.getUnknownLoc(), newKernelFuncName,
199 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
201 rewriter.setInsertionPoint(launchOp);
205 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
211 Location loc = launchOp.getLoc();
212 SmallVector<CopyInfo, 4> copyInfo;
213 auto numKernelOperands = launchOp.getNumKernelOperands();
214 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
215 for (
const auto &operand : llvm::enumerate(kernelOperands)) {
217 auto memRefType = dyn_cast<MemRefType>(
218 launchOp.getKernelOperand(operand.index()).getType());
224 SmallVector<Value, 4> sizes;
225 SmallVector<Value, 4> strides;
227 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
229 MemRefDescriptor descriptor(operand.value());
230 Value src = descriptor.allocatedPtr(rewriter, loc);
235 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
237 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
238 auto dstGlobalType = typeConverter->convertType(pointeeType);
244 auto dstGlobal =
module.lookupSymbol<LLVM::GlobalOp>(name);
246 OpBuilder::InsertionGuard guard(rewriter);
247 rewriter.setInsertionPointToStart(module.getBody());
248 dstGlobal = LLVM::GlobalOp::create(
249 rewriter, loc, dstGlobalType,
250 false, LLVM::Linkage::Linkonce, name, Attribute(),
252 rewriter.setInsertionPoint(launchOp);
258 Value dst = LLVM::AddressOfOp::create(
259 rewriter, loc, typeConverter->convertType(spirvGlobal.getType()),
260 dstGlobal.getSymName());
261 copy(loc, dst, src, sizeBytes, rewriter);
266 info.size = sizeBytes;
267 copyInfo.push_back(info);
270 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
272 for (CopyInfo info : copyInfo)
273 copy(loc, info.src, info.dst, info.size, rewriter);
278class LowerHostCodeToLLVM
283 void runOnOperation()
override {
284 ModuleOp module = getOperation();
287 for (
auto gpuModule :
288 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
292 for (
auto func : module.getOps<func::FuncOp>()) {
293 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
298 LowerToLLVMOptions
options(module.getContext());
300 auto *context =
module.getContext();
301 RewritePatternSet
patterns(context);
302 LLVMTypeConverter typeConverter(context,
options);
306 patterns.add<GPULaunchLowering>(typeConverter);
312 ConversionTarget
target(*context);
313 target.addLegalDialect<LLVM::LLVMDialect>();
319 for (
auto spvModule : module.getOps<spirv::ModuleOp>()) {
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static constexpr const char kSPIRVModule[]
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Include the generated interface declarations.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap