MLIR  19.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/ADT/SmallVectorExtras.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir;
21 
23 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
24  ConversionPatternRewriter &rewriter) const {
25  Location loc = gpuFuncOp.getLoc();
26 
27  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
28  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
29  for (const auto [idx, attribution] :
30  llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
31  auto type = dyn_cast<MemRefType>(attribution.getType());
32  assert(type && type.hasStaticShape() && "unexpected type in attribution");
33 
34  uint64_t numElements = type.getNumElements();
35 
36  auto elementType =
37  cast<Type>(typeConverter->convertType(type.getElementType()));
38  auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
39  std::string name =
40  std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
41  uint64_t alignment = 0;
42  if (auto alignAttr =
43  dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
44  idx, LLVM::LLVMDialect::getAlignAttrName())))
45  alignment = alignAttr.getInt();
46  auto globalOp = rewriter.create<LLVM::GlobalOp>(
47  gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
48  LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
49  workgroupAddrSpace);
50  workgroupBuffers.push_back(globalOp);
51  }
52 
53  // Remap proper input types.
54  TypeConverter::SignatureConversion signatureConversion(
55  gpuFuncOp.front().getNumArguments());
56 
58  gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
59  getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
60  if (!funcType) {
61  return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
62  diag << "failed to convert function signature type for: "
63  << gpuFuncOp.getFunctionType();
64  });
65  }
66 
67  // Create the new function operation. Only copy those attributes that are
68  // not specific to function modeling.
70  ArrayAttr argAttrs;
71  for (const auto &attr : gpuFuncOp->getAttrs()) {
72  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
73  attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
74  attr.getName() ==
75  gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
76  attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
77  attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName())
78  continue;
79  if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
80  argAttrs = gpuFuncOp.getArgAttrsAttr();
81  continue;
82  }
83  attributes.push_back(attr);
84  }
85  // Add a dialect specific kernel attribute in addition to GPU kernel
86  // attribute. The former is necessary for further translation while the
87  // latter is expected by gpu.launch_func.
88  if (gpuFuncOp.isKernel()) {
89  attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
90 
91  // Set the block size attribute if it is present.
92  if (kernelBlockSizeAttributeName.has_value()) {
93  std::optional<int32_t> dimX =
94  gpuFuncOp.getKnownBlockSize(gpu::Dimension::x);
95  std::optional<int32_t> dimY =
96  gpuFuncOp.getKnownBlockSize(gpu::Dimension::y);
97  std::optional<int32_t> dimZ =
98  gpuFuncOp.getKnownBlockSize(gpu::Dimension::z);
99  if (dimX.has_value() || dimY.has_value() || dimZ.has_value()) {
100  // If any of the dimensions are missing, fill them in with 1.
101  attributes.emplace_back(
102  kernelBlockSizeAttributeName.value(),
103  rewriter.getDenseI32ArrayAttr(
104  {dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)}));
105  }
106  }
107  }
108  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
109  gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
110  LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C,
111  /*comdat=*/nullptr, attributes);
112 
113  {
114  // Insert operations that correspond to converted workgroup and private
115  // memory attributions to the body of the function. This must operate on
116  // the original function, before the body region is inlined in the new
117  // function to maintain the relation between block arguments and the
118  // parent operation that assigns their semantics.
119  OpBuilder::InsertionGuard guard(rewriter);
120 
121  // Rewrite workgroup memory attributions to addresses of global buffers.
122  rewriter.setInsertionPointToStart(&gpuFuncOp.front());
123  unsigned numProperArguments = gpuFuncOp.getNumArguments();
124 
125  for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
126  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
127  global.getAddrSpace());
128  Value address = rewriter.create<LLVM::AddressOfOp>(
129  loc, ptrType, global.getSymNameAttr());
130  Value memory =
131  rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), address,
132  ArrayRef<LLVM::GEPArg>{0, 0});
133 
134  // Build a memref descriptor pointing to the buffer to plug with the
135  // existing memref infrastructure. This may use more registers than
136  // otherwise necessary given that memref sizes are fixed, but we can try
137  // and canonicalize that away later.
138  Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
139  auto type = cast<MemRefType>(attribution.getType());
141  rewriter, loc, *getTypeConverter(), type, memory);
142  signatureConversion.remapInput(numProperArguments + idx, descr);
143  }
144 
145  // Rewrite private memory attributions to alloca'ed buffers.
146  unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
147  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
148  for (const auto [idx, attribution] :
149  llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
150  auto type = cast<MemRefType>(attribution.getType());
151  assert(type && type.hasStaticShape() && "unexpected type in attribution");
152 
153  // Explicitly drop memory space when lowering private memory
154  // attributions since NVVM models it as `alloca`s in the default
155  // memory space and does not support `alloca`s with addrspace(5).
156  Type elementType = typeConverter->convertType(type.getElementType());
157  auto ptrType =
158  LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
159  Value numElements = rewriter.create<LLVM::ConstantOp>(
160  gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
161  uint64_t alignment = 0;
162  if (auto alignAttr =
163  dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
164  idx, LLVM::LLVMDialect::getAlignAttrName())))
165  alignment = alignAttr.getInt();
166  Value allocated = rewriter.create<LLVM::AllocaOp>(
167  gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
169  rewriter, loc, *getTypeConverter(), type, allocated);
170  signatureConversion.remapInput(
171  numProperArguments + numWorkgroupAttributions + idx, descr);
172  }
173  }
174 
175  // Move the region to the new function, update the entry block signature.
176  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
177  llvmFuncOp.end());
178  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
179  &signatureConversion)))
180  return failure();
181 
182  // If bare memref pointers are being used, remap them back to memref
183  // descriptors This must be done after signature conversion to get rid of the
184  // unrealized casts.
185  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
186  OpBuilder::InsertionGuard guard(rewriter);
187  rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
188  for (const auto [idx, argTy] :
189  llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
190  auto memrefTy = dyn_cast<MemRefType>(argTy);
191  if (!memrefTy)
192  continue;
193  assert(memrefTy.hasStaticShape() &&
194  "Bare pointer convertion used with dynamically-shaped memrefs");
195  // Use a placeholder when replacing uses of the memref argument to prevent
196  // circular replacements.
197  auto remapping = signatureConversion.getInputMapping(idx);
198  assert(remapping && remapping->size == 1 &&
199  "Type converter should produce 1-to-1 mapping for bare memrefs");
200  BlockArgument newArg =
201  llvmFuncOp.getBody().getArgument(remapping->inputNo);
202  auto placeholder = rewriter.create<LLVM::UndefOp>(
203  loc, getTypeConverter()->convertType(memrefTy));
204  rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
206  rewriter, loc, *getTypeConverter(), memrefTy, newArg);
207  rewriter.replaceOp(placeholder, {desc});
208  }
209  }
210 
211  // Get memref type from function arguments and set the noalias to
212  // pointer arguments.
213  for (const auto [idx, argTy] :
214  llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
215  auto remapping = signatureConversion.getInputMapping(idx);
216  NamedAttrList argAttr =
217  argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
218  auto copyAttribute = [&](StringRef attrName) {
219  Attribute attr = argAttr.erase(attrName);
220  if (!attr)
221  return;
222  for (size_t i = 0, e = remapping->size; i < e; ++i)
223  llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
224  };
225  auto copyPointerAttribute = [&](StringRef attrName) {
226  Attribute attr = argAttr.erase(attrName);
227 
228  if (!attr)
229  return;
230  if (remapping->size > 1 &&
231  attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
232  emitWarning(llvmFuncOp.getLoc(),
233  "Cannot copy noalias with non-bare pointers.\n");
234  return;
235  }
236  for (size_t i = 0, e = remapping->size; i < e; ++i) {
237  if (llvmFuncOp.getArgument(remapping->inputNo + i)
238  .getType()
239  .isa<LLVM::LLVMPointerType>()) {
240  llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
241  }
242  }
243  };
244 
245  if (argAttr.empty())
246  continue;
247 
248  copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
249  copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
250  copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
251  bool lowersToPointer = false;
252  for (size_t i = 0, e = remapping->size; i < e; ++i) {
253  lowersToPointer |= isa<LLVM::LLVMPointerType>(
254  llvmFuncOp.getArgument(remapping->inputNo + i).getType());
255  }
256 
257  if (lowersToPointer) {
258  copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
259  copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
260  copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
261  copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
262  copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
263  copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
264  copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
265  copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
266  copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
267  copyPointerAttribute(
268  LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
269  }
270  }
271  rewriter.eraseOp(gpuFuncOp);
272  return success();
273 }
274 
275 static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
276  const char formatStringPrefix[] = "printfFormat_";
277  // Get a unique global name.
278  unsigned stringNumber = 0;
279  SmallString<16> stringConstName;
280  do {
281  stringConstName.clear();
282  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
283  } while (moduleOp.lookupSymbol(stringConstName));
284  return stringConstName;
285 }
286 
287 template <typename T>
288 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
289  ConversionPatternRewriter &rewriter,
290  StringRef name,
291  LLVM::LLVMFunctionType type) {
292  LLVM::LLVMFuncOp ret;
293  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
294  ConversionPatternRewriter::InsertionGuard guard(rewriter);
295  rewriter.setInsertionPointToStart(moduleOp.getBody());
296  ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
297  LLVM::Linkage::External);
298  }
299  return ret;
300 }
301 
303  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
304  ConversionPatternRewriter &rewriter) const {
305  Location loc = gpuPrintfOp->getLoc();
306 
307  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
308  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
309  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
310  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
311  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
312  // This ensures that global constants and declarations are placed within
313  // the device code, not the host code
314  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
315 
316  auto ocklBegin =
317  getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
318  LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
319  LLVM::LLVMFuncOp ocklAppendArgs;
320  if (!adaptor.getArgs().empty()) {
321  ocklAppendArgs = getOrDefineFunction(
322  moduleOp, loc, rewriter, "__ockl_printf_append_args",
324  llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
325  llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
326  }
327  auto ocklAppendStringN = getOrDefineFunction(
328  moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
330  llvmI64,
331  {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
332 
333  /// Start the printf hostcall
334  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
335  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
336  Value printfDesc = printfBeginCall.getResult();
337 
338  // Get a unique global name for the format.
339  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
340 
341  llvm::SmallString<20> formatString(adaptor.getFormat());
342  formatString.push_back('\0'); // Null terminate for C
343  size_t formatStringSize = formatString.size_in_bytes();
344 
345  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
346  LLVM::GlobalOp global;
347  {
349  rewriter.setInsertionPointToStart(moduleOp.getBody());
350  global = rewriter.create<LLVM::GlobalOp>(
351  loc, globalType,
352  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
353  rewriter.getStringAttr(formatString));
354  }
355 
356  // Get a pointer to the format string's first element and pass it to printf()
357  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
358  loc,
359  LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
360  global.getSymNameAttr());
361  Value stringStart = rewriter.create<LLVM::GEPOp>(
362  loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
363  Value stringLen =
364  rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
365 
366  Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
367  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
368 
369  auto appendFormatCall = rewriter.create<LLVM::CallOp>(
370  loc, ocklAppendStringN,
371  ValueRange{printfDesc, stringStart, stringLen,
372  adaptor.getArgs().empty() ? oneI32 : zeroI32});
373  printfDesc = appendFormatCall.getResult();
374 
375  // __ockl_printf_append_args takes 7 values per append call
376  constexpr size_t argsPerAppend = 7;
377  size_t nArgs = adaptor.getArgs().size();
378  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
379  size_t bound = std::min(group + argsPerAppend, nArgs);
380  size_t numArgsThisCall = bound - group;
381 
383  arguments.push_back(printfDesc);
384  arguments.push_back(
385  rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
386  for (size_t i = group; i < bound; ++i) {
387  Value arg = adaptor.getArgs()[i];
388  if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
389  if (!floatType.isF64())
390  arg = rewriter.create<LLVM::FPExtOp>(
391  loc, typeConverter->convertType(rewriter.getF64Type()), arg);
392  arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
393  }
394  if (arg.getType().getIntOrFloatBitWidth() != 64)
395  arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
396 
397  arguments.push_back(arg);
398  }
399  // Pad out to 7 arguments since the hostcall always needs 7
400  for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
401  arguments.push_back(zeroI64);
402  }
403 
404  auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
405  arguments.push_back(isLast);
406  auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
407  printfDesc = call.getResult();
408  }
409  rewriter.eraseOp(gpuPrintfOp);
410  return success();
411 }
412 
414  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
415  ConversionPatternRewriter &rewriter) const {
416  Location loc = gpuPrintfOp->getLoc();
417 
418  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
419  mlir::Type ptrType =
420  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
421 
422  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
423  // This ensures that global constants and declarations are placed within
424  // the device code, not the host code
425  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
426 
427  auto printfType =
428  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
429  /*isVarArg=*/true);
430  LLVM::LLVMFuncOp printfDecl =
431  getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
432 
433  // Get a unique global name for the format.
434  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
435 
436  llvm::SmallString<20> formatString(adaptor.getFormat());
437  formatString.push_back('\0'); // Null terminate for C
438  auto globalType =
439  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
440  LLVM::GlobalOp global;
441  {
443  rewriter.setInsertionPointToStart(moduleOp.getBody());
444  global = rewriter.create<LLVM::GlobalOp>(
445  loc, globalType,
446  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
447  rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
448  }
449 
450  // Get a pointer to the format string's first element
451  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
452  loc,
453  LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
454  global.getSymNameAttr());
455  Value stringStart = rewriter.create<LLVM::GEPOp>(
456  loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
457 
458  // Construct arguments and function call
459  auto argsRange = adaptor.getArgs();
460  SmallVector<Value, 4> printfArgs;
461  printfArgs.reserve(argsRange.size() + 1);
462  printfArgs.push_back(stringStart);
463  printfArgs.append(argsRange.begin(), argsRange.end());
464 
465  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
466  rewriter.eraseOp(gpuPrintfOp);
467  return success();
468 }
469 
471  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
472  ConversionPatternRewriter &rewriter) const {
473  Location loc = gpuPrintfOp->getLoc();
474 
475  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
476  mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
477 
478  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
479  // This ensures that global constants and declarations are placed within
480  // the device code, not the host code
481  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
482 
483  auto vprintfType =
484  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
485  LLVM::LLVMFuncOp vprintfDecl =
486  getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
487 
488  // Get a unique global name for the format.
489  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
490 
491  llvm::SmallString<20> formatString(adaptor.getFormat());
492  formatString.push_back('\0'); // Null terminate for C
493  auto globalType =
494  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
495  LLVM::GlobalOp global;
496  {
498  rewriter.setInsertionPointToStart(moduleOp.getBody());
499  global = rewriter.create<LLVM::GlobalOp>(
500  loc, globalType,
501  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
502  rewriter.getStringAttr(formatString), /*allignment=*/0);
503  }
504 
505  // Get a pointer to the format string's first element
506  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
507  Value stringStart = rewriter.create<LLVM::GEPOp>(
508  loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
509  SmallVector<Type> types;
510  SmallVector<Value> args;
511  // Promote and pack the arguments into a stack allocation.
512  for (Value arg : adaptor.getArgs()) {
513  Type type = arg.getType();
514  Value promotedArg = arg;
515  assert(type.isIntOrFloat());
516  if (isa<FloatType>(type)) {
517  type = rewriter.getF64Type();
518  promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
519  }
520  types.push_back(type);
521  args.push_back(promotedArg);
522  }
523  Type structType =
524  LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
525  Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
526  rewriter.getIndexAttr(1));
527  Value tempAlloc =
528  rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
529  /*alignment=*/0);
530  for (auto [index, arg] : llvm::enumerate(args)) {
531  Value ptr = rewriter.create<LLVM::GEPOp>(
532  loc, ptrType, structType, tempAlloc,
533  ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
534  rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
535  }
536  std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
537 
538  rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
539  rewriter.eraseOp(gpuPrintfOp);
540  return success();
541 }
542 
543 /// Unrolls op if it's operating on vectors.
545  ConversionPatternRewriter &rewriter,
546  const LLVMTypeConverter &converter) {
547  TypeRange operandTypes(operands);
548  if (llvm::none_of(operandTypes,
549  [](Type type) { return isa<VectorType>(type); })) {
550  return rewriter.notifyMatchFailure(op, "expected vector operand");
551  }
552  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
553  return rewriter.notifyMatchFailure(op, "expected no region/successor");
554  if (op->getNumResults() != 1)
555  return rewriter.notifyMatchFailure(op, "expected single result");
556  VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
557  if (!vectorType)
558  return rewriter.notifyMatchFailure(op, "expected vector result");
559 
560  Location loc = op->getLoc();
561  Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
562  Type indexType = converter.convertType(rewriter.getIndexType());
563  StringAttr name = op->getName().getIdentifier();
564  Type elementType = vectorType.getElementType();
565 
566  for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
567  Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
568  auto extractElement = [&](Value operand) -> Value {
569  if (!isa<VectorType>(operand.getType()))
570  return operand;
571  return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
572  };
573  auto scalarOperands = llvm::map_to_vector(operands, extractElement);
574  Operation *scalarOp =
575  rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
576  result = rewriter.create<LLVM::InsertElementOp>(
577  loc, result, scalarOp->getResult(0), index);
578  }
579 
580  rewriter.replaceOp(op, result);
581  return success();
582 }
583 
584 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
585  return IntegerAttr::get(IntegerType::get(ctx, 64), space);
586 }
587 
588 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
589 /// or uses existing symbol.
590 LLVM::GlobalOp
592  Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
593  const LLVMTypeConverter *typeConverter,
594  MemRefType memrefType, unsigned alignmentBit) {
595  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
596 
597  FailureOr<unsigned> addressSpace =
598  typeConverter->getMemRefAddressSpace(memrefType);
599  if (failed(addressSpace)) {
600  op->emitError() << "conversion of memref memory space "
601  << memrefType.getMemorySpace()
602  << " to integer address space "
603  "failed. Consider adding memory space conversions.";
604  }
605 
606  // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
607  // LLVM::GlobalOp is suitable for shared memory, return it.
608  llvm::StringSet<> existingGlobalNames;
609  for (auto globalOp :
610  moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
611  existingGlobalNames.insert(globalOp.getSymName());
612  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
613  if (globalOp.getAddrSpace() == addressSpace.value() &&
614  arrayType.getNumElements() == 0 &&
615  globalOp.getAlignment().value_or(0) == alignmentByte) {
616  return globalOp;
617  }
618  }
619  }
620 
621  // Step 2. Find a unique symbol name
622  unsigned uniquingCounter = 0;
623  SmallString<128> symName = SymbolTable::generateSymbolName<128>(
624  "__dynamic_shmem_",
625  [&](StringRef candidate) {
626  return existingGlobalNames.contains(candidate);
627  },
628  uniquingCounter);
629 
630  // Step 3. Generate a global op
631  OpBuilder::InsertionGuard guard(rewriter);
632  rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
633 
634  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
635  typeConverter->convertType(memrefType.getElementType()), 0);
636 
637  return rewriter.create<LLVM::GlobalOp>(
638  op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
639  LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
640  addressSpace.value());
641 }
642 
644  gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
645  ConversionPatternRewriter &rewriter) const {
646  Location loc = op.getLoc();
647  MemRefType memrefType = op.getResultMemref().getType();
648  Type elementType = typeConverter->convertType(memrefType.getElementType());
649 
650  // Step 1: Generate a memref<0xi8> type
651  MemRefLayoutAttrInterface layout = {};
652  auto memrefType0sz =
653  MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
654 
655  // Step 2: Generate a global symbol or existing for the dynamic shared
656  // memory with memref<0xi8> type
657  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
658  LLVM::GlobalOp shmemOp = {};
659  Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
661  rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
662 
663  // Step 3. Get address of the global symbol
664  OpBuilder::InsertionGuard guard(rewriter);
665  rewriter.setInsertionPoint(op);
666  auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
667  Type baseType = basePtr->getResultTypes().front();
668 
669  // Step 4. Generate GEP using offsets
670  SmallVector<LLVM::GEPArg> gepArgs = {0};
671  Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
672  basePtr, gepArgs);
673  // Step 5. Create a memref descriptor
674  SmallVector<Value> shape, strides;
675  Value sizeBytes;
676  getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
677  sizeBytes);
678  auto memRefDescriptor = this->createMemRefDescriptor(
679  loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
680 
681  // Step 5. Replace the op with memref descriptor
682  rewriter.replaceOp(op, {memRefDescriptor});
683  return success();
684 }
685 
687  TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
688  typeConverter.addTypeAttributeConversion(
689  [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
690  gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
691  unsigned addressSpace = mapping(memorySpace);
692  return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
693  addressSpace);
694  });
695 }
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
LLVM::GlobalOp getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter, Operation *moduleOp, gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit)
Generates a symbol with 0-sized array type for dynamic shared memory usage, or uses existing symbol.
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space)
static SmallString< 16 > getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp)
static std::string diag(const llvm::Value &value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
This class represents an argument of a Block.
Definition: Value.h:315
Operation & front()
Definition: Block.h:150
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:190
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
IntegerType getI8Type()
Definition: Builders.cpp:79
FloatType getF64Type()
Definition: Builders.cpp:65
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumSuccessors()
Definition: Operation.h:702
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Block & front()
Definition: Region.h:65
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Unrolls op if it's operating on vectors.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
std::function< unsigned(gpu::AddressSpace)> MemorySpaceMapping
A function that maps a MemorySpace enum to a target-specific integer value.
Definition: GPUCommonPass.h:75
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26