MLIR 22.0.0git
ConvertLaunchFuncToLLVMCalls.cpp
Go to the documentation of this file.
1//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements passes to convert `gpu.launch_func` op into a sequence
10// of LLVM calls that emulate the host and device sides.
11//
12//===----------------------------------------------------------------------===//
13
26#include "mlir/IR/BuiltinOps.h"
27#include "mlir/IR/SymbolTable.h"
28#include "mlir/Pass/Pass.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/Support/FormatVariadic.h"
33
34namespace mlir {
35#define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40
41static constexpr const char kSPIRVModule[] = "__spv__";
42
43//===----------------------------------------------------------------------===//
44// Utility functions
45//===----------------------------------------------------------------------===//
46
47/// Returns the string name of the `DescriptorSet` decoration.
48static std::string descriptorSetName() {
49 return spirv::getDecorationString(spirv::Decoration::DescriptorSet);
50}
51
52/// Returns the string name of the `Binding` decoration.
53static std::string bindingName() {
54 return spirv::getDecorationString(spirv::Decoration::Binding);
55}
56
57/// Calculates the index of the kernel's operand that is represented by the
58/// given global variable with the `bind` attribute. We assume that the index of
59/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
60/// i -> (0, i)
61/// which is implemented under `LowerABIAttributesPass`.
62static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
63 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
64 return binding.getInt();
65}
66
67/// Copies the given number of bytes from src to dst pointers.
68static void copy(Location loc, Value dst, Value src, Value size,
69 OpBuilder &builder) {
70 LLVM::MemcpyOp::create(builder, loc, dst, src, size, /*isVolatile=*/false);
71}
72
73/// Encodes the binding and descriptor set numbers into a new symbolic name.
74/// The name is specified by
75/// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
76/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
77/// binding numbers.
78static std::string
79createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
80 StringRef kernelModuleName) {
81 IntegerAttr descriptorSet =
82 op->getAttrOfType<IntegerAttr>(descriptorSetName());
83 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
84 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
85 kernelModuleName.str(), op.getSymName().str(),
86 std::to_string(descriptorSet.getInt()),
87 std::to_string(binding.getInt()));
88}
89
90/// Returns true if the given global variable has both a descriptor set number
91/// and a binding number.
92static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
93 IntegerAttr descriptorSet =
94 op->getAttrOfType<IntegerAttr>(descriptorSetName());
95 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
96 return descriptorSet && binding;
97}
98
99/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
100/// arguments from the given SPIR-V module. We assume that the module contains a
101/// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind
102/// attribute are kernel arguments.
103static LogicalResult getKernelGlobalVariables(
104 spirv::ModuleOp module,
106 auto entryPoints = module.getOps<spirv::EntryPointOp>();
107 if (!llvm::hasSingleElement(entryPoints)) {
108 return module.emitError(
109 "The module must contain exactly one entry point function");
110 }
111 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
112 for (auto globalOp : globalVariables) {
113 if (hasDescriptorSetAndBinding(globalOp))
114 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
115 }
116 return success();
117}
118
119/// Encodes the SPIR-V module's symbolic name into the name of the entry point
120/// function.
121static LogicalResult encodeKernelName(spirv::ModuleOp module) {
122 StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule);
123 // We already know that the module contains exactly one entry point function
124 // based on `getKernelGlobalVariables()` call. Update this function's name
125 // to:
126 // {spv_module_name}_{function_name}
127 auto entryPoints = module.getOps<spirv::EntryPointOp>();
128 if (!llvm::hasSingleElement(entryPoints)) {
129 return module.emitError(
130 "The module must contain exactly one entry point function");
131 }
132 spirv::EntryPointOp entryPoint = *entryPoints.begin();
133 StringRef funcName = entryPoint.getFn();
134 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
135 StringAttr newFuncName =
136 StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
137 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
138 return failure();
139 SymbolTable::setSymbolName(funcOp, newFuncName);
140 return success();
141}
142
143//===----------------------------------------------------------------------===//
144// Conversion patterns
145//===----------------------------------------------------------------------===//
146
147namespace {
148
149/// Structure to group information about the variables being copied.
150struct CopyInfo {
151 Value dst;
152 Value src;
153 Value size;
154};
155
156/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
157/// copy the data to the global variable (emulating device side), call the
158/// kernel as a normal void LLVM function, and copy the data back (emulating the
159/// host side).
160class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
161 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
162
163 LogicalResult
164 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter) const override {
166 auto *op = launchOp.getOperation();
167 MLIRContext *context = rewriter.getContext();
168 auto module = launchOp->getParentOfType<ModuleOp>();
169
170 // Get the SPIR-V module that represents the gpu kernel module. The module
171 // is named:
172 // __spv__{kernel_module_name}
173 // based on GPU to SPIR-V conversion.
174 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
175 std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
176 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
177 StringAttr::get(context, spvModuleName));
178 if (!spvModule) {
179 return launchOp.emitOpError("SPIR-V kernel module '")
180 << spvModuleName << "' is not found";
181 }
182
183 // Declare kernel function in the main module so that it later can be linked
184 // with its definition from the kernel module. We know that the kernel
185 // function would have no arguments and the data is passed via global
186 // variables. The name of the kernel will be
187 // {spv_module_name}_{kernel_function_name}
188 // to avoid symbolic name conflicts.
189 StringRef kernelFuncName = launchOp.getKernelName().getValue();
190 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
191 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
192 StringAttr::get(context, newKernelFuncName));
193 if (!kernelFunc) {
194 OpBuilder::InsertionGuard guard(rewriter);
195 rewriter.setInsertionPointToStart(module.getBody());
196 kernelFunc = LLVM::LLVMFuncOp::create(
197 rewriter, rewriter.getUnknownLoc(), newKernelFuncName,
198 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
199 ArrayRef<Type>()));
200 rewriter.setInsertionPoint(launchOp);
201 }
202
203 // Get all global variables associated with the kernel operands.
204 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
205 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
206 return failure();
207
208 // Traverse kernel operands that were converted to MemRefDescriptors. For
209 // each operand, create a global variable and copy data from operand to it.
210 Location loc = launchOp.getLoc();
211 SmallVector<CopyInfo, 4> copyInfo;
212 auto numKernelOperands = launchOp.getNumKernelOperands();
213 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
214 for (const auto &operand : llvm::enumerate(kernelOperands)) {
215 // Check if the kernel's operand is a ranked memref.
216 auto memRefType = dyn_cast<MemRefType>(
217 launchOp.getKernelOperand(operand.index()).getType());
218 if (!memRefType)
219 return failure();
220
221 // Calculate the size of the memref and get the pointer to the allocated
222 // buffer.
223 SmallVector<Value, 4> sizes;
224 SmallVector<Value, 4> strides;
225 Value sizeBytes;
226 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
227 sizeBytes);
228 MemRefDescriptor descriptor(operand.value());
229 Value src = descriptor.allocatedPtr(rewriter, loc);
230
231 // Get the global variable in the SPIR-V module that is associated with
232 // the kernel operand. Construct its new name and create a corresponding
233 // LLVM dialect global variable.
234 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
235 auto pointeeType =
236 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
237 auto dstGlobalType = typeConverter->convertType(pointeeType);
238 if (!dstGlobalType)
239 return failure();
240 std::string name =
241 createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
242 // Check if this variable has already been created.
243 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
244 if (!dstGlobal) {
245 OpBuilder::InsertionGuard guard(rewriter);
246 rewriter.setInsertionPointToStart(module.getBody());
247 dstGlobal = LLVM::GlobalOp::create(
248 rewriter, loc, dstGlobalType,
249 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
250 /*alignment=*/0);
251 rewriter.setInsertionPoint(launchOp);
252 }
253
254 // Copy the data from src operand pointer to dst global variable. Save
255 // src, dst and size so that we can copy data back after emulating the
256 // kernel call.
257 Value dst = LLVM::AddressOfOp::create(
258 rewriter, loc, typeConverter->convertType(spirvGlobal.getType()),
259 dstGlobal.getSymName());
260 copy(loc, dst, src, sizeBytes, rewriter);
261
262 CopyInfo info;
263 info.dst = dst;
264 info.src = src;
265 info.size = sizeBytes;
266 copyInfo.push_back(info);
267 }
268 // Create a call to the kernel and copy the data back.
269 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
270 ArrayRef<Value>());
271 for (CopyInfo info : copyInfo)
272 copy(loc, info.src, info.dst, info.size, rewriter);
273 return success();
274 }
275};
276
277class LowerHostCodeToLLVM
278 : public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
279public:
280 using Base::Base;
281
282 void runOnOperation() override {
283 ModuleOp module = getOperation();
284
285 // Erase the GPU module.
286 for (auto gpuModule :
287 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
288 gpuModule.erase();
289
290 // Request C wrapper emission.
291 for (auto func : module.getOps<func::FuncOp>()) {
292 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
293 UnitAttr::get(&getContext()));
294 }
295
296 // Specify options to lower to LLVM and pull in the conversion patterns.
297 LowerToLLVMOptions options(module.getContext());
298
299 auto *context = module.getContext();
300 RewritePatternSet patterns(context);
301 LLVMTypeConverter typeConverter(context, options);
305 patterns.add<GPULaunchLowering>(typeConverter);
306
307 // Pull in SPIR-V type conversion patterns to convert SPIR-V global
308 // variable's type to LLVM dialect type.
310
311 ConversionTarget target(*context);
312 target.addLegalDialect<LLVM::LLVMDialect>();
313 if (failed(applyPartialConversion(module, target, std::move(patterns))))
314 signalPassFailure();
315
316 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
317 // conflicts.
318 for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
319 if (failed(encodeKernelName(spvModule))) {
320 signalPassFailure();
321 return;
322 }
323 }
324 }
325};
326} // namespace
return success()
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static constexpr const char kSPIRVModule[]
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
std::string getDecorationString(spirv::Decoration decor)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Definition Pattern.h:29
Include the generated interface declarations.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126