MLIR  15.0.0git
ConvertLaunchFuncToLLVMCalls.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/SymbolTable.h"
29 
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/Support/FormatVariadic.h"
33 
34 using namespace mlir;
35 
36 static constexpr const char kSPIRVModule[] = "__spv__";
37 
38 //===----------------------------------------------------------------------===//
39 // Utility functions
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns the string name of the `DescriptorSet` decoration.
43 static std::string descriptorSetName() {
44  return llvm::convertToSnakeFromCamelCase(
45  stringifyDecoration(spirv::Decoration::DescriptorSet));
46 }
47 
48 /// Returns the string name of the `Binding` decoration.
49 static std::string bindingName() {
50  return llvm::convertToSnakeFromCamelCase(
51  stringifyDecoration(spirv::Decoration::Binding));
52 }
53 
54 /// Calculates the index of the kernel's operand that is represented by the
55 /// given global variable with the `bind` attribute. We assume that the index of
56 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
57 /// i -> (0, i)
58 /// which is implemented under `LowerABIAttributesPass`.
59 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
60  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
61  return binding.getInt();
62 }
63 
64 /// Copies the given number of bytes from src to dst pointers.
65 static void copy(Location loc, Value dst, Value src, Value size,
66  OpBuilder &builder) {
67  MLIRContext *context = builder.getContext();
68  auto llvmI1Type = IntegerType::get(context, 1);
69  Value isVolatile = builder.create<LLVM::ConstantOp>(
70  loc, llvmI1Type, builder.getBoolAttr(false));
71  builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
72 }
73 
74 /// Encodes the binding and descriptor set numbers into a new symbolic name.
75 /// The name is specified by
76 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
77 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
78 /// binding numbers.
79 static std::string
80 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
81  StringRef kernelModuleName) {
82  IntegerAttr descriptorSet =
83  op->getAttrOfType<IntegerAttr>(descriptorSetName());
84  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
85  return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
86  kernelModuleName.str(), op.sym_name().str(),
87  std::to_string(descriptorSet.getInt()),
88  std::to_string(binding.getInt()));
89 }
90 
91 /// Returns true if the given global variable has both a descriptor set number
92 /// and a binding number.
93 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
94  IntegerAttr descriptorSet =
95  op->getAttrOfType<IntegerAttr>(descriptorSetName());
96  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
97  return descriptorSet && binding;
98 }
99 
100 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
101 /// arguments from the given SPIR-V module. We assume that the module contains a
102 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind
103 /// attribute are kernel arguments.
105  spirv::ModuleOp module,
106  DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
107  auto entryPoints = module.getOps<spirv::EntryPointOp>();
108  if (!llvm::hasSingleElement(entryPoints)) {
109  return module.emitError(
110  "The module must contain exactly one entry point function");
111  }
112  auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
113  for (auto globalOp : globalVariables) {
114  if (hasDescriptorSetAndBinding(globalOp))
115  globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
116  }
117  return success();
118 }
119 
120 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
121 /// function.
122 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
123  StringRef spvModuleName = module.sym_name().getValue();
124  // We already know that the module contains exactly one entry point function
125  // based on `getKernelGlobalVariables()` call. Update this function's name
126  // to:
127  // {spv_module_name}_{function_name}
128  auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
129  StringRef funcName = entryPoint.fn();
130  auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
131  StringAttr newFuncName =
132  StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
133  if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
134  return failure();
135  SymbolTable::setSymbolName(funcOp, newFuncName);
136  return success();
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Conversion patterns
141 //===----------------------------------------------------------------------===//
142 
143 namespace {
144 
145 /// Structure to group information about the variables being copied.
146 struct CopyInfo {
147  Value dst;
148  Value src;
149  Value size;
150 };
151 
152 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
153 /// copy the data to the global variable (emulating device side), call the
154 /// kernel as a normal void LLVM function, and copy the data back (emulating the
155 /// host side).
156 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
158 
160  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
161  ConversionPatternRewriter &rewriter) const override {
162  auto *op = launchOp.getOperation();
163  MLIRContext *context = rewriter.getContext();
164  auto module = launchOp->getParentOfType<ModuleOp>();
165 
166  // Get the SPIR-V module that represents the gpu kernel module. The module
167  // is named:
168  // __spv__{kernel_module_name}
169  // based on GPU to SPIR-V conversion.
170  StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
171  std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
172  auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
173  StringAttr::get(context, spvModuleName));
174  if (!spvModule) {
175  return launchOp.emitOpError("SPIR-V kernel module '")
176  << spvModuleName << "' is not found";
177  }
178 
179  // Declare kernel function in the main module so that it later can be linked
180  // with its definition from the kernel module. We know that the kernel
181  // function would have no arguments and the data is passed via global
182  // variables. The name of the kernel will be
183  // {spv_module_name}_{kernel_function_name}
184  // to avoid symbolic name conflicts.
185  StringRef kernelFuncName = launchOp.getKernelName().getValue();
186  std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
187  auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
188  StringAttr::get(context, newKernelFuncName));
189  if (!kernelFunc) {
190  OpBuilder::InsertionGuard guard(rewriter);
191  rewriter.setInsertionPointToStart(module.getBody());
192  kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
193  rewriter.getUnknownLoc(), newKernelFuncName,
194  LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
195  ArrayRef<Type>()));
196  rewriter.setInsertionPoint(launchOp);
197  }
198 
199  // Get all global variables associated with the kernel operands.
201  if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
202  return failure();
203 
204  // Traverse kernel operands that were converted to MemRefDescriptors. For
205  // each operand, create a global variable and copy data from operand to it.
206  Location loc = launchOp.getLoc();
207  SmallVector<CopyInfo, 4> copyInfo;
208  auto numKernelOperands = launchOp.getNumKernelOperands();
209  auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
210  for (const auto &operand : llvm::enumerate(kernelOperands)) {
211  // Check if the kernel's operand is a ranked memref.
212  auto memRefType = launchOp.getKernelOperand(operand.index())
213  .getType()
214  .dyn_cast<MemRefType>();
215  if (!memRefType)
216  return failure();
217 
218  // Calculate the size of the memref and get the pointer to the allocated
219  // buffer.
220  SmallVector<Value, 4> sizes;
221  SmallVector<Value, 4> strides;
222  Value sizeBytes;
223  getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
224  sizeBytes);
225  MemRefDescriptor descriptor(operand.value());
226  Value src = descriptor.allocatedPtr(rewriter, loc);
227 
228  // Get the global variable in the SPIR-V module that is associated with
229  // the kernel operand. Construct its new name and create a corresponding
230  // LLVM dialect global variable.
231  spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
232  auto pointeeType =
233  spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
234  auto dstGlobalType = typeConverter->convertType(pointeeType);
235  if (!dstGlobalType)
236  return failure();
237  std::string name =
238  createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
239  // Check if this variable has already been created.
240  auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
241  if (!dstGlobal) {
242  OpBuilder::InsertionGuard guard(rewriter);
243  rewriter.setInsertionPointToStart(module.getBody());
244  dstGlobal = rewriter.create<LLVM::GlobalOp>(
245  loc, dstGlobalType,
246  /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
247  /*alignment=*/0);
248  rewriter.setInsertionPoint(launchOp);
249  }
250 
251  // Copy the data from src operand pointer to dst global variable. Save
252  // src, dst and size so that we can copy data back after emulating the
253  // kernel call.
254  Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
255  copy(loc, dst, src, sizeBytes, rewriter);
256 
257  CopyInfo info;
258  info.dst = dst;
259  info.src = src;
260  info.size = sizeBytes;
261  copyInfo.push_back(info);
262  }
263  // Create a call to the kernel and copy the data back.
264  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
265  ArrayRef<Value>());
266  for (CopyInfo info : copyInfo)
267  copy(loc, info.src, info.dst, info.size, rewriter);
268  return success();
269  }
270 };
271 
272 class LowerHostCodeToLLVM
273  : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
274 public:
275  void runOnOperation() override {
276  ModuleOp module = getOperation();
277 
278  // Erase the GPU module.
279  for (auto gpuModule :
280  llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
281  gpuModule.erase();
282 
283  // Specify options to lower to LLVM and pull in the conversion patterns.
284  LowerToLLVMOptions options(module.getContext());
285  options.emitCWrappers = true;
286  auto *context = module.getContext();
287  RewritePatternSet patterns(context);
288  LLVMTypeConverter typeConverter(context, options);
290  patterns);
291  populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
292  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
293  patterns.add<GPULaunchLowering>(typeConverter);
294 
295  // Pull in SPIR-V type conversion patterns to convert SPIR-V global
296  // variable's type to LLVM dialect type.
297  populateSPIRVToLLVMTypeConversion(typeConverter);
298 
299  ConversionTarget target(*context);
300  target.addLegalDialect<LLVM::LLVMDialect>();
301  if (failed(applyPartialConversion(module, target, std::move(patterns))))
302  signalPassFailure();
303 
304  // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
305  // conflicts.
306  for (auto spvModule : module.getOps<spirv::ModuleOp>())
307  (void)encodeKernelName(spvModule);
308  }
309 };
310 } // namespace
311 
312 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
314  return std::make_unique<LowerHostCodeToLLVM>();
315 }
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
Location getUnknownLoc()
Definition: Builders.cpp:26
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
MLIRContext * getContext() const
Definition: Builders.h:54
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static std::string bindingName()
Returns the string name of the Binding decoration.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
Definition: LLVMTypes.cpp:101
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
std::unique_ptr< OperationPass< ModuleOp > > createLowerHostCodeToLLVMPass()
Creates a pass to emulate gpu.launch_func call in LLVM dialect and lower the host module code to LLVM...
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number...
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel&#39;s operand that is represented by the given global variable with th...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static llvm::ManagedStatic< PassManagerOptions > options
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Options to control the LLVM lowering.
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol &#39;oldSymbol&#39; with the provided symbol &#39;newSymbol&#39; that...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module&#39;s symbolic name into the name of the entry point function.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:662
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
This class helps build Operations.
Definition: Builders.h:177
static constexpr const char kSPIRVModule[]