31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
36 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVM
37 #include "mlir/Conversion/Passes.h.inc"
50 return llvm::convertToSnakeFromCamelCase(
51 stringifyDecoration(spirv::Decoration::DescriptorSet));
56 return llvm::convertToSnakeFromCamelCase(
57 stringifyDecoration(spirv::Decoration::Binding));
66 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
67 return binding.getInt();
74 auto llvmI1Type = IntegerType::get(context, 1);
75 Value isVolatile = builder.
create<LLVM::ConstantOp>(
77 builder.
create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
87 StringRef kernelModuleName) {
88 IntegerAttr descriptorSet =
90 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
91 return llvm::formatv(
"{0}_{1}_descriptor_set{2}_binding{3}",
92 kernelModuleName.str(), op.getSymName().str(),
93 std::to_string(descriptorSet.getInt()),
94 std::to_string(binding.getInt()));
100 IntegerAttr descriptorSet =
102 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
103 return descriptorSet && binding;
111 spirv::ModuleOp module,
113 auto entryPoints = module.getOps<spirv::EntryPointOp>();
114 if (!llvm::hasSingleElement(entryPoints)) {
115 return module.emitError(
116 "The module must contain exactly one entry point function");
118 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
119 for (
auto globalOp : globalVariables) {
129 StringRef spvModuleName = module.getSymName().value_or(
kSPIRVModule);
134 auto entryPoints = module.getOps<spirv::EntryPointOp>();
135 if (!llvm::hasSingleElement(entryPoints)) {
136 return module.emitError(
137 "The module must contain exactly one entry point function");
139 spirv::EntryPointOp entryPoint = *entryPoints.begin();
140 StringRef funcName = entryPoint.getFn();
141 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
142 StringAttr newFuncName =
143 StringAttr::get(module->getContext(), spvModuleName +
"_" + funcName);
171 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
173 auto *op = launchOp.getOperation();
175 auto module = launchOp->getParentOfType<ModuleOp>();
181 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
182 std::string spvModuleName =
kSPIRVModule + kernelModuleName.str();
183 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
184 StringAttr::get(context, spvModuleName));
186 return launchOp.emitOpError(
"SPIR-V kernel module '")
187 << spvModuleName <<
"' is not found";
196 StringRef kernelFuncName = launchOp.getKernelName().getValue();
197 std::string newKernelFuncName = spvModuleName +
"_" + kernelFuncName.str();
198 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
199 StringAttr::get(context, newKernelFuncName));
203 kernelFunc = rewriter.
create<LLVM::LLVMFuncOp>(
205 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
219 auto numKernelOperands = launchOp.getNumKernelOperands();
220 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
223 auto memRefType = launchOp.getKernelOperand(operand.index())
225 .dyn_cast<MemRefType>();
234 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
237 Value src = descriptor.allocatedPtr(rewriter, loc);
242 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
245 auto dstGlobalType = typeConverter->convertType(pointeeType);
251 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
255 dstGlobal = rewriter.
create<LLVM::GlobalOp>(
257 false, LLVM::Linkage::Linkonce, name,
Attribute(),
265 Value dst = rewriter.
create<LLVM::AddressOfOp>(loc, dstGlobal);
266 copy(loc, dst, src, sizeBytes, rewriter);
271 info.size = sizeBytes;
272 copyInfo.push_back(info);
277 for (CopyInfo info : copyInfo)
278 copy(loc, info.src, info.dst, info.size, rewriter);
283 class LowerHostCodeToLLVM
284 :
public impl::LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
286 void runOnOperation()
override {
287 ModuleOp module = getOperation();
290 for (
auto gpuModule :
291 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
295 for (
auto func : module.getOps<func::FuncOp>()) {
296 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
297 UnitAttr::get(&getContext()));
302 auto *context = module.getContext();
308 patterns.add<GPULaunchLowering>(typeConverter);
315 target.addLegalDialect<LLVM::LLVMDialect>();
321 for (
auto spvModule : module.getOps<spirv::ModuleOp>()) {
331 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
333 return std::make_unique<LowerHostCodeToLLVM>();
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
std::unique_ptr< OperationPass< ModuleOp > > createLowerHostCodeToLLVMPass()
Creates a pass to emulate gpu.launch_func call in LLVM dialect and lower the host module code to LLVM...
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.