29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/Support/FormatVariadic.h"
34 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
35 #include "mlir/Conversion/Passes.h.inc"
48 return llvm::convertToSnakeFromCamelCase(
49 stringifyDecoration(spirv::Decoration::DescriptorSet));
54 return llvm::convertToSnakeFromCamelCase(
55 stringifyDecoration(spirv::Decoration::Binding));
64 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
65 return binding.getInt();
71 LLVM::MemcpyOp::create(builder, loc, dst, src, size,
false);
81 StringRef kernelModuleName) {
82 IntegerAttr descriptorSet =
84 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
85 return llvm::formatv(
"{0}_{1}_descriptor_set{2}_binding{3}",
86 kernelModuleName.str(), op.getSymName().str(),
87 std::to_string(descriptorSet.getInt()),
88 std::to_string(binding.getInt()));
94 IntegerAttr descriptorSet =
96 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
bindingName());
97 return descriptorSet && binding;
105 spirv::ModuleOp module,
107 auto entryPoints = module.getOps<spirv::EntryPointOp>();
108 if (!llvm::hasSingleElement(entryPoints)) {
109 return module.emitError(
110 "The module must contain exactly one entry point function");
112 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
113 for (
auto globalOp : globalVariables) {
123 StringRef spvModuleName = module.getSymName().value_or(
kSPIRVModule);
128 auto entryPoints = module.getOps<spirv::EntryPointOp>();
129 if (!llvm::hasSingleElement(entryPoints)) {
130 return module.emitError(
131 "The module must contain exactly one entry point function");
133 spirv::EntryPointOp entryPoint = *entryPoints.begin();
134 StringRef funcName = entryPoint.getFn();
135 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
136 StringAttr newFuncName =
165 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
167 auto *op = launchOp.getOperation();
169 auto module = launchOp->getParentOfType<ModuleOp>();
175 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
176 std::string spvModuleName =
kSPIRVModule + kernelModuleName.str();
177 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
180 return launchOp.emitOpError(
"SPIR-V kernel module '")
181 << spvModuleName <<
"' is not found";
190 StringRef kernelFuncName = launchOp.getKernelName().getValue();
191 std::string newKernelFuncName = spvModuleName +
"_" + kernelFuncName.str();
192 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
197 kernelFunc = LLVM::LLVMFuncOp::create(
213 auto numKernelOperands = launchOp.getNumKernelOperands();
214 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
217 auto memRefType = dyn_cast<MemRefType>(
218 launchOp.getKernelOperand(operand.index()).getType());
227 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
230 Value src = descriptor.allocatedPtr(rewriter, loc);
235 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
237 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
238 auto dstGlobalType = typeConverter->convertType(pointeeType);
244 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
248 dstGlobal = LLVM::GlobalOp::create(
249 rewriter, loc, dstGlobalType,
250 false, LLVM::Linkage::Linkonce, name,
Attribute(),
258 Value dst = LLVM::AddressOfOp::create(
259 rewriter, loc, typeConverter->convertType(spirvGlobal.getType()),
260 dstGlobal.getSymName());
261 copy(loc, dst, src, sizeBytes, rewriter);
266 info.size = sizeBytes;
267 copyInfo.push_back(info);
273 for (CopyInfo info : copyInfo)
274 copy(loc, info.src, info.dst, info.size, rewriter);
279 class LowerHostCodeToLLVM
280 :
public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
284 void runOnOperation()
override {
285 ModuleOp module = getOperation();
288 for (
auto gpuModule :
289 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
293 for (
auto func : module.getOps<func::FuncOp>()) {
294 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
301 auto *context = module.getContext();
307 patterns.add<GPULaunchLowering>(typeConverter);
314 target.addLegalDialect<LLVM::LLVMDialect>();
320 for (
auto spvModule : module.getOps<spirv::ModuleOp>()) {
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.