MLIR  20.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/ADT/SmallVectorExtras.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir;
21 
22 LogicalResult
23 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
24  ConversionPatternRewriter &rewriter) const {
25  Location loc = gpuFuncOp.getLoc();
26 
27  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
28  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
29  for (const auto [idx, attribution] :
30  llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
31  auto type = dyn_cast<MemRefType>(attribution.getType());
32  assert(type && type.hasStaticShape() && "unexpected type in attribution");
33 
34  uint64_t numElements = type.getNumElements();
35 
36  auto elementType =
37  cast<Type>(typeConverter->convertType(type.getElementType()));
38  auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
39  std::string name =
40  std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
41  uint64_t alignment = 0;
42  if (auto alignAttr =
43  dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
44  idx, LLVM::LLVMDialect::getAlignAttrName())))
45  alignment = alignAttr.getInt();
46  auto globalOp = rewriter.create<LLVM::GlobalOp>(
47  gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
48  LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
49  workgroupAddrSpace);
50  workgroupBuffers.push_back(globalOp);
51  }
52 
53  // Remap proper input types.
54  TypeConverter::SignatureConversion signatureConversion(
55  gpuFuncOp.front().getNumArguments());
56 
58  gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
59  getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
60  if (!funcType) {
61  return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
62  diag << "failed to convert function signature type for: "
63  << gpuFuncOp.getFunctionType();
64  });
65  }
66 
67  // Create the new function operation. Only copy those attributes that are
68  // not specific to function modeling.
70  ArrayAttr argAttrs;
71  for (const auto &attr : gpuFuncOp->getAttrs()) {
72  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
73  attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
74  attr.getName() ==
75  gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
76  attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
77  attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
78  attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
79  attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
80  continue;
81  if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
82  argAttrs = gpuFuncOp.getArgAttrsAttr();
83  continue;
84  }
85  attributes.push_back(attr);
86  }
87 
88  DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
89  DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
90  // Ensure we don't lose information if the function is lowered before its
91  // surrounding context.
92  auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
93  if (knownBlockSize)
94  attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
95  knownBlockSize);
96  if (knownGridSize)
97  attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
98  knownGridSize);
99 
100  // Add a dialect specific kernel attribute in addition to GPU kernel
101  // attribute. The former is necessary for further translation while the
102  // latter is expected by gpu.launch_func.
103  if (gpuFuncOp.isKernel()) {
104  attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
105  // Set the dialect-specific block size attribute if there is one.
106  if (kernelBlockSizeAttributeName.has_value() && knownBlockSize) {
107  attributes.emplace_back(kernelBlockSizeAttributeName.value(),
108  knownBlockSize);
109  }
110  }
111  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
112  gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
113  LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C,
114  /*comdat=*/nullptr, attributes);
115 
116  {
117  // Insert operations that correspond to converted workgroup and private
118  // memory attributions to the body of the function. This must operate on
119  // the original function, before the body region is inlined in the new
120  // function to maintain the relation between block arguments and the
121  // parent operation that assigns their semantics.
122  OpBuilder::InsertionGuard guard(rewriter);
123 
124  // Rewrite workgroup memory attributions to addresses of global buffers.
125  rewriter.setInsertionPointToStart(&gpuFuncOp.front());
126  unsigned numProperArguments = gpuFuncOp.getNumArguments();
127 
128  for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
129  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
130  global.getAddrSpace());
131  Value address = rewriter.create<LLVM::AddressOfOp>(
132  loc, ptrType, global.getSymNameAttr());
133  Value memory =
134  rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), address,
135  ArrayRef<LLVM::GEPArg>{0, 0});
136 
137  // Build a memref descriptor pointing to the buffer to plug with the
138  // existing memref infrastructure. This may use more registers than
139  // otherwise necessary given that memref sizes are fixed, but we can try
140  // and canonicalize that away later.
141  Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
142  auto type = cast<MemRefType>(attribution.getType());
144  rewriter, loc, *getTypeConverter(), type, memory);
145  signatureConversion.remapInput(numProperArguments + idx, descr);
146  }
147 
148  // Rewrite private memory attributions to alloca'ed buffers.
149  unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
150  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
151  for (const auto [idx, attribution] :
152  llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
153  auto type = cast<MemRefType>(attribution.getType());
154  assert(type && type.hasStaticShape() && "unexpected type in attribution");
155 
156  // Explicitly drop memory space when lowering private memory
157  // attributions since NVVM models it as `alloca`s in the default
158  // memory space and does not support `alloca`s with addrspace(5).
159  Type elementType = typeConverter->convertType(type.getElementType());
160  auto ptrType =
161  LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
162  Value numElements = rewriter.create<LLVM::ConstantOp>(
163  gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
164  uint64_t alignment = 0;
165  if (auto alignAttr =
166  dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
167  idx, LLVM::LLVMDialect::getAlignAttrName())))
168  alignment = alignAttr.getInt();
169  Value allocated = rewriter.create<LLVM::AllocaOp>(
170  gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
172  rewriter, loc, *getTypeConverter(), type, allocated);
173  signatureConversion.remapInput(
174  numProperArguments + numWorkgroupAttributions + idx, descr);
175  }
176  }
177 
178  // Move the region to the new function, update the entry block signature.
179  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
180  llvmFuncOp.end());
181  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
182  &signatureConversion)))
183  return failure();
184 
185  // Get memref type from function arguments and set the noalias to
186  // pointer arguments.
187  for (const auto [idx, argTy] :
188  llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
189  auto remapping = signatureConversion.getInputMapping(idx);
190  NamedAttrList argAttr =
191  argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
192  auto copyAttribute = [&](StringRef attrName) {
193  Attribute attr = argAttr.erase(attrName);
194  if (!attr)
195  return;
196  for (size_t i = 0, e = remapping->size; i < e; ++i)
197  llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
198  };
199  auto copyPointerAttribute = [&](StringRef attrName) {
200  Attribute attr = argAttr.erase(attrName);
201 
202  if (!attr)
203  return;
204  if (remapping->size > 1 &&
205  attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
206  emitWarning(llvmFuncOp.getLoc(),
207  "Cannot copy noalias with non-bare pointers.\n");
208  return;
209  }
210  for (size_t i = 0, e = remapping->size; i < e; ++i) {
211  if (isa<LLVM::LLVMPointerType>(
212  llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
213  llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
214  }
215  }
216  };
217 
218  if (argAttr.empty())
219  continue;
220 
221  copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
222  copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
223  copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
224  bool lowersToPointer = false;
225  for (size_t i = 0, e = remapping->size; i < e; ++i) {
226  lowersToPointer |= isa<LLVM::LLVMPointerType>(
227  llvmFuncOp.getArgument(remapping->inputNo + i).getType());
228  }
229 
230  if (lowersToPointer) {
231  copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
232  copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
233  copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
234  copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
235  copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
236  copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
237  copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
238  copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
239  copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
240  copyPointerAttribute(
241  LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
242  }
243  }
244  rewriter.eraseOp(gpuFuncOp);
245  return success();
246 }
247 
248 static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
249  const char formatStringPrefix[] = "printfFormat_";
250  // Get a unique global name.
251  unsigned stringNumber = 0;
252  SmallString<16> stringConstName;
253  do {
254  stringConstName.clear();
255  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
256  } while (moduleOp.lookupSymbol(stringConstName));
257  return stringConstName;
258 }
259 
260 template <typename T>
261 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
262  ConversionPatternRewriter &rewriter,
263  StringRef name,
264  LLVM::LLVMFunctionType type) {
265  LLVM::LLVMFuncOp ret;
266  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
267  ConversionPatternRewriter::InsertionGuard guard(rewriter);
268  rewriter.setInsertionPointToStart(moduleOp.getBody());
269  ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
270  LLVM::Linkage::External);
271  }
272  return ret;
273 }
274 
276  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
277  ConversionPatternRewriter &rewriter) const {
278  Location loc = gpuPrintfOp->getLoc();
279 
280  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
281  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
282  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
283  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
284  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
285  // This ensures that global constants and declarations are placed within
286  // the device code, not the host code
287  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
288 
289  auto ocklBegin =
290  getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
291  LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
292  LLVM::LLVMFuncOp ocklAppendArgs;
293  if (!adaptor.getArgs().empty()) {
294  ocklAppendArgs = getOrDefineFunction(
295  moduleOp, loc, rewriter, "__ockl_printf_append_args",
297  llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
298  llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
299  }
300  auto ocklAppendStringN = getOrDefineFunction(
301  moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
303  llvmI64,
304  {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
305 
306  /// Start the printf hostcall
307  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
308  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
309  Value printfDesc = printfBeginCall.getResult();
310 
311  // Get a unique global name for the format.
312  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
313 
314  llvm::SmallString<20> formatString(adaptor.getFormat());
315  formatString.push_back('\0'); // Null terminate for C
316  size_t formatStringSize = formatString.size_in_bytes();
317 
318  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
319  LLVM::GlobalOp global;
320  {
322  rewriter.setInsertionPointToStart(moduleOp.getBody());
323  global = rewriter.create<LLVM::GlobalOp>(
324  loc, globalType,
325  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
326  rewriter.getStringAttr(formatString));
327  }
328 
329  // Get a pointer to the format string's first element and pass it to printf()
330  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
331  loc,
332  LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
333  global.getSymNameAttr());
334  Value stringStart = rewriter.create<LLVM::GEPOp>(
335  loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
336  Value stringLen =
337  rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
338 
339  Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
340  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
341 
342  auto appendFormatCall = rewriter.create<LLVM::CallOp>(
343  loc, ocklAppendStringN,
344  ValueRange{printfDesc, stringStart, stringLen,
345  adaptor.getArgs().empty() ? oneI32 : zeroI32});
346  printfDesc = appendFormatCall.getResult();
347 
348  // __ockl_printf_append_args takes 7 values per append call
349  constexpr size_t argsPerAppend = 7;
350  size_t nArgs = adaptor.getArgs().size();
351  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
352  size_t bound = std::min(group + argsPerAppend, nArgs);
353  size_t numArgsThisCall = bound - group;
354 
356  arguments.push_back(printfDesc);
357  arguments.push_back(
358  rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
359  for (size_t i = group; i < bound; ++i) {
360  Value arg = adaptor.getArgs()[i];
361  if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
362  if (!floatType.isF64())
363  arg = rewriter.create<LLVM::FPExtOp>(
364  loc, typeConverter->convertType(rewriter.getF64Type()), arg);
365  arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
366  }
367  if (arg.getType().getIntOrFloatBitWidth() != 64)
368  arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
369 
370  arguments.push_back(arg);
371  }
372  // Pad out to 7 arguments since the hostcall always needs 7
373  for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
374  arguments.push_back(zeroI64);
375  }
376 
377  auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
378  arguments.push_back(isLast);
379  auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
380  printfDesc = call.getResult();
381  }
382  rewriter.eraseOp(gpuPrintfOp);
383  return success();
384 }
385 
387  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
388  ConversionPatternRewriter &rewriter) const {
389  Location loc = gpuPrintfOp->getLoc();
390 
391  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
392  mlir::Type ptrType =
393  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
394 
395  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
396  // This ensures that global constants and declarations are placed within
397  // the device code, not the host code
398  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
399 
400  auto printfType =
401  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
402  /*isVarArg=*/true);
403  LLVM::LLVMFuncOp printfDecl =
404  getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
405 
406  // Get a unique global name for the format.
407  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
408 
409  llvm::SmallString<20> formatString(adaptor.getFormat());
410  formatString.push_back('\0'); // Null terminate for C
411  auto globalType =
412  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
413  LLVM::GlobalOp global;
414  {
416  rewriter.setInsertionPointToStart(moduleOp.getBody());
417  global = rewriter.create<LLVM::GlobalOp>(
418  loc, globalType,
419  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
420  rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
421  }
422 
423  // Get a pointer to the format string's first element
424  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
425  loc,
426  LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
427  global.getSymNameAttr());
428  Value stringStart = rewriter.create<LLVM::GEPOp>(
429  loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
430 
431  // Construct arguments and function call
432  auto argsRange = adaptor.getArgs();
433  SmallVector<Value, 4> printfArgs;
434  printfArgs.reserve(argsRange.size() + 1);
435  printfArgs.push_back(stringStart);
436  printfArgs.append(argsRange.begin(), argsRange.end());
437 
438  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
439  rewriter.eraseOp(gpuPrintfOp);
440  return success();
441 }
442 
444  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
445  ConversionPatternRewriter &rewriter) const {
446  Location loc = gpuPrintfOp->getLoc();
447 
448  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
449  mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
450 
451  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
452  // This ensures that global constants and declarations are placed within
453  // the device code, not the host code
454  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
455 
456  auto vprintfType =
457  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
458  LLVM::LLVMFuncOp vprintfDecl =
459  getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
460 
461  // Get a unique global name for the format.
462  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
463 
464  llvm::SmallString<20> formatString(adaptor.getFormat());
465  formatString.push_back('\0'); // Null terminate for C
466  auto globalType =
467  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
468  LLVM::GlobalOp global;
469  {
471  rewriter.setInsertionPointToStart(moduleOp.getBody());
472  global = rewriter.create<LLVM::GlobalOp>(
473  loc, globalType,
474  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
475  rewriter.getStringAttr(formatString), /*allignment=*/0);
476  }
477 
478  // Get a pointer to the format string's first element
479  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
480  Value stringStart = rewriter.create<LLVM::GEPOp>(
481  loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
482  SmallVector<Type> types;
483  SmallVector<Value> args;
484  // Promote and pack the arguments into a stack allocation.
485  for (Value arg : adaptor.getArgs()) {
486  Type type = arg.getType();
487  Value promotedArg = arg;
488  assert(type.isIntOrFloat());
489  if (isa<FloatType>(type)) {
490  type = rewriter.getF64Type();
491  promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
492  }
493  types.push_back(type);
494  args.push_back(promotedArg);
495  }
496  Type structType =
497  LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
498  Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
499  rewriter.getIndexAttr(1));
500  Value tempAlloc =
501  rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
502  /*alignment=*/0);
503  for (auto [index, arg] : llvm::enumerate(args)) {
504  Value ptr = rewriter.create<LLVM::GEPOp>(
505  loc, ptrType, structType, tempAlloc,
506  ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
507  rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
508  }
509  std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
510 
511  rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
512  rewriter.eraseOp(gpuPrintfOp);
513  return success();
514 }
515 
516 /// Unrolls op if it's operating on vectors.
517 LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
518  ConversionPatternRewriter &rewriter,
519  const LLVMTypeConverter &converter) {
520  TypeRange operandTypes(operands);
521  if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
522  return rewriter.notifyMatchFailure(op, "expected vector operand");
523  }
524  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
525  return rewriter.notifyMatchFailure(op, "expected no region/successor");
526  if (op->getNumResults() != 1)
527  return rewriter.notifyMatchFailure(op, "expected single result");
528  VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
529  if (!vectorType)
530  return rewriter.notifyMatchFailure(op, "expected vector result");
531 
532  Location loc = op->getLoc();
533  Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
534  Type indexType = converter.convertType(rewriter.getIndexType());
535  StringAttr name = op->getName().getIdentifier();
536  Type elementType = vectorType.getElementType();
537 
538  for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
539  Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
540  auto extractElement = [&](Value operand) -> Value {
541  if (!isa<VectorType>(operand.getType()))
542  return operand;
543  return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
544  };
545  auto scalarOperands = llvm::map_to_vector(operands, extractElement);
546  Operation *scalarOp =
547  rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
548  result = rewriter.create<LLVM::InsertElementOp>(
549  loc, result, scalarOp->getResult(0), index);
550  }
551 
552  rewriter.replaceOp(op, result);
553  return success();
554 }
555 
556 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
557  return IntegerAttr::get(IntegerType::get(ctx, 64), space);
558 }
559 
560 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
561 /// or uses existing symbol.
562 LLVM::GlobalOp
564  Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
565  const LLVMTypeConverter *typeConverter,
566  MemRefType memrefType, unsigned alignmentBit) {
567  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
568 
569  FailureOr<unsigned> addressSpace =
570  typeConverter->getMemRefAddressSpace(memrefType);
571  if (failed(addressSpace)) {
572  op->emitError() << "conversion of memref memory space "
573  << memrefType.getMemorySpace()
574  << " to integer address space "
575  "failed. Consider adding memory space conversions.";
576  }
577 
578  // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
579  // LLVM::GlobalOp is suitable for shared memory, return it.
580  llvm::StringSet<> existingGlobalNames;
581  for (auto globalOp :
582  moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
583  existingGlobalNames.insert(globalOp.getSymName());
584  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
585  if (globalOp.getAddrSpace() == addressSpace.value() &&
586  arrayType.getNumElements() == 0 &&
587  globalOp.getAlignment().value_or(0) == alignmentByte) {
588  return globalOp;
589  }
590  }
591  }
592 
593  // Step 2. Find a unique symbol name
594  unsigned uniquingCounter = 0;
595  SmallString<128> symName = SymbolTable::generateSymbolName<128>(
596  "__dynamic_shmem_",
597  [&](StringRef candidate) {
598  return existingGlobalNames.contains(candidate);
599  },
600  uniquingCounter);
601 
602  // Step 3. Generate a global op
603  OpBuilder::InsertionGuard guard(rewriter);
604  rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
605 
606  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
607  typeConverter->convertType(memrefType.getElementType()), 0);
608 
609  return rewriter.create<LLVM::GlobalOp>(
610  op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
611  LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
612  addressSpace.value());
613 }
614 
616  gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
617  ConversionPatternRewriter &rewriter) const {
618  Location loc = op.getLoc();
619  MemRefType memrefType = op.getResultMemref().getType();
620  Type elementType = typeConverter->convertType(memrefType.getElementType());
621 
622  // Step 1: Generate a memref<0xi8> type
623  MemRefLayoutAttrInterface layout = {};
624  auto memrefType0sz =
625  MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
626 
627  // Step 2: Generate a global symbol or existing for the dynamic shared
628  // memory with memref<0xi8> type
629  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
630  LLVM::GlobalOp shmemOp = {};
631  Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
633  rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
634 
635  // Step 3. Get address of the global symbol
636  OpBuilder::InsertionGuard guard(rewriter);
637  rewriter.setInsertionPoint(op);
638  auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
639  Type baseType = basePtr->getResultTypes().front();
640 
641  // Step 4. Generate GEP using offsets
642  SmallVector<LLVM::GEPArg> gepArgs = {0};
643  Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
644  basePtr, gepArgs);
645  // Step 5. Create a memref descriptor
646  SmallVector<Value> shape, strides;
647  Value sizeBytes;
648  getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
649  sizeBytes);
650  auto memRefDescriptor = this->createMemRefDescriptor(
651  loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
652 
653  // Step 5. Replace the op with memref descriptor
654  rewriter.replaceOp(op, {memRefDescriptor});
655  return success();
656 }
657 
659  gpu::ReturnOp op, OpAdaptor adaptor,
660  ConversionPatternRewriter &rewriter) const {
661  Location loc = op.getLoc();
662  unsigned numArguments = op.getNumOperands();
663  SmallVector<Value, 4> updatedOperands;
664 
665  bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
666  if (useBarePtrCallConv) {
667  // For the bare-ptr calling convention, extract the aligned pointer to
668  // be returned from the memref descriptor.
669  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
670  Type oldTy = std::get<0>(it).getType();
671  Value newOperand = std::get<1>(it);
672  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
673  cast<BaseMemRefType>(oldTy))) {
674  MemRefDescriptor memrefDesc(newOperand);
675  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
676  } else if (isa<UnrankedMemRefType>(oldTy)) {
677  // Unranked memref is not supported in the bare pointer calling
678  // convention.
679  return failure();
680  }
681  updatedOperands.push_back(newOperand);
682  }
683  } else {
684  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
685  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
686  updatedOperands,
687  /*toDynamic=*/true);
688  }
689 
690  // If ReturnOp has 0 or 1 operand, create it and return immediately.
691  if (numArguments <= 1) {
692  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
693  op, TypeRange(), updatedOperands, op->getAttrs());
694  return success();
695  }
696 
697  // Otherwise, we need to pack the arguments into an LLVM struct type before
698  // returning.
699  auto packedType = getTypeConverter()->packFunctionResults(
700  op.getOperandTypes(), useBarePtrCallConv);
701  if (!packedType) {
702  return rewriter.notifyMatchFailure(op, "could not convert result types");
703  }
704 
705  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
706  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
707  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
708  }
709  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
710  op->getAttrs());
711  return success();
712 }
713 
715  TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
716  typeConverter.addTypeAttributeConversion(
717  [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
718  gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
719  unsigned addressSpace = mapping(memorySpace);
720  return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
721  addressSpace);
722  });
723 }
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
LLVM::GlobalOp getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter, Operation *moduleOp, gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit)
Generates a symbol with 0-sized array type for dynamic shared memory usage, or uses existing symbol.
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space)
static SmallString< 16 > getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp)
static std::string diag(const llvm::Value &value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:144
Operation & front()
Definition: Block.h:151
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:191
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
UnitAttr getUnitAttr()
Definition: Builders.cpp:118
IntegerType getI64Type()
Definition: Builders.cpp:89
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:75
IntegerType getI8Type()
Definition: Builders.cpp:83
FloatType getF64Type()
Definition: Builders.cpp:69
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:247
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:94
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
type_range getTypes() const
Definition: ValueRange.cpp:26
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumSuccessors()
Definition: Operation.h:702
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Block & front()
Definition: Region.h:65
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Unrolls op if it's operating on vectors.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
std::function< unsigned(gpu::AddressSpace)> MemorySpaceMapping
A function that maps a MemorySpace enum to a target-specific integer value.
Definition: GPUCommonPass.h:70
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override