MLIR  14.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
12 #include "mlir/IR/Builders.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 using namespace mlir;
16 
18 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
19  ConversionPatternRewriter &rewriter) const {
20  Location loc = gpuFuncOp.getLoc();
21 
22  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
23  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
24  for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
25  Value attribution = en.value();
26 
27  auto type = attribution.getType().dyn_cast<MemRefType>();
28  assert(type && type.hasStaticShape() && "unexpected type in attribution");
29 
30  uint64_t numElements = type.getNumElements();
31 
32  auto elementType =
33  typeConverter->convertType(type.getElementType()).template cast<Type>();
34  auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
35  std::string name = std::string(
36  llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
37  auto globalOp = rewriter.create<LLVM::GlobalOp>(
38  gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
39  LLVM::Linkage::Internal, name, /*value=*/Attribute(),
40  /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
41  workgroupBuffers.push_back(globalOp);
42  }
43 
44  // Rewrite the original GPU function to an LLVM function.
45  auto funcType = typeConverter->convertType(gpuFuncOp.getType())
46  .template cast<LLVM::LLVMPointerType>()
47  .getElementType();
48 
49  // Remap proper input types.
50  TypeConverter::SignatureConversion signatureConversion(
51  gpuFuncOp.front().getNumArguments());
53  gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
54 
55  // Create the new function operation. Only copy those attributes that are
56  // not specific to function modeling.
58  for (const auto &attr : gpuFuncOp->getAttrs()) {
59  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
60  attr.getName() == FunctionOpInterface::getTypeAttrName() ||
61  attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
62  continue;
63  attributes.push_back(attr);
64  }
65  // Add a dialect specific kernel attribute in addition to GPU kernel
66  // attribute. The former is necessary for further translation while the
67  // latter is expected by gpu.launch_func.
68  if (gpuFuncOp.isKernel())
69  attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
70  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
71  gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
72  LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
73 
74  {
75  // Insert operations that correspond to converted workgroup and private
76  // memory attributions to the body of the function. This must operate on
77  // the original function, before the body region is inlined in the new
78  // function to maintain the relation between block arguments and the
79  // parent operation that assigns their semantics.
80  OpBuilder::InsertionGuard guard(rewriter);
81 
82  // Rewrite workgroup memory attributions to addresses of global buffers.
83  rewriter.setInsertionPointToStart(&gpuFuncOp.front());
84  unsigned numProperArguments = gpuFuncOp.getNumArguments();
85  auto i32Type = IntegerType::get(rewriter.getContext(), 32);
86 
87  Value zero = nullptr;
88  if (!workgroupBuffers.empty())
89  zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
90  rewriter.getI32IntegerAttr(0));
91  for (const auto &en : llvm::enumerate(workgroupBuffers)) {
92  LLVM::GlobalOp global = en.value();
93  Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
94  auto elementType =
95  global.getType().cast<LLVM::LLVMArrayType>().getElementType();
96  Value memory = rewriter.create<LLVM::GEPOp>(
97  loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
98  address, ArrayRef<Value>{zero, zero});
99 
100  // Build a memref descriptor pointing to the buffer to plug with the
101  // existing memref infrastructure. This may use more registers than
102  // otherwise necessary given that memref sizes are fixed, but we can try
103  // and canonicalize that away later.
104  Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
105  auto type = attribution.getType().cast<MemRefType>();
107  rewriter, loc, *getTypeConverter(), type, memory);
108  signatureConversion.remapInput(numProperArguments + en.index(), descr);
109  }
110 
111  // Rewrite private memory attributions to alloca'ed buffers.
112  unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
113  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
114  for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
115  Value attribution = en.value();
116  auto type = attribution.getType().cast<MemRefType>();
117  assert(type && type.hasStaticShape() && "unexpected type in attribution");
118 
119  // Explicitly drop memory space when lowering private memory
120  // attributions since NVVM models it as `alloca`s in the default
121  // memory space and does not support `alloca`s with addrspace(5).
122  auto ptrType = LLVM::LLVMPointerType::get(
123  typeConverter->convertType(type.getElementType())
124  .template cast<Type>(),
125  allocaAddrSpace);
126  Value numElements = rewriter.create<LLVM::ConstantOp>(
127  gpuFuncOp.getLoc(), int64Ty,
128  rewriter.getI64IntegerAttr(type.getNumElements()));
129  Value allocated = rewriter.create<LLVM::AllocaOp>(
130  gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
132  rewriter, loc, *getTypeConverter(), type, allocated);
133  signatureConversion.remapInput(
134  numProperArguments + numWorkgroupAttributions + en.index(), descr);
135  }
136  }
137 
138  // Move the region to the new function, update the entry block signature.
139  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
140  llvmFuncOp.end());
141  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
142  &signatureConversion)))
143  return failure();
144 
145  rewriter.eraseOp(gpuFuncOp);
146  return success();
147 }
148 
149 static const char formatStringPrefix[] = "printfFormat_";
150 
151 template <typename T>
152 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
153  ConversionPatternRewriter &rewriter,
154  StringRef name,
155  LLVM::LLVMFunctionType type) {
156  LLVM::LLVMFuncOp ret;
157  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158  ConversionPatternRewriter::InsertionGuard guard(rewriter);
159  rewriter.setInsertionPointToStart(moduleOp.getBody());
160  ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
161  LLVM::Linkage::External);
162  }
163  return ret;
164 }
165 
167  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
168  ConversionPatternRewriter &rewriter) const {
169  Location loc = gpuPrintfOp->getLoc();
170 
171  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
172  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
173  mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
174  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
175  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
176  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
177  // This ensures that global constants and declarations are placed within
178  // the device code, not the host code
179  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
180 
181  auto ocklBegin =
182  getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
183  LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
184  LLVM::LLVMFuncOp ocklAppendArgs;
185  if (!adaptor.args().empty()) {
186  ocklAppendArgs = getOrDefineFunction(
187  moduleOp, loc, rewriter, "__ockl_printf_append_args",
189  llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
190  llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
191  }
192  auto ocklAppendStringN = getOrDefineFunction(
193  moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
195  llvmI64,
196  {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
197 
198  /// Start the printf hostcall
199  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
200  loc, llvmI64, rewriter.getI64IntegerAttr(0));
201  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202  Value printfDesc = printfBeginCall.getResult(0);
203 
204  // Create a global constant for the format string
205  unsigned stringNumber = 0;
206  SmallString<16> stringConstName;
207  do {
208  stringConstName.clear();
209  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
210  } while (moduleOp.lookupSymbol(stringConstName));
211 
212  llvm::SmallString<20> formatString(adaptor.format());
213  formatString.push_back('\0'); // Null terminate for C
214  size_t formatStringSize = formatString.size_in_bytes();
215 
216  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
217  LLVM::GlobalOp global;
218  {
220  rewriter.setInsertionPointToStart(moduleOp.getBody());
221  global = rewriter.create<LLVM::GlobalOp>(
222  loc, globalType,
223  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
224  rewriter.getStringAttr(formatString));
225  }
226 
227  // Get a pointer to the format string's first element and pass it to printf()
228  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
229  Value zero = rewriter.create<LLVM::ConstantOp>(
230  loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
231  Value stringStart = rewriter.create<LLVM::GEPOp>(
232  loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
233  Value stringLen = rewriter.create<LLVM::ConstantOp>(
234  loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
235 
236  Value oneI32 = rewriter.create<LLVM::ConstantOp>(
237  loc, llvmI32, rewriter.getI32IntegerAttr(1));
238  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
239  loc, llvmI32, rewriter.getI32IntegerAttr(0));
240 
241  auto appendFormatCall = rewriter.create<LLVM::CallOp>(
242  loc, ocklAppendStringN,
243  ValueRange{printfDesc, stringStart, stringLen,
244  adaptor.args().empty() ? oneI32 : zeroI32});
245  printfDesc = appendFormatCall.getResult(0);
246 
247  // __ockl_printf_append_args takes 7 values per append call
248  constexpr size_t argsPerAppend = 7;
249  size_t nArgs = adaptor.args().size();
250  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
251  size_t bound = std::min(group + argsPerAppend, nArgs);
252  size_t numArgsThisCall = bound - group;
253 
255  arguments.push_back(printfDesc);
256  arguments.push_back(rewriter.create<LLVM::ConstantOp>(
257  loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
258  for (size_t i = group; i < bound; ++i) {
259  Value arg = adaptor.args()[i];
260  if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
261  if (!floatType.isF64())
262  arg = rewriter.create<LLVM::FPExtOp>(
263  loc, typeConverter->convertType(rewriter.getF64Type()), arg);
264  arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
265  }
266  if (arg.getType().getIntOrFloatBitWidth() != 64)
267  arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
268 
269  arguments.push_back(arg);
270  }
271  // Pad out to 7 arguments since the hostcall always needs 7
272  for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
273  arguments.push_back(zeroI64);
274  }
275 
276  auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
277  arguments.push_back(isLast);
278  auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
279  printfDesc = call.getResult(0);
280  }
281  rewriter.eraseOp(gpuPrintfOp);
282  return success();
283 }
284 
286  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const {
288  Location loc = gpuPrintfOp->getLoc();
289 
290  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
291  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
292  mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
293 
294  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
295  // This ensures that global constants and declarations are placed within
296  // the device code, not the host code
297  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
298 
299  auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
300  /*isVarArg=*/true);
301  LLVM::LLVMFuncOp printfDecl =
302  getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
303 
304  // Create a global constant for the format string
305  unsigned stringNumber = 0;
306  SmallString<16> stringConstName;
307  do {
308  stringConstName.clear();
309  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
310  } while (moduleOp.lookupSymbol(stringConstName));
311 
312  llvm::SmallString<20> formatString(adaptor.format());
313  formatString.push_back('\0'); // Null terminate for C
314  auto globalType =
315  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
316  LLVM::GlobalOp global;
317  {
319  rewriter.setInsertionPointToStart(moduleOp.getBody());
320  global = rewriter.create<LLVM::GlobalOp>(
321  loc, globalType,
322  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
323  rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
324  }
325 
326  // Get a pointer to the format string's first element
327  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
328  Value zero = rewriter.create<LLVM::ConstantOp>(
329  loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
330  Value stringStart = rewriter.create<LLVM::GEPOp>(
331  loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
332 
333  // Construct arguments and function call
334  auto argsRange = adaptor.args();
335  SmallVector<Value, 4> printfArgs;
336  printfArgs.reserve(argsRange.size() + 1);
337  printfArgs.push_back(stringStart);
338  printfArgs.append(argsRange.begin(), argsRange.end());
339 
340  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
341  rewriter.eraseOp(gpuPrintfOp);
342  return success();
343 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
MLIRContext * getContext() const
Definition: Builders.h:54
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
Definition: LLVMTypes.cpp:39
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
LLVM dialect function type.
Definition: LLVMTypes.h:123
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
static LLVMPointerType get(Type pointee, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:165
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
Definition: LLVMTypes.cpp:101
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
IntegerType getI8Type()
Definition: Builders.cpp:52
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
LLVM dialect array type.
Definition: LLVMTypes.h:74
IntegerType getI64Type()
Definition: Builders.cpp:56
TypeConverter * typeConverter
An optional type converter for use by this pattern.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
Type getType() const
Return the type of this value.
Definition: Value.h:117
IndexType getIndexType()
Definition: Builders.cpp:48
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FloatType getF64Type()
Definition: Builders.cpp:42
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result)
Convert a function type.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class provides an abstraction over the different types of ranges over Values.
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
IntegerType getI32Type()
Definition: Builders.cpp:54
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
U cast() const
Definition: Types.h:250
static const char formatStringPrefix[]