31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
36 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
50 return llvm::convertToSnakeFromCamelCase(
51 stringifyDecoration(spirv::Decoration::DescriptorSet));
56 return llvm::convertToSnakeFromCamelCase(
57 stringifyDecoration(spirv::Decoration::Binding));
66 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
67 return binding.getInt();
73 builder.
create<LLVM::MemcpyOp>(loc, dst, src, size,
false);
83 StringRef kernelModuleName) {
84 IntegerAttr descriptorSet =
86 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
87 return llvm::formatv(
"{0}_{1}_descriptor_set{2}_binding{3}",
88 kernelModuleName.str(), op.getSymName().str(),
89 std::to_string(descriptorSet.getInt()),
90 std::to_string(binding.getInt()));
96 IntegerAttr descriptorSet =
98 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
99 return descriptorSet && binding;
107 spirv::ModuleOp module,
109 auto entryPoints = module.getOps<spirv::EntryPointOp>();
110 if (!llvm::hasSingleElement(entryPoints)) {
111 return module.emitError(
112 "The module must contain exactly one entry point function");
114 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
115 for (
auto globalOp : globalVariables) {
125 StringRef spvModuleName = module.getSymName().value_or(
kSPIRVModule);
130 auto entryPoints = module.getOps<spirv::EntryPointOp>();
131 if (!llvm::hasSingleElement(entryPoints)) {
132 return module.emitError(
133 "The module must contain exactly one entry point function");
135 spirv::EntryPointOp entryPoint = *entryPoints.begin();
136 StringRef funcName = entryPoint.getFn();
137 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
138 StringAttr newFuncName =
167 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
169 auto *op = launchOp.getOperation();
171 auto module = launchOp->getParentOfType<ModuleOp>();
177 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
178 std::string spvModuleName =
kSPIRVModule + kernelModuleName.str();
179 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
182 return launchOp.emitOpError(
"SPIR-V kernel module '")
183 << spvModuleName <<
"' is not found";
192 StringRef kernelFuncName = launchOp.getKernelName().getValue();
193 std::string newKernelFuncName = spvModuleName +
"_" + kernelFuncName.str();
194 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
199 kernelFunc = rewriter.
create<LLVM::LLVMFuncOp>(
215 auto numKernelOperands = launchOp.getNumKernelOperands();
216 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
219 auto memRefType = dyn_cast<MemRefType>(
220 launchOp.getKernelOperand(operand.index()).getType());
229 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
232 Value src = descriptor.allocatedPtr(rewriter, loc);
237 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
239 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
240 auto dstGlobalType = typeConverter->convertType(pointeeType);
246 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
250 dstGlobal = rewriter.
create<LLVM::GlobalOp>(
252 false, LLVM::Linkage::Linkonce, name,
Attribute(),
261 loc, typeConverter->convertType(spirvGlobal.getType()),
262 dstGlobal.getSymName());
263 copy(loc, dst, src, sizeBytes, rewriter);
268 info.size = sizeBytes;
269 copyInfo.push_back(info);
274 for (CopyInfo info : copyInfo)
275 copy(loc, info.src, info.dst, info.size, rewriter);
280 class LowerHostCodeToLLVM
281 :
public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
285 void runOnOperation()
override {
286 ModuleOp module = getOperation();
289 for (
auto gpuModule :
290 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
294 for (
auto func : module.getOps<func::FuncOp>()) {
295 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
302 auto *context = module.getContext();
308 patterns.add<GPULaunchLowering>(typeConverter);
315 target.addLegalDialect<LLVM::LLVMDialect>();
321 for (
auto spvModule : module.getOps<spirv::ModuleOp>()) {
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.