MLIR  15.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
11 #include "mlir/IR/Builders.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
17 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
18  ConversionPatternRewriter &rewriter) const {
19  Location loc = gpuFuncOp.getLoc();
20 
21  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
22  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23  for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
24  Value attribution = en.value();
25 
26  auto type = attribution.getType().dyn_cast<MemRefType>();
27  assert(type && type.hasStaticShape() && "unexpected type in attribution");
28 
29  uint64_t numElements = type.getNumElements();
30 
31  auto elementType =
32  typeConverter->convertType(type.getElementType()).template cast<Type>();
33  auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
34  std::string name = std::string(
35  llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
36  auto globalOp = rewriter.create<LLVM::GlobalOp>(
37  gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
38  LLVM::Linkage::Internal, name, /*value=*/Attribute(),
39  /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
40  workgroupBuffers.push_back(globalOp);
41  }
42 
43  // Rewrite the original GPU function to an LLVM function.
44  auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType())
45  .template cast<LLVM::LLVMPointerType>()
46  .getElementType();
47 
48  // Remap proper input types.
49  TypeConverter::SignatureConversion signatureConversion(
50  gpuFuncOp.front().getNumArguments());
52  gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
53 
54  // Create the new function operation. Only copy those attributes that are
55  // not specific to function modeling.
57  for (const auto &attr : gpuFuncOp->getAttrs()) {
58  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
59  attr.getName() == FunctionOpInterface::getTypeAttrName() ||
60  attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
61  continue;
62  attributes.push_back(attr);
63  }
64  // Add a dialect specific kernel attribute in addition to GPU kernel
65  // attribute. The former is necessary for further translation while the
66  // latter is expected by gpu.launch_func.
67  if (gpuFuncOp.isKernel())
68  attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
69  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
70  gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71  LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
72  attributes);
73 
74  {
75  // Insert operations that correspond to converted workgroup and private
76  // memory attributions to the body of the function. This must operate on
77  // the original function, before the body region is inlined in the new
78  // function to maintain the relation between block arguments and the
79  // parent operation that assigns their semantics.
80  OpBuilder::InsertionGuard guard(rewriter);
81 
82  // Rewrite workgroup memory attributions to addresses of global buffers.
83  rewriter.setInsertionPointToStart(&gpuFuncOp.front());
84  unsigned numProperArguments = gpuFuncOp.getNumArguments();
85  auto i32Type = IntegerType::get(rewriter.getContext(), 32);
86 
87  Value zero = nullptr;
88  if (!workgroupBuffers.empty())
89  zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
90  rewriter.getI32IntegerAttr(0));
91  for (const auto &en : llvm::enumerate(workgroupBuffers)) {
92  LLVM::GlobalOp global = en.value();
93  Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
94  auto elementType =
95  global.getType().cast<LLVM::LLVMArrayType>().getElementType();
96  Value memory = rewriter.create<LLVM::GEPOp>(
97  loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
98  address, ArrayRef<Value>{zero, zero});
99 
100  // Build a memref descriptor pointing to the buffer to plug with the
101  // existing memref infrastructure. This may use more registers than
102  // otherwise necessary given that memref sizes are fixed, but we can try
103  // and canonicalize that away later.
104  Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
105  auto type = attribution.getType().cast<MemRefType>();
107  rewriter, loc, *getTypeConverter(), type, memory);
108  signatureConversion.remapInput(numProperArguments + en.index(), descr);
109  }
110 
111  // Rewrite private memory attributions to alloca'ed buffers.
112  unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
113  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
114  for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
115  Value attribution = en.value();
116  auto type = attribution.getType().cast<MemRefType>();
117  assert(type && type.hasStaticShape() && "unexpected type in attribution");
118 
119  // Explicitly drop memory space when lowering private memory
120  // attributions since NVVM models it as `alloca`s in the default
121  // memory space and does not support `alloca`s with addrspace(5).
122  auto ptrType = LLVM::LLVMPointerType::get(
123  typeConverter->convertType(type.getElementType())
124  .template cast<Type>(),
125  allocaAddrSpace);
126  Value numElements = rewriter.create<LLVM::ConstantOp>(
127  gpuFuncOp.getLoc(), int64Ty,
128  rewriter.getI64IntegerAttr(type.getNumElements()));
129  Value allocated = rewriter.create<LLVM::AllocaOp>(
130  gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
132  rewriter, loc, *getTypeConverter(), type, allocated);
133  signatureConversion.remapInput(
134  numProperArguments + numWorkgroupAttributions + en.index(), descr);
135  }
136  }
137 
138  // Move the region to the new function, update the entry block signature.
139  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
140  llvmFuncOp.end());
141  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
142  &signatureConversion)))
143  return failure();
144 
145  rewriter.eraseOp(gpuFuncOp);
146  return success();
147 }
148 
149 static const char formatStringPrefix[] = "printfFormat_";
150 
151 template <typename T>
152 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
153  ConversionPatternRewriter &rewriter,
154  StringRef name,
155  LLVM::LLVMFunctionType type) {
156  LLVM::LLVMFuncOp ret;
157  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158  ConversionPatternRewriter::InsertionGuard guard(rewriter);
159  rewriter.setInsertionPointToStart(moduleOp.getBody());
160  ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
161  LLVM::Linkage::External);
162  }
163  return ret;
164 }
165 
167  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
168  ConversionPatternRewriter &rewriter) const {
169  Location loc = gpuPrintfOp->getLoc();
170 
171  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
172  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
173  mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
174  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
175  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
176  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
177  // This ensures that global constants and declarations are placed within
178  // the device code, not the host code
179  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
180 
181  auto ocklBegin =
182  getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
183  LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
184  LLVM::LLVMFuncOp ocklAppendArgs;
185  if (!adaptor.args().empty()) {
186  ocklAppendArgs = getOrDefineFunction(
187  moduleOp, loc, rewriter, "__ockl_printf_append_args",
189  llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
190  llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
191  }
192  auto ocklAppendStringN = getOrDefineFunction(
193  moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
195  llvmI64,
196  {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
197 
198  /// Start the printf hostcall
199  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
200  loc, llvmI64, rewriter.getI64IntegerAttr(0));
201  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202  Value printfDesc = printfBeginCall.getResult(0);
203 
204  // Create a global constant for the format string
205  unsigned stringNumber = 0;
206  SmallString<16> stringConstName;
207  do {
208  stringConstName.clear();
209  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
210  } while (moduleOp.lookupSymbol(stringConstName));
211 
212  llvm::SmallString<20> formatString(adaptor.format());
213  formatString.push_back('\0'); // Null terminate for C
214  size_t formatStringSize = formatString.size_in_bytes();
215 
216  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
217  LLVM::GlobalOp global;
218  {
220  rewriter.setInsertionPointToStart(moduleOp.getBody());
221  global = rewriter.create<LLVM::GlobalOp>(
222  loc, globalType,
223  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
224  rewriter.getStringAttr(formatString));
225  }
226 
227  // Get a pointer to the format string's first element and pass it to printf()
228  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
229  Value zero = rewriter.create<LLVM::ConstantOp>(
230  loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
231  Value stringStart = rewriter.create<LLVM::GEPOp>(
232  loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
233  Value stringLen = rewriter.create<LLVM::ConstantOp>(
234  loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
235 
236  Value oneI32 = rewriter.create<LLVM::ConstantOp>(
237  loc, llvmI32, rewriter.getI32IntegerAttr(1));
238  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
239  loc, llvmI32, rewriter.getI32IntegerAttr(0));
240 
241  auto appendFormatCall = rewriter.create<LLVM::CallOp>(
242  loc, ocklAppendStringN,
243  ValueRange{printfDesc, stringStart, stringLen,
244  adaptor.args().empty() ? oneI32 : zeroI32});
245  printfDesc = appendFormatCall.getResult(0);
246 
247  // __ockl_printf_append_args takes 7 values per append call
248  constexpr size_t argsPerAppend = 7;
249  size_t nArgs = adaptor.args().size();
250  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
251  size_t bound = std::min(group + argsPerAppend, nArgs);
252  size_t numArgsThisCall = bound - group;
253 
255  arguments.push_back(printfDesc);
256  arguments.push_back(rewriter.create<LLVM::ConstantOp>(
257  loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
258  for (size_t i = group; i < bound; ++i) {
259  Value arg = adaptor.args()[i];
260  if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
261  if (!floatType.isF64())
262  arg = rewriter.create<LLVM::FPExtOp>(
263  loc, typeConverter->convertType(rewriter.getF64Type()), arg);
264  arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
265  }
266  if (arg.getType().getIntOrFloatBitWidth() != 64)
267  arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
268 
269  arguments.push_back(arg);
270  }
271  // Pad out to 7 arguments since the hostcall always needs 7
272  for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
273  arguments.push_back(zeroI64);
274  }
275 
276  auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
277  arguments.push_back(isLast);
278  auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
279  printfDesc = call.getResult(0);
280  }
281  rewriter.eraseOp(gpuPrintfOp);
282  return success();
283 }
284 
286  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const {
288  Location loc = gpuPrintfOp->getLoc();
289 
290  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
291  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
292  mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
293 
294  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
295  // This ensures that global constants and declarations are placed within
296  // the device code, not the host code
297  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
298 
299  auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
300  /*isVarArg=*/true);
301  LLVM::LLVMFuncOp printfDecl =
302  getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
303 
304  // Create a global constant for the format string
305  unsigned stringNumber = 0;
306  SmallString<16> stringConstName;
307  do {
308  stringConstName.clear();
309  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
310  } while (moduleOp.lookupSymbol(stringConstName));
311 
312  llvm::SmallString<20> formatString(adaptor.format());
313  formatString.push_back('\0'); // Null terminate for C
314  auto globalType =
315  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
316  LLVM::GlobalOp global;
317  {
319  rewriter.setInsertionPointToStart(moduleOp.getBody());
320  global = rewriter.create<LLVM::GlobalOp>(
321  loc, globalType,
322  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
323  rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
324  }
325 
326  // Get a pointer to the format string's first element
327  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
328  Value zero = rewriter.create<LLVM::ConstantOp>(
329  loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
330  Value stringStart = rewriter.create<LLVM::GEPOp>(
331  loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
332 
333  // Construct arguments and function call
334  auto argsRange = adaptor.args();
335  SmallVector<Value, 4> printfArgs;
336  printfArgs.reserve(argsRange.size() + 1);
337  printfArgs.push_back(stringStart);
338  printfArgs.append(argsRange.begin(), argsRange.end());
339 
340  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
341  rewriter.eraseOp(gpuPrintfOp);
342  return success();
343 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
MLIRContext * getContext() const
Definition: Builders.h:54
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
Definition: LLVMTypes.cpp:39
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
LLVM dialect function type.
Definition: LLVMTypes.h:128
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
Definition: LLVMTypes.cpp:107
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:256
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
IntegerType getI8Type()
Definition: Builders.cpp:52
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:188
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LLVM dialect array type.
Definition: LLVMTypes.h:75
IntegerType getI64Type()
Definition: Builders.cpp:56
TypeConverter * typeConverter
An optional type converter for use by this pattern.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FloatType getF64Type()
Definition: Builders.cpp:42
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result)
Convert a function type.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class provides an abstraction over the different types of ranges over Values.
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
IntegerType getI32Type()
Definition: Builders.cpp:54
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
U cast() const
Definition: Types.h:262
static const char formatStringPrefix[]