MLIR  17.0.0git
ConvertLaunchFuncToLLVMCalls.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/Pass/Pass.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 
35 namespace mlir {
36 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 
42 static constexpr const char kSPIRVModule[] = "__spv__";
43 
44 //===----------------------------------------------------------------------===//
45 // Utility functions
46 //===----------------------------------------------------------------------===//
47 
48 /// Returns the string name of the `DescriptorSet` decoration.
49 static std::string descriptorSetName() {
50  return llvm::convertToSnakeFromCamelCase(
51  stringifyDecoration(spirv::Decoration::DescriptorSet));
52 }
53 
54 /// Returns the string name of the `Binding` decoration.
55 static std::string bindingName() {
56  return llvm::convertToSnakeFromCamelCase(
57  stringifyDecoration(spirv::Decoration::Binding));
58 }
59 
60 /// Calculates the index of the kernel's operand that is represented by the
61 /// given global variable with the `bind` attribute. We assume that the index of
62 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
63 /// i -> (0, i)
64 /// which is implemented under `LowerABIAttributesPass`.
65 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
66  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
67  return binding.getInt();
68 }
69 
70 /// Copies the given number of bytes from src to dst pointers.
71 static void copy(Location loc, Value dst, Value src, Value size,
72  OpBuilder &builder) {
73  MLIRContext *context = builder.getContext();
74  auto llvmI1Type = IntegerType::get(context, 1);
75  Value isVolatile = builder.create<LLVM::ConstantOp>(
76  loc, llvmI1Type, builder.getBoolAttr(false));
77  builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
78 }
79 
80 /// Encodes the binding and descriptor set numbers into a new symbolic name.
81 /// The name is specified by
82 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
83 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
84 /// binding numbers.
85 static std::string
86 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
87  StringRef kernelModuleName) {
88  IntegerAttr descriptorSet =
89  op->getAttrOfType<IntegerAttr>(descriptorSetName());
90  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
91  return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
92  kernelModuleName.str(), op.getSymName().str(),
93  std::to_string(descriptorSet.getInt()),
94  std::to_string(binding.getInt()));
95 }
96 
97 /// Returns true if the given global variable has both a descriptor set number
98 /// and a binding number.
99 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
100  IntegerAttr descriptorSet =
101  op->getAttrOfType<IntegerAttr>(descriptorSetName());
102  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
103  return descriptorSet && binding;
104 }
105 
106 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
107 /// arguments from the given SPIR-V module. We assume that the module contains a
108 /// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind
109 /// attribute are kernel arguments.
111  spirv::ModuleOp module,
112  DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
113  auto entryPoints = module.getOps<spirv::EntryPointOp>();
114  if (!llvm::hasSingleElement(entryPoints)) {
115  return module.emitError(
116  "The module must contain exactly one entry point function");
117  }
118  auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
119  for (auto globalOp : globalVariables) {
120  if (hasDescriptorSetAndBinding(globalOp))
121  globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
122  }
123  return success();
124 }
125 
126 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
127 /// function.
128 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
129  StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule);
130  // We already know that the module contains exactly one entry point function
131  // based on `getKernelGlobalVariables()` call. Update this function's name
132  // to:
133  // {spv_module_name}_{function_name}
134  auto entryPoints = module.getOps<spirv::EntryPointOp>();
135  if (!llvm::hasSingleElement(entryPoints)) {
136  return module.emitError(
137  "The module must contain exactly one entry point function");
138  }
139  spirv::EntryPointOp entryPoint = *entryPoints.begin();
140  StringRef funcName = entryPoint.getFn();
141  auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
142  StringAttr newFuncName =
143  StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
144  if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
145  return failure();
146  SymbolTable::setSymbolName(funcOp, newFuncName);
147  return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Conversion patterns
152 //===----------------------------------------------------------------------===//
153 
154 namespace {
155 
156 /// Structure to group information about the variables being copied.
157 struct CopyInfo {
158  Value dst;
159  Value src;
160  Value size;
161 };
162 
163 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
164 /// copy the data to the global variable (emulating device side), call the
165 /// kernel as a normal void LLVM function, and copy the data back (emulating the
166 /// host side).
167 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
169 
171  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter) const override {
173  auto *op = launchOp.getOperation();
174  MLIRContext *context = rewriter.getContext();
175  auto module = launchOp->getParentOfType<ModuleOp>();
176 
177  // Get the SPIR-V module that represents the gpu kernel module. The module
178  // is named:
179  // __spv__{kernel_module_name}
180  // based on GPU to SPIR-V conversion.
181  StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
182  std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
183  auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
184  StringAttr::get(context, spvModuleName));
185  if (!spvModule) {
186  return launchOp.emitOpError("SPIR-V kernel module '")
187  << spvModuleName << "' is not found";
188  }
189 
190  // Declare kernel function in the main module so that it later can be linked
191  // with its definition from the kernel module. We know that the kernel
192  // function would have no arguments and the data is passed via global
193  // variables. The name of the kernel will be
194  // {spv_module_name}_{kernel_function_name}
195  // to avoid symbolic name conflicts.
196  StringRef kernelFuncName = launchOp.getKernelName().getValue();
197  std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
198  auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
199  StringAttr::get(context, newKernelFuncName));
200  if (!kernelFunc) {
201  OpBuilder::InsertionGuard guard(rewriter);
202  rewriter.setInsertionPointToStart(module.getBody());
203  kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
204  rewriter.getUnknownLoc(), newKernelFuncName,
206  ArrayRef<Type>()));
207  rewriter.setInsertionPoint(launchOp);
208  }
209 
210  // Get all global variables associated with the kernel operands.
212  if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
213  return failure();
214 
215  // Traverse kernel operands that were converted to MemRefDescriptors. For
216  // each operand, create a global variable and copy data from operand to it.
217  Location loc = launchOp.getLoc();
218  SmallVector<CopyInfo, 4> copyInfo;
219  auto numKernelOperands = launchOp.getNumKernelOperands();
220  auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
221  for (const auto &operand : llvm::enumerate(kernelOperands)) {
222  // Check if the kernel's operand is a ranked memref.
223  auto memRefType = dyn_cast<MemRefType>(
224  launchOp.getKernelOperand(operand.index()).getType());
225  if (!memRefType)
226  return failure();
227 
228  // Calculate the size of the memref and get the pointer to the allocated
229  // buffer.
230  SmallVector<Value, 4> sizes;
231  SmallVector<Value, 4> strides;
232  Value sizeBytes;
233  getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
234  sizeBytes);
235  MemRefDescriptor descriptor(operand.value());
236  Value src = descriptor.allocatedPtr(rewriter, loc);
237 
238  // Get the global variable in the SPIR-V module that is associated with
239  // the kernel operand. Construct its new name and create a corresponding
240  // LLVM dialect global variable.
241  spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
242  auto pointeeType =
243  cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
244  auto dstGlobalType = typeConverter->convertType(pointeeType);
245  if (!dstGlobalType)
246  return failure();
247  std::string name =
248  createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
249  // Check if this variable has already been created.
250  auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
251  if (!dstGlobal) {
252  OpBuilder::InsertionGuard guard(rewriter);
253  rewriter.setInsertionPointToStart(module.getBody());
254  dstGlobal = rewriter.create<LLVM::GlobalOp>(
255  loc, dstGlobalType,
256  /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
257  /*alignment=*/0);
258  rewriter.setInsertionPoint(launchOp);
259  }
260 
261  // Copy the data from src operand pointer to dst global variable. Save
262  // src, dst and size so that we can copy data back after emulating the
263  // kernel call.
264  Value dst = rewriter.create<LLVM::AddressOfOp>(
265  loc, typeConverter->convertType(spirvGlobal.getType()),
266  dstGlobal.getSymName());
267  copy(loc, dst, src, sizeBytes, rewriter);
268 
269  CopyInfo info;
270  info.dst = dst;
271  info.src = src;
272  info.size = sizeBytes;
273  copyInfo.push_back(info);
274  }
275  // Create a call to the kernel and copy the data back.
276  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
277  ArrayRef<Value>());
278  for (CopyInfo info : copyInfo)
279  copy(loc, info.src, info.dst, info.size, rewriter);
280  return success();
281  }
282 };
283 
284 class LowerHostCodeToLLVM
285  : public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
286 public:
287 
288  using Base::Base;
289 
290  void runOnOperation() override {
291  ModuleOp module = getOperation();
292 
293  // Erase the GPU module.
294  for (auto gpuModule :
295  llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
296  gpuModule.erase();
297 
298  // Request C wrapper emission.
299  for (auto func : module.getOps<func::FuncOp>()) {
300  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
301  UnitAttr::get(&getContext()));
302  }
303 
304  // Specify options to lower to LLVM and pull in the conversion patterns.
305  LowerToLLVMOptions options(module.getContext());
306  options.useOpaquePointers = useOpaquePointers;
307 
308  auto *context = module.getContext();
309  RewritePatternSet patterns(context);
310  LLVMTypeConverter typeConverter(context, options);
312  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
313  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
314  patterns.add<GPULaunchLowering>(typeConverter);
315 
316  // Pull in SPIR-V type conversion patterns to convert SPIR-V global
317  // variable's type to LLVM dialect type.
318  populateSPIRVToLLVMTypeConversion(typeConverter);
319 
320  ConversionTarget target(*context);
321  target.addLegalDialect<LLVM::LLVMDialect>();
322  if (failed(applyPartialConversion(module, target, std::move(patterns))))
323  signalPassFailure();
324 
325  // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
326  // conflicts.
327  for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
328  if (failed(encodeKernelName(spvModule))) {
329  signalPassFailure();
330  return;
331  }
332  }
333  }
334 };
335 } // namespace
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const char kSPIRVModule[]
static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName)
Encodes the binding and descriptor set numbers into a new symbolic name.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op)
Calculates the index of the kernel's operand that is represented by the given global variable with th...
static LogicalResult encodeKernelName(spirv::ModuleOp module)
Encodes the SPIR-V module's symbolic name into the name of the entry point function.
static LogicalResult getKernelGlobalVariables(spirv::ModuleOp module, DenseMap< uint32_t, spirv::GlobalVariableOp > &globalVariableMap)
Fills globalVariableMap with SPIR-V global variables that represent kernel arguments from the given S...
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)
Returns true if the given global variable has both a descriptor set number and a binding number.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:114
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:433
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:511
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:514
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:262
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:720
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26