31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
36 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
50 return llvm::convertToSnakeFromCamelCase(
51 stringifyDecoration(spirv::Decoration::DescriptorSet));
56 return llvm::convertToSnakeFromCamelCase(
57 stringifyDecoration(spirv::Decoration::Binding));
67 return binding.getInt();
75 Value isVolatile = builder.
create<LLVM::ConstantOp>(
77 builder.
create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
87 StringRef kernelModuleName) {
88 IntegerAttr descriptorSet =
91 return llvm::formatv(
"{0}_{1}_descriptor_set{2}_binding{3}",
92 kernelModuleName.str(), op.getSymName().str(),
93 std::to_string(descriptorSet.getInt()),
94 std::to_string(binding.getInt()));
100 IntegerAttr descriptorSet =
103 return descriptorSet && binding;
111 spirv::ModuleOp module,
113 auto entryPoints = module.getOps<spirv::EntryPointOp>();
114 if (!llvm::hasSingleElement(entryPoints)) {
115 return module.emitError(
116 "The module must contain exactly one entry point function");
118 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
119 for (
auto globalOp : globalVariables) {
129 StringRef spvModuleName = module.getSymName().value_or(
kSPIRVModule);
134 auto entryPoints = module.getOps<spirv::EntryPointOp>();
135 if (!llvm::hasSingleElement(entryPoints)) {
136 return module.emitError(
137 "The module must contain exactly one entry point function");
139 spirv::EntryPointOp entryPoint = *entryPoints.begin();
140 StringRef funcName = entryPoint.getFn();
141 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
142 StringAttr newFuncName =
171 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
173 auto *op = launchOp.getOperation();
175 auto module = launchOp->getParentOfType<ModuleOp>();
181 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
182 std::string spvModuleName =
kSPIRVModule + kernelModuleName.str();
183 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
186 return launchOp.emitOpError(
"SPIR-V kernel module '")
187 << spvModuleName <<
"' is not found";
196 StringRef kernelFuncName = launchOp.getKernelName().getValue();
197 std::string newKernelFuncName = spvModuleName +
"_" + kernelFuncName.str();
198 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
203 kernelFunc = rewriter.
create<LLVM::LLVMFuncOp>(
219 auto numKernelOperands = launchOp.getNumKernelOperands();
220 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
223 auto memRefType = dyn_cast<MemRefType>(
224 launchOp.getKernelOperand(operand.index()).getType());
233 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
236 Value src = descriptor.allocatedPtr(rewriter, loc);
241 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
243 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
244 auto dstGlobalType = typeConverter->convertType(pointeeType);
250 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
254 dstGlobal = rewriter.
create<LLVM::GlobalOp>(
256 false, LLVM::Linkage::Linkonce, name,
Attribute(),
265 loc, typeConverter->convertType(spirvGlobal.getType()),
266 dstGlobal.getSymName());
267 copy(loc, dst, src, sizeBytes, rewriter);
272 info.size = sizeBytes;
273 copyInfo.push_back(info);
278 for (CopyInfo info : copyInfo)
279 copy(loc, info.src, info.dst, info.size, rewriter);
284 class LowerHostCodeToLLVM
285 :
public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
290 void runOnOperation()
override {
291 ModuleOp module = getOperation();
294 for (
auto gpuModule :
295 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
299 for (
auto func : module.getOps<func::FuncOp>()) {
300 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
306 options.useOpaquePointers = useOpaquePointers;
308 auto *context = module.getContext();
314 patterns.add<GPULaunchLowering>(typeConverter);
321 target.addLegalDialect<LLVM::LLVMDialect>();
327 for (
auto spvModule : module.getOps<spirv::ModuleOp>()) {
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
AttrClass getAttrOfType(StringAttr name)
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.