MLIR  18.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/ADT/SmallVectorExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
22 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
23  ConversionPatternRewriter &rewriter) const {
24  Location loc = gpuFuncOp.getLoc();
25 
26  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
27  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
28  for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
29  BlockArgument attribution = en.value();
30 
31  auto type = dyn_cast<MemRefType>(attribution.getType());
32  assert(type && type.hasStaticShape() && "unexpected type in attribution");
33 
34  uint64_t numElements = type.getNumElements();
35 
36  auto elementType =
37  cast<Type>(typeConverter->convertType(type.getElementType()));
38  auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
39  std::string name = std::string(
40  llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
41  uint64_t alignment = 0;
42  if (auto alignAttr =
43  dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
44  en.index(), LLVM::LLVMDialect::getAlignAttrName())))
45  alignment = alignAttr.getInt();
46  auto globalOp = rewriter.create<LLVM::GlobalOp>(
47  gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
48  LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
49  workgroupAddrSpace);
50  workgroupBuffers.push_back(globalOp);
51  }
52 
53  // Remap proper input types.
54  TypeConverter::SignatureConversion signatureConversion(
55  gpuFuncOp.front().getNumArguments());
56 
58  gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
59  getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
60  if (!funcType) {
61  return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
62  diag << "failed to convert function signature type for: "
63  << gpuFuncOp.getFunctionType();
64  });
65  }
66 
67  // Create the new function operation. Only copy those attributes that are
68  // not specific to function modeling.
70  ArrayAttr argAttrs;
71  for (const auto &attr : gpuFuncOp->getAttrs()) {
72  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
73  attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
74  attr.getName() ==
75  gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
76  attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
77  attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName())
78  continue;
79  if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
80  argAttrs = gpuFuncOp.getArgAttrsAttr();
81  continue;
82  }
83  attributes.push_back(attr);
84  }
85  // Add a dialect specific kernel attribute in addition to GPU kernel
86  // attribute. The former is necessary for further translation while the
87  // latter is expected by gpu.launch_func.
88  if (gpuFuncOp.isKernel())
89  attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
90  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
91  gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
92  LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C,
93  /*comdat=*/nullptr, attributes);
94 
95  {
96  // Insert operations that correspond to converted workgroup and private
97  // memory attributions to the body of the function. This must operate on
98  // the original function, before the body region is inlined in the new
99  // function to maintain the relation between block arguments and the
100  // parent operation that assigns their semantics.
101  OpBuilder::InsertionGuard guard(rewriter);
102 
103  // Rewrite workgroup memory attributions to addresses of global buffers.
104  rewriter.setInsertionPointToStart(&gpuFuncOp.front());
105  unsigned numProperArguments = gpuFuncOp.getNumArguments();
106 
107  for (const auto &en : llvm::enumerate(workgroupBuffers)) {
108  LLVM::GlobalOp global = en.value();
109  Value address = rewriter.create<LLVM::AddressOfOp>(
110  loc,
111  getTypeConverter()->getPointerType(global.getType(),
112  global.getAddrSpace()),
113  global.getSymNameAttr());
114  auto elementType =
115  cast<LLVM::LLVMArrayType>(global.getType()).getElementType();
116  Value memory = rewriter.create<LLVM::GEPOp>(
117  loc,
118  getTypeConverter()->getPointerType(elementType,
119  global.getAddrSpace()),
120  global.getType(), address, ArrayRef<LLVM::GEPArg>{0, 0});
121 
122  // Build a memref descriptor pointing to the buffer to plug with the
123  // existing memref infrastructure. This may use more registers than
124  // otherwise necessary given that memref sizes are fixed, but we can try
125  // and canonicalize that away later.
126  Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
127  auto type = cast<MemRefType>(attribution.getType());
129  rewriter, loc, *getTypeConverter(), type, memory);
130  signatureConversion.remapInput(numProperArguments + en.index(), descr);
131  }
132 
133  // Rewrite private memory attributions to alloca'ed buffers.
134  unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
135  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
136  for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
137  Value attribution = en.value();
138  auto type = cast<MemRefType>(attribution.getType());
139  assert(type && type.hasStaticShape() && "unexpected type in attribution");
140 
141  // Explicitly drop memory space when lowering private memory
142  // attributions since NVVM models it as `alloca`s in the default
143  // memory space and does not support `alloca`s with addrspace(5).
144  Type elementType = typeConverter->convertType(type.getElementType());
145  auto ptrType =
146  getTypeConverter()->getPointerType(elementType, allocaAddrSpace);
147  Value numElements = rewriter.create<LLVM::ConstantOp>(
148  gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
149  uint64_t alignment = 0;
150  if (auto alignAttr =
151  dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
152  en.index(), LLVM::LLVMDialect::getAlignAttrName())))
153  alignment = alignAttr.getInt();
154  Value allocated = rewriter.create<LLVM::AllocaOp>(
155  gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
157  rewriter, loc, *getTypeConverter(), type, allocated);
158  signatureConversion.remapInput(
159  numProperArguments + numWorkgroupAttributions + en.index(), descr);
160  }
161  }
162 
163  // Move the region to the new function, update the entry block signature.
164  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
165  llvmFuncOp.end());
166  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
167  &signatureConversion)))
168  return failure();
169 
170  // If bare memref pointers are being used, remap them back to memref
171  // descriptors This must be done after signature conversion to get rid of the
172  // unrealized casts.
173  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
174  OpBuilder::InsertionGuard guard(rewriter);
175  rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
176  for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
177  auto memrefTy = dyn_cast<MemRefType>(en.value());
178  if (!memrefTy)
179  continue;
180  assert(memrefTy.hasStaticShape() &&
181  "Bare pointer convertion used with dynamically-shaped memrefs");
182  // Use a placeholder when replacing uses of the memref argument to prevent
183  // circular replacements.
184  auto remapping = signatureConversion.getInputMapping(en.index());
185  assert(remapping && remapping->size == 1 &&
186  "Type converter should produce 1-to-1 mapping for bare memrefs");
187  BlockArgument newArg =
188  llvmFuncOp.getBody().getArgument(remapping->inputNo);
189  auto placeholder = rewriter.create<LLVM::UndefOp>(
190  loc, getTypeConverter()->convertType(memrefTy));
191  rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
193  rewriter, loc, *getTypeConverter(), memrefTy, newArg);
194  rewriter.replaceOp(placeholder, {desc});
195  }
196  }
197 
198  // Get memref type from function arguments and set the noalias to
199  // pointer arguments.
200  for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
201  auto memrefTy = en.value().dyn_cast<MemRefType>();
202  NamedAttrList argAttr = argAttrs
203  ? argAttrs[en.index()].cast<DictionaryAttr>()
204  : NamedAttrList();
205 
206  auto copyPointerAttribute = [&](StringRef attrName) {
207  Attribute attr = argAttr.erase(attrName);
208 
209  // This is a proxy for the bare pointer calling convention.
210  if (!attr)
211  return;
212  auto remapping = signatureConversion.getInputMapping(en.index());
213  if (remapping->size > 1 &&
214  attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
215  emitWarning(llvmFuncOp.getLoc(),
216  "Cannot copy noalias with non-bare pointers.\n");
217  return;
218  }
219  for (size_t i = 0, e = remapping->size; i < e; ++i) {
220  if (llvmFuncOp.getArgument(remapping->inputNo + i)
221  .getType()
222  .isa<LLVM::LLVMPointerType>()) {
223  llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
224  }
225  }
226  };
227 
228  if (argAttr.empty())
229  continue;
230 
231  if (memrefTy) {
232  copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
233  copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
234  copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
235  copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
236  copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
237  copyPointerAttribute(
238  LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
239  }
240  }
241  rewriter.eraseOp(gpuFuncOp);
242  return success();
243 }
244 
245 static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
246  const char formatStringPrefix[] = "printfFormat_";
247  // Get a unique global name.
248  unsigned stringNumber = 0;
249  SmallString<16> stringConstName;
250  do {
251  stringConstName.clear();
252  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
253  } while (moduleOp.lookupSymbol(stringConstName));
254  return stringConstName;
255 }
256 
257 template <typename T>
258 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
259  ConversionPatternRewriter &rewriter,
260  StringRef name,
261  LLVM::LLVMFunctionType type) {
262  LLVM::LLVMFuncOp ret;
263  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
264  ConversionPatternRewriter::InsertionGuard guard(rewriter);
265  rewriter.setInsertionPointToStart(moduleOp.getBody());
266  ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
267  LLVM::Linkage::External);
268  }
269  return ret;
270 }
271 
273  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
274  ConversionPatternRewriter &rewriter) const {
275  Location loc = gpuPrintfOp->getLoc();
276 
277  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
278  mlir::Type i8Ptr = getTypeConverter()->getPointerType(llvmI8);
279  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
280  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
281  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
282  // This ensures that global constants and declarations are placed within
283  // the device code, not the host code
284  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
285 
286  auto ocklBegin =
287  getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
288  LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
289  LLVM::LLVMFuncOp ocklAppendArgs;
290  if (!adaptor.getArgs().empty()) {
291  ocklAppendArgs = getOrDefineFunction(
292  moduleOp, loc, rewriter, "__ockl_printf_append_args",
294  llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
295  llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
296  }
297  auto ocklAppendStringN = getOrDefineFunction(
298  moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
300  llvmI64,
301  {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
302 
303  /// Start the printf hostcall
304  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
305  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
306  Value printfDesc = printfBeginCall.getResult();
307 
308  // Get a unique global name for the format.
309  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
310 
311  llvm::SmallString<20> formatString(adaptor.getFormat());
312  formatString.push_back('\0'); // Null terminate for C
313  size_t formatStringSize = formatString.size_in_bytes();
314 
315  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
316  LLVM::GlobalOp global;
317  {
319  rewriter.setInsertionPointToStart(moduleOp.getBody());
320  global = rewriter.create<LLVM::GlobalOp>(
321  loc, globalType,
322  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
323  rewriter.getStringAttr(formatString));
324  }
325 
326  // Get a pointer to the format string's first element and pass it to printf()
327  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
328  loc,
329  getTypeConverter()->getPointerType(globalType, global.getAddrSpace()),
330  global.getSymNameAttr());
331  Value stringStart = rewriter.create<LLVM::GEPOp>(
332  loc, i8Ptr, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
333  Value stringLen =
334  rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
335 
336  Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
337  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
338 
339  auto appendFormatCall = rewriter.create<LLVM::CallOp>(
340  loc, ocklAppendStringN,
341  ValueRange{printfDesc, stringStart, stringLen,
342  adaptor.getArgs().empty() ? oneI32 : zeroI32});
343  printfDesc = appendFormatCall.getResult();
344 
345  // __ockl_printf_append_args takes 7 values per append call
346  constexpr size_t argsPerAppend = 7;
347  size_t nArgs = adaptor.getArgs().size();
348  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
349  size_t bound = std::min(group + argsPerAppend, nArgs);
350  size_t numArgsThisCall = bound - group;
351 
353  arguments.push_back(printfDesc);
354  arguments.push_back(
355  rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
356  for (size_t i = group; i < bound; ++i) {
357  Value arg = adaptor.getArgs()[i];
358  if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
359  if (!floatType.isF64())
360  arg = rewriter.create<LLVM::FPExtOp>(
361  loc, typeConverter->convertType(rewriter.getF64Type()), arg);
362  arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
363  }
364  if (arg.getType().getIntOrFloatBitWidth() != 64)
365  arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
366 
367  arguments.push_back(arg);
368  }
369  // Pad out to 7 arguments since the hostcall always needs 7
370  for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
371  arguments.push_back(zeroI64);
372  }
373 
374  auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
375  arguments.push_back(isLast);
376  auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
377  printfDesc = call.getResult();
378  }
379  rewriter.eraseOp(gpuPrintfOp);
380  return success();
381 }
382 
384  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
385  ConversionPatternRewriter &rewriter) const {
386  Location loc = gpuPrintfOp->getLoc();
387 
388  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
389  mlir::Type i8Ptr = getTypeConverter()->getPointerType(llvmI8, addressSpace);
390 
391  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
392  // This ensures that global constants and declarations are placed within
393  // the device code, not the host code
394  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
395 
396  auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
397  /*isVarArg=*/true);
398  LLVM::LLVMFuncOp printfDecl =
399  getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
400 
401  // Get a unique global name for the format.
402  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
403 
404  llvm::SmallString<20> formatString(adaptor.getFormat());
405  formatString.push_back('\0'); // Null terminate for C
406  auto globalType =
407  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
408  LLVM::GlobalOp global;
409  {
411  rewriter.setInsertionPointToStart(moduleOp.getBody());
412  global = rewriter.create<LLVM::GlobalOp>(
413  loc, globalType,
414  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
415  rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
416  }
417 
418  // Get a pointer to the format string's first element
419  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
420  loc,
421  getTypeConverter()->getPointerType(globalType, global.getAddrSpace()),
422  global.getSymNameAttr());
423  Value stringStart = rewriter.create<LLVM::GEPOp>(
424  loc, i8Ptr, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
425 
426  // Construct arguments and function call
427  auto argsRange = adaptor.getArgs();
428  SmallVector<Value, 4> printfArgs;
429  printfArgs.reserve(argsRange.size() + 1);
430  printfArgs.push_back(stringStart);
431  printfArgs.append(argsRange.begin(), argsRange.end());
432 
433  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
434  rewriter.eraseOp(gpuPrintfOp);
435  return success();
436 }
437 
439  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
440  ConversionPatternRewriter &rewriter) const {
441  Location loc = gpuPrintfOp->getLoc();
442 
443  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
444  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
445 
446  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
447  // This ensures that global constants and declarations are placed within
448  // the device code, not the host code
449  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
450 
451  auto vprintfType =
452  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
453  LLVM::LLVMFuncOp vprintfDecl =
454  getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
455 
456  // Get a unique global name for the format.
457  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
458 
459  llvm::SmallString<20> formatString(adaptor.getFormat());
460  formatString.push_back('\0'); // Null terminate for C
461  auto globalType =
462  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
463  LLVM::GlobalOp global;
464  {
466  rewriter.setInsertionPointToStart(moduleOp.getBody());
467  global = rewriter.create<LLVM::GlobalOp>(
468  loc, globalType,
469  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
470  rewriter.getStringAttr(formatString), /*allignment=*/0);
471  }
472 
473  // Get a pointer to the format string's first element
474  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
475  Value stringStart = rewriter.create<LLVM::GEPOp>(
476  loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
477  SmallVector<Type> types;
478  SmallVector<Value> args;
479  // Promote and pack the arguments into a stack allocation.
480  for (Value arg : adaptor.getArgs()) {
481  Type type = arg.getType();
482  Value promotedArg = arg;
483  assert(type.isIntOrFloat());
484  if (isa<FloatType>(type)) {
485  type = rewriter.getF64Type();
486  promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
487  }
488  types.push_back(type);
489  args.push_back(promotedArg);
490  }
491  Type structType =
492  LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
493  Type structPtrType = LLVM::LLVMPointerType::get(structType);
494  Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
495  rewriter.getIndexAttr(1));
496  Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
497  /*alignment=*/0);
498  for (auto [index, arg] : llvm::enumerate(args)) {
499  Value ptr = rewriter.create<LLVM::GEPOp>(
500  loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
501  ArrayRef<LLVM::GEPArg>{0, index});
502  rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
503  }
504  tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
505  std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
506 
507  rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
508  rewriter.eraseOp(gpuPrintfOp);
509  return success();
510 }
511 
512 /// Unrolls op if it's operating on vectors.
514  ConversionPatternRewriter &rewriter,
515  const LLVMTypeConverter &converter) {
516  TypeRange operandTypes(operands);
517  if (llvm::none_of(operandTypes,
518  [](Type type) { return isa<VectorType>(type); })) {
519  return rewriter.notifyMatchFailure(op, "expected vector operand");
520  }
521  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
522  return rewriter.notifyMatchFailure(op, "expected no region/successor");
523  if (op->getNumResults() != 1)
524  return rewriter.notifyMatchFailure(op, "expected single result");
525  VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
526  if (!vectorType)
527  return rewriter.notifyMatchFailure(op, "expected vector result");
528 
529  Location loc = op->getLoc();
530  Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
531  Type indexType = converter.convertType(rewriter.getIndexType());
532  StringAttr name = op->getName().getIdentifier();
533  Type elementType = vectorType.getElementType();
534 
535  for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
536  Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
537  auto extractElement = [&](Value operand) -> Value {
538  if (!isa<VectorType>(operand.getType()))
539  return operand;
540  return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
541  };
542  auto scalarOperands = llvm::map_to_vector(operands, extractElement);
543  Operation *scalarOp =
544  rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
545  result = rewriter.create<LLVM::InsertElementOp>(
546  loc, result, scalarOp->getResult(0), index);
547  }
548 
549  rewriter.replaceOp(op, result);
550  return success();
551 }
552 
553 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
554  return IntegerAttr::get(IntegerType::get(ctx, 64), space);
555 }
556 
558  TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
559  typeConverter.addTypeAttributeConversion(
560  [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
561  gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
562  unsigned addressSpace = mapping(memorySpace);
563  return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
564  addressSpace);
565  });
566 }
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space)
static SmallString< 16 > getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp)
static std::string diag(const llvm::Value &value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:137
This class represents an argument of a Block.
Definition: Value.h:310
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
IntegerType getI8Type()
Definition: Builders.cpp:79
FloatType getF64Type()
Definition: Builders.cpp:65
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0) const
Creates an LLVM pointer type with the given element type and address space.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:502
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumSuccessors()
Definition: Operation.h:668
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:635
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:469
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Unrolls op if it's operating on vectors.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
std::function< unsigned(gpu::AddressSpace)> MemorySpaceMapping
A function that maps a MemorySpace enum to a target-specific integer value.
Definition: GPUCommonPass.h:75
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26