30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/Support/FormatVariadic.h"
35#define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
36#include "mlir/Conversion/Passes.h.inc"
63 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
64 return binding.getInt();
70 LLVM::MemcpyOp::create(builder, loc, dst, src, size,
false);
80 StringRef kernelModuleName) {
81 IntegerAttr descriptorSet =
83 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
84 return llvm::formatv(
"{0}_{1}_descriptor_set{2}_binding{3}",
85 kernelModuleName.str(), op.getSymName().str(),
86 std::to_string(descriptorSet.getInt()),
87 std::to_string(binding.getInt()));
93 IntegerAttr descriptorSet =
95 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
96 return descriptorSet && binding;
104 spirv::ModuleOp module,
106 auto entryPoints =
module.getOps<spirv::EntryPointOp>();
107 if (!llvm::hasSingleElement(entryPoints)) {
108 return module.emitError(
109 "The module must contain exactly one entry point function");
111 auto globalVariables =
module.getOps<spirv::GlobalVariableOp>();
112 for (
auto globalOp : globalVariables) {
122 StringRef spvModuleName =
module.getSymName().value_or(kSPIRVModule);
127 auto entryPoints =
module.getOps<spirv::EntryPointOp>();
128 if (!llvm::hasSingleElement(entryPoints)) {
129 return module.emitError(
130 "The module must contain exactly one entry point function");
132 spirv::EntryPointOp entryPoint = *entryPoints.begin();
133 StringRef funcName = entryPoint.getFn();
134 auto funcOp =
module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
135 StringAttr newFuncName =
136 StringAttr::get(module->getContext(), spvModuleName +
"_" + funcName);
161 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
164 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter)
const override {
166 auto *op = launchOp.getOperation();
167 MLIRContext *context = rewriter.getContext();
168 auto module = launchOp->getParentOfType<ModuleOp>();
174 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
175 std::string spvModuleName =
kSPIRVModule + kernelModuleName.str();
176 auto spvModule =
module.lookupSymbol<spirv::ModuleOp>(
177 StringAttr::get(context, spvModuleName));
179 return launchOp.emitOpError(
"SPIR-V kernel module '")
180 << spvModuleName <<
"' is not found";
189 StringRef kernelFuncName = launchOp.getKernelName().getValue();
190 std::string newKernelFuncName = spvModuleName +
"_" + kernelFuncName.str();
191 auto kernelFunc =
module.lookupSymbol<LLVM::LLVMFuncOp>(
192 StringAttr::get(context, newKernelFuncName));
194 OpBuilder::InsertionGuard guard(rewriter);
195 rewriter.setInsertionPointToStart(module.getBody());
196 kernelFunc = LLVM::LLVMFuncOp::create(
197 rewriter, rewriter.getUnknownLoc(), newKernelFuncName,
198 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
200 rewriter.setInsertionPoint(launchOp);
204 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
210 Location loc = launchOp.getLoc();
211 SmallVector<CopyInfo, 4> copyInfo;
212 auto numKernelOperands = launchOp.getNumKernelOperands();
213 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
214 for (
const auto &operand : llvm::enumerate(kernelOperands)) {
216 auto memRefType = dyn_cast<MemRefType>(
217 launchOp.getKernelOperand(operand.index()).getType());
223 SmallVector<Value, 4> sizes;
224 SmallVector<Value, 4> strides;
226 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
228 MemRefDescriptor descriptor(operand.value());
229 Value src = descriptor.allocatedPtr(rewriter, loc);
234 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
236 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
237 auto dstGlobalType = typeConverter->convertType(pointeeType);
243 auto dstGlobal =
module.lookupSymbol<LLVM::GlobalOp>(name);
245 OpBuilder::InsertionGuard guard(rewriter);
246 rewriter.setInsertionPointToStart(module.getBody());
247 dstGlobal = LLVM::GlobalOp::create(
248 rewriter, loc, dstGlobalType,
249 false, LLVM::Linkage::Linkonce, name, Attribute(),
251 rewriter.setInsertionPoint(launchOp);
257 Value dst = LLVM::AddressOfOp::create(
258 rewriter, loc, typeConverter->convertType(spirvGlobal.getType()),
259 dstGlobal.getSymName());
260 copy(loc, dst, src, sizeBytes, rewriter);
265 info.size = sizeBytes;
266 copyInfo.push_back(info);
269 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
271 for (CopyInfo info : copyInfo)
272 copy(loc, info.src, info.dst, info.size, rewriter);
277class LowerHostCodeToLLVM
282 void runOnOperation()
override {
283 ModuleOp module = getOperation();
286 for (
auto gpuModule :
287 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
291 for (
auto func : module.getOps<func::FuncOp>()) {
292 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
297 LowerToLLVMOptions
options(module.getContext());
299 auto *context =
module.getContext();
300 RewritePatternSet
patterns(context);
301 LLVMTypeConverter typeConverter(context,
options);
305 patterns.add<GPULaunchLowering>(typeConverter);
311 ConversionTarget
target(*context);
312 target.addLegalDialect<LLVM::LLVMDialect>();
318 for (
auto spvModule : module.getOps<spirv::ModuleOp>()) {
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static constexpr const char kSPIRVModule[]
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
std::string getDecorationString(spirv::Decoration decor)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap