MLIR 23.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "GPUOpsLowering.h"
10
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/Builders.h"
17#include "llvm/ADT/SmallVectorExtras.h"
18#include "llvm/ADT/StringSet.h"
19#include "llvm/Support/FormatVariadic.h"
20
21using namespace mlir;
22
23LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
24 OpBuilder &b, StringRef name,
25 LLVM::LLVMFunctionType type) {
26 auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
27 SymbolTable::lookupSymbolIn(moduleOp, name));
28 if (existing)
29 return existing;
30
32 b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
33 return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
34}
35
37 StringRef prefix) {
38 // Get a unique global name.
39 unsigned stringNumber = 0;
40 SmallString<16> stringConstName;
41 do {
42 stringConstName.clear();
43 (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
44 } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
45 return stringConstName;
46}
47
49 Operation *moduleOp, Type llvmI8,
50 StringRef namePrefix,
51 StringRef str,
52 uint64_t alignment,
53 unsigned addrSpace) {
54 llvm::SmallString<20> nullTermStr(str);
55 nullTermStr.push_back('\0'); // Null terminate for C
56 auto globalType =
57 LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
58 StringAttr attr = b.getStringAttr(nullTermStr);
59
60 // Try to find existing global.
61 for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
62 if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
63 globalOp.getValueAttr() == attr &&
64 globalOp.getAlignment().value_or(0) == alignment &&
65 globalOp.getAddrSpace() == addrSpace)
66 return globalOp;
67
68 // Not found: create new global.
70 b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
71 SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
72 return LLVM::GlobalOp::create(b, loc, globalType,
73 /*isConstant=*/true, LLVM::Linkage::Internal,
74 name, attr, alignment, addrSpace);
75}
76
77LogicalResult
78GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const {
80 Location loc = gpuFuncOp.getLoc();
81
82 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
83 if (encodeWorkgroupAttributionsAsArguments) {
84 // Append an `llvm.ptr` argument to the function signature to encode
85 // workgroup attributions.
86
87 ArrayRef<BlockArgument> workgroupAttributions =
88 gpuFuncOp.getWorkgroupAttributions();
89 size_t numAttributions = workgroupAttributions.size();
90
91 // Insert all arguments at the end.
92 unsigned index = gpuFuncOp.getNumArguments();
93 SmallVector<unsigned> argIndices(numAttributions, index);
94
95 // New arguments will simply be `llvm.ptr` with the correct address space
96 Type workgroupPtrType =
97 rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
98 SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
99
100 // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
101 std::array attrs{
102 rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
103 rewriter.getUnitAttr()),
104 rewriter.getNamedAttr(
105 getDialect().getWorkgroupAttributionAttrHelper().getName(),
106 rewriter.getUnitAttr()),
107 };
109 for (BlockArgument attribution : workgroupAttributions) {
110 auto attributionType = cast<MemRefType>(attribution.getType());
111 IntegerAttr numElements =
112 rewriter.getI64IntegerAttr(attributionType.getNumElements());
113 Type llvmElementType =
114 getTypeConverter()->convertType(attributionType.getElementType());
115 if (!llvmElementType)
116 return failure();
117 TypeAttr type = TypeAttr::get(llvmElementType);
118 attrs.back().setValue(
119 rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
120 argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
121 }
122
123 // Location match function location
124 SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
125
126 // Perform signature modification
127 rewriter.modifyOpInPlace(
128 gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
129 LogicalResult inserted =
130 static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
131 argIndices, argTypes, argAttrs, argLocs);
132 (void)inserted;
133 assert(succeeded(inserted) &&
134 "expected GPU funcs to support inserting any argument");
135 });
136 } else {
137 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
138 for (auto [idx, attribution] :
139 llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
140 auto type = dyn_cast<MemRefType>(attribution.getType());
141 assert(type && type.hasStaticShape() && "unexpected type in attribution");
142
143 uint64_t numElements = type.getNumElements();
144
145 auto elementType =
146 cast<Type>(typeConverter->convertType(type.getElementType()));
147 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
148 std::string name =
149 std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
150 uint64_t alignment = 0;
151 if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
152 gpuFuncOp.getWorkgroupAttributionAttr(
153 idx, LLVM::LLVMDialect::getAlignAttrName())))
154 alignment = alignAttr.getInt();
155 auto globalOp = LLVM::GlobalOp::create(
156 rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
157 LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
158 workgroupAddrSpace);
159 workgroupBuffers.push_back(globalOp);
160 }
161 }
162
163 // Remap proper input types.
164 TypeConverter::SignatureConversion signatureConversion(
165 gpuFuncOp.front().getNumArguments());
166
168 gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
169 getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
170 if (!funcType) {
171 return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
172 diag << "failed to convert function signature type for: "
173 << gpuFuncOp.getFunctionType();
174 });
175 }
176
177 // Create the new function operation. Only copy those attributes that are
178 // not specific to function modeling.
180 ArrayAttr argAttrs;
181 for (const auto &attr : gpuFuncOp->getAttrs()) {
182 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
183 attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
184 attr.getName() ==
185 gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
186 attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
187 attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
188 attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
189 attr.getName() == gpuFuncOp.getKnownGridSizeAttrName() ||
190 attr.getName() == gpuFuncOp.getKnownClusterSizeAttrName())
191 continue;
192 if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
193 argAttrs = gpuFuncOp.getArgAttrsAttr();
194 continue;
195 }
196 attributes.push_back(attr);
197 }
198
199 DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
200 DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
201 DenseI32ArrayAttr knownClusterSize = gpuFuncOp.getKnownClusterSizeAttr();
202 // Ensure we don't lose information if the function is lowered before its
203 // surrounding context.
204 auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
205 if (knownBlockSize)
206 attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
207 knownBlockSize);
208 if (knownGridSize)
209 attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
210 knownGridSize);
211 if (knownClusterSize)
212 attributes.emplace_back(
213 gpuDialect->getKnownClusterSizeAttrHelper().getName(),
214 knownClusterSize);
215
216 // Add a dialect specific kernel attribute in addition to GPU kernel
217 // attribute. The former is necessary for further translation while the
218 // latter is expected by gpu.launch_func.
219 if (gpuFuncOp.isKernel()) {
220 if (kernelAttributeName)
221 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
222 // Set the dialect-specific block size attribute if there is one.
223 if (kernelBlockSizeAttributeName && knownBlockSize) {
224 attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
225 }
226 // Set the dialect-specific cluster size attribute if there is one.
227 if (kernelClusterSizeAttributeName && knownClusterSize) {
228 attributes.emplace_back(kernelClusterSizeAttributeName, knownClusterSize);
229 }
230 }
231 LLVM::CConv callingConvention = gpuFuncOp.isKernel()
232 ? kernelCallingConvention
233 : nonKernelCallingConvention;
234 auto llvmFuncOp = LLVM::LLVMFuncOp::create(
235 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
236 LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
237 /*comdat=*/nullptr, attributes);
238
239 {
240 // Insert operations that correspond to converted workgroup and private
241 // memory attributions to the body of the function. This must operate on
242 // the original function, before the body region is inlined in the new
243 // function to maintain the relation between block arguments and the
244 // parent operation that assigns their semantics.
245 OpBuilder::InsertionGuard guard(rewriter);
246
247 // Rewrite workgroup memory attributions to addresses of global buffers.
248 rewriter.setInsertionPointToStart(&gpuFuncOp.front());
249 unsigned numProperArguments = gpuFuncOp.getNumArguments();
250
251 if (encodeWorkgroupAttributionsAsArguments) {
252 // Build a MemRefDescriptor with each of the arguments added above.
253
254 unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
255 assert(numProperArguments >= numAttributions &&
256 "Expecting attributions to be encoded as arguments already");
257
258 // Arguments encoding workgroup attributions will be in positions
259 // [numProperArguments, numProperArguments+numAttributions)
260 ArrayRef<BlockArgument> attributionArguments =
261 gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
262 numAttributions);
263 for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
264 gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
265 auto [attribution, arg] = vals;
266 auto type = cast<MemRefType>(attribution.getType());
267
268 // Arguments are of llvm.ptr type and attributions are of memref type:
269 // we need to wrap them in memref descriptors.
271 rewriter, loc, *getTypeConverter(), type, arg);
272
273 // And remap the arguments
274 signatureConversion.remapInput(numProperArguments + idx, descr);
275 }
276 } else {
277 for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
278 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
279 global.getAddrSpace());
280 Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
281 global.getSymNameAttr());
282 Value memory =
283 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
284 address, ArrayRef<LLVM::GEPArg>{0, 0});
285
286 // Build a memref descriptor pointing to the buffer to plug with the
287 // existing memref infrastructure. This may use more registers than
288 // otherwise necessary given that memref sizes are fixed, but we can try
289 // and canonicalize that away later.
290 Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
291 auto type = cast<MemRefType>(attribution.getType());
293 rewriter, loc, *getTypeConverter(), type, memory);
294 signatureConversion.remapInput(numProperArguments + idx, descr);
295 }
296 }
297
298 // Rewrite private memory attributions to alloca'ed buffers.
299 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
300 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
301 for (const auto [idx, attribution] :
302 llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
303 auto type = cast<MemRefType>(attribution.getType());
304 assert(type && type.hasStaticShape() && "unexpected type in attribution");
305
306 // Explicitly drop memory space when lowering private memory
307 // attributions since NVVM models it as `alloca`s in the default
308 // memory space and does not support `alloca`s with addrspace(5).
309 Type elementType = typeConverter->convertType(type.getElementType());
310 auto ptrType =
311 LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
312 Value numElements = LLVM::ConstantOp::create(
313 rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
314 uint64_t alignment = 0;
315 if (auto alignAttr =
316 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
317 idx, LLVM::LLVMDialect::getAlignAttrName())))
318 alignment = alignAttr.getInt();
319 Value allocated =
320 LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
321 elementType, numElements, alignment);
323 rewriter, loc, *getTypeConverter(), type, allocated);
324 signatureConversion.remapInput(
325 numProperArguments + numWorkgroupAttributions + idx, descr);
326 }
327 }
328
329 // Move the region to the new function, update the entry block signature.
330 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
331 llvmFuncOp.end());
332 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
333 &signatureConversion)))
334 return failure();
335
336 // Get memref type from function arguments and set the noalias to
337 // pointer arguments.
338 for (const auto [idx, argTy] :
339 llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
340 auto remapping = signatureConversion.getInputMapping(idx);
341 NamedAttrList argAttr =
342 argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
343 auto copyAttribute = [&](StringRef attrName) {
344 Attribute attr = argAttr.erase(attrName);
345 if (!attr)
346 return;
347 for (size_t i = 0, e = remapping->size; i < e; ++i)
348 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
349 };
350 auto copyPointerAttribute = [&](StringRef attrName) {
351 Attribute attr = argAttr.erase(attrName);
352
353 if (!attr)
354 return;
355 if (remapping->size > 1 &&
356 attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
357 emitWarning(llvmFuncOp.getLoc(),
358 "Cannot copy noalias with non-bare pointers.\n");
359 return;
360 }
361 for (size_t i = 0, e = remapping->size; i < e; ++i) {
362 if (isa<LLVM::LLVMPointerType>(
363 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
364 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
365 }
366 }
367 };
368
369 if (argAttr.empty())
370 continue;
371
372 copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
373 copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
374 copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
375 bool lowersToPointer = false;
376 for (size_t i = 0, e = remapping->size; i < e; ++i) {
377 lowersToPointer |= isa<LLVM::LLVMPointerType>(
378 llvmFuncOp.getArgument(remapping->inputNo + i).getType());
379 }
380
381 if (lowersToPointer) {
382 copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
383 copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
384 copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
385 copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
386 copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
387 copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
388 copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
389 copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
390 copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
391 copyPointerAttribute(
392 LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
393 copyPointerAttribute(
394 LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
395 }
396 }
397 rewriter.eraseOp(gpuFuncOp);
398 return success();
399}
400
402 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter) const {
404 Location loc = gpuPrintfOp->getLoc();
405
406 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
407 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
408 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
409 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
410
411 Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
412 if (!moduleOp)
413 return rewriter.notifyMatchFailure(gpuPrintfOp,
414 "Couldn't find a parent module");
415
416 auto ocklBegin =
417 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
418 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
419 LLVM::LLVMFuncOp ocklAppendArgs;
420 if (!adaptor.getArgs().empty()) {
421 ocklAppendArgs = getOrDefineFunction(
422 moduleOp, loc, rewriter, "__ockl_printf_append_args",
423 LLVM::LLVMFunctionType::get(
424 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
425 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
426 }
427 auto ocklAppendStringN = getOrDefineFunction(
428 moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
429 LLVM::LLVMFunctionType::get(
430 llvmI64,
431 {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
432
433 /// Start the printf hostcall
434 Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
435 auto printfBeginCall =
436 LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
437 Value printfDesc = printfBeginCall.getResult();
438
439 // Create the global op or find an existing one.
440 LLVM::GlobalOp global = getOrCreateStringConstant(
441 rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
442
443 // Get a pointer to the format string's first element and pass it to printf()
444 Value globalPtr = LLVM::AddressOfOp::create(
445 rewriter, loc,
446 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
447 global.getSymNameAttr());
448 Value stringStart =
449 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
450 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
451 Value stringLen = LLVM::ConstantOp::create(
452 rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
453
454 Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
455 Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
456
457 auto appendFormatCall = LLVM::CallOp::create(
458 rewriter, loc, ocklAppendStringN,
459 ValueRange{printfDesc, stringStart, stringLen,
460 adaptor.getArgs().empty() ? oneI32 : zeroI32});
461 printfDesc = appendFormatCall.getResult();
462
463 // __ockl_printf_append_args takes 7 values per append call
464 constexpr size_t argsPerAppend = 7;
465 size_t nArgs = adaptor.getArgs().size();
466 for (size_t group = 0; group < nArgs; group += argsPerAppend) {
467 size_t bound = std::min(group + argsPerAppend, nArgs);
468 size_t numArgsThisCall = bound - group;
469
471 arguments.push_back(printfDesc);
472 arguments.push_back(
473 LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
474 for (size_t i = group; i < bound; ++i) {
475 Value arg = adaptor.getArgs()[i];
476 if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
477 if (!floatType.isF64())
478 arg = LLVM::FPExtOp::create(
479 rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
480 arg);
481 arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
482 }
483 if (arg.getType().getIntOrFloatBitWidth() != 64)
484 arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
485
486 arguments.push_back(arg);
487 }
488 // Pad out to 7 arguments since the hostcall always needs 7
489 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
490 arguments.push_back(zeroI64);
491 }
492
493 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
494 arguments.push_back(isLast);
495 auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
496 printfDesc = call.getResult();
497 }
498 rewriter.eraseOp(gpuPrintfOp);
499 return success();
500}
501
503 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter) const {
505 Location loc = gpuPrintfOp->getLoc();
506
507 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
508 mlir::Type ptrType =
509 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
510
511 Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
512 if (!moduleOp)
513 return rewriter.notifyMatchFailure(gpuPrintfOp,
514 "Couldn't find a parent module");
515
516 auto printfType =
517 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
518 /*isVarArg=*/true);
519 LLVM::LLVMFuncOp printfDecl =
520 getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
521 printfDecl.setCConv(callingConvention);
522
523 // Create the global op or find an existing one.
524 LLVM::GlobalOp global = getOrCreateStringConstant(
525 rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
526 /*alignment=*/0, addressSpace);
527
528 // Get a pointer to the format string's first element
529 Value globalPtr = LLVM::AddressOfOp::create(
530 rewriter, loc,
531 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
532 global.getSymNameAttr());
533 Value stringStart =
534 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
535 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
536
537 // Construct arguments and function call
538 auto argsRange = adaptor.getArgs();
539 SmallVector<Value, 4> printfArgs;
540 printfArgs.reserve(argsRange.size() + 1);
541 printfArgs.push_back(stringStart);
542 printfArgs.append(argsRange.begin(), argsRange.end());
543
544 auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
545 call.setCConv(callingConvention);
546 rewriter.eraseOp(gpuPrintfOp);
547 return success();
548}
549
551 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
552 ConversionPatternRewriter &rewriter) const {
553 Location loc = gpuPrintfOp->getLoc();
554
555 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
556 mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
557
558 Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
559 if (!moduleOp)
560 return rewriter.notifyMatchFailure(gpuPrintfOp,
561 "Couldn't find a parent module");
562
563 // Create a valid global location removing any metadata attached to the
564 // location as debug info metadata inside of a function cannot be used outside
565 // of that function.
567
568 auto vprintfType =
569 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
570 LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction(
571 moduleOp, globalLoc, rewriter, "vprintf", vprintfType);
572
573 // Create the global op or find an existing one.
574 LLVM::GlobalOp global =
575 getOrCreateStringConstant(rewriter, globalLoc, moduleOp, llvmI8,
576 "printfFormat_", adaptor.getFormat());
577
578 // Get a pointer to the format string's first element
579 Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
580 Value stringStart =
581 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
582 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
583 SmallVector<Type> types;
585 // Promote and pack the arguments into a stack allocation.
586 for (Value arg : adaptor.getArgs()) {
587 Type type = arg.getType();
588 Value promotedArg = arg;
589 assert(type.isIntOrFloat());
590 if (isa<FloatType>(type)) {
591 type = rewriter.getF64Type();
592 promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
593 }
594 types.push_back(type);
595 args.push_back(promotedArg);
596 }
597 Type structType =
598 LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
599 Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
600 rewriter.getIndexAttr(1));
601 Value tempAlloc =
602 LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
603 /*alignment=*/0);
604 for (auto [index, arg] : llvm::enumerate(args)) {
605 Value ptr = LLVM::GEPOp::create(
606 rewriter, loc, ptrType, structType, tempAlloc,
607 ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
608 LLVM::StoreOp::create(rewriter, loc, arg, ptr);
609 }
610 std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
611
612 LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
613 rewriter.eraseOp(gpuPrintfOp);
614 return success();
615}
616
617/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
618/// Used either directly (for ops on 1D vectors) or as the callback passed to
619/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
621 Type llvm1DVectorTy,
622 ConversionPatternRewriter &rewriter,
623 const LLVMTypeConverter &converter) {
624 TypeRange operandTypes(operands);
625 VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
626 Location loc = op->getLoc();
627 Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
628 Type indexType = converter.convertType(rewriter.getIndexType());
629 StringAttr name = op->getName().getIdentifier();
630 Type elementType = vectorType.getElementType();
631
632 for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
633 Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
634 auto extractElement = [&](Value operand) -> Value {
635 if (!isa<VectorType>(operand.getType()))
636 return operand;
637 return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
638 };
639 auto scalarOperands = llvm::map_to_vector(operands, extractElement);
640 Operation *scalarOp =
641 rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
642 result = LLVM::InsertElementOp::create(rewriter, loc, result,
643 scalarOp->getResult(0), index);
644 }
645 return result;
646}
647
648/// Unrolls op to array/vector elements.
649LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
650 ConversionPatternRewriter &rewriter,
651 const LLVMTypeConverter &converter) {
652 TypeRange operandTypes(operands);
653 if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
654 VectorType vectorType =
655 cast<VectorType>(converter.convertType(op->getResultTypes()[0]));
656 rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
657 rewriter, converter));
658 return success();
659 }
660
661 if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
663 op, operands, converter,
664 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
665 return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
666 converter);
667 },
668 rewriter);
669 }
670
671 return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
672}
673
674static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
675 return IntegerAttr::get(IntegerType::get(ctx, 64), space);
676}
677
678/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
679/// or uses existing symbol.
680static LLVM::GlobalOp getDynamicSharedMemorySymbol(
681 ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
682 gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
683 MemRefType memrefType, unsigned alignmentBit) {
684 uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
685
686 FailureOr<unsigned> addressSpace =
687 typeConverter->getMemRefAddressSpace(memrefType);
688 if (failed(addressSpace)) {
689 op->emitError() << "conversion of memref memory space "
690 << memrefType.getMemorySpace()
691 << " to integer address space "
692 "failed. Consider adding memory space conversions.";
693 }
694
695 // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
696 // LLVM::GlobalOp is suitable for shared memory, return it.
697 llvm::StringSet<> existingGlobalNames;
698 for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
699 existingGlobalNames.insert(globalOp.getSymName());
700 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
701 if (globalOp.getAddrSpace() == addressSpace.value() &&
702 arrayType.getNumElements() == 0 &&
703 globalOp.getAlignment().value_or(0) == alignmentByte) {
704 return globalOp;
705 }
706 }
707 }
708
709 // Step 2. Find a unique symbol name
710 unsigned uniquingCounter = 0;
712 "__dynamic_shmem_",
713 [&](StringRef candidate) {
714 return existingGlobalNames.contains(candidate);
715 },
716 uniquingCounter);
717
718 // Step 3. Generate a global op
719 OpBuilder::InsertionGuard guard(rewriter);
720 rewriter.setInsertionPointToStart(moduleOp.getBody());
721
722 auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
723 typeConverter->convertType(memrefType.getElementType()), 0);
724
725 return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
726 /*isConstant=*/false, LLVM::Linkage::Internal,
727 symName, /*value=*/Attribute(), alignmentByte,
728 addressSpace.value());
729}
730
732 gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter) const {
734 Location loc = op.getLoc();
735 MemRefType memrefType = op.getResultMemref().getType();
736 Type elementType = typeConverter->convertType(memrefType.getElementType());
737
738 // Step 1: Generate a memref<0xi8> type
739 MemRefLayoutAttrInterface layout = {};
740 auto memrefType0sz =
741 MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
742
743 // Step 2: Generate a global symbol or existing for the dynamic shared
744 // memory with memref<0xi8> type
745 auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>();
746 LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
747 rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
748
749 // Step 3. Get address of the global symbol
750 OpBuilder::InsertionGuard guard(rewriter);
751 rewriter.setInsertionPoint(op);
752 auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
753 Type baseType = basePtr->getResultTypes().front();
754
755 // Step 4. Generate GEP using offsets
756 SmallVector<LLVM::GEPArg> gepArgs = {0};
757 Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
758 basePtr, gepArgs);
759 // Step 5. Create a memref descriptor
760 SmallVector<Value> shape, strides;
761 Value sizeBytes;
762 getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
763 sizeBytes);
764 auto memRefDescriptor = this->createMemRefDescriptor(
765 loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
766
767 // Step 5. Replace the op with memref descriptor
768 rewriter.replaceOp(op, {memRefDescriptor});
769 return success();
770}
771
773 gpu::ReturnOp op, OpAdaptor adaptor,
774 ConversionPatternRewriter &rewriter) const {
775 Location loc = op.getLoc();
776 unsigned numArguments = op.getNumOperands();
777 SmallVector<Value, 4> updatedOperands;
778
779 bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
780 if (useBarePtrCallConv) {
781 // For the bare-ptr calling convention, extract the aligned pointer to
782 // be returned from the memref descriptor.
783 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
784 Type oldTy = std::get<0>(it).getType();
785 Value newOperand = std::get<1>(it);
786 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
787 cast<BaseMemRefType>(oldTy))) {
788 MemRefDescriptor memrefDesc(newOperand);
789 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
790 } else if (isa<UnrankedMemRefType>(oldTy)) {
791 // Unranked memref is not supported in the bare pointer calling
792 // convention.
793 return failure();
794 }
795 updatedOperands.push_back(newOperand);
796 }
797 } else {
798 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
799 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
800 updatedOperands,
801 /*toDynamic=*/true);
802 }
803
804 // If ReturnOp has 0 or 1 operand, create it and return immediately.
805 if (numArguments <= 1) {
806 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
807 op, TypeRange(), updatedOperands, op->getAttrs());
808 return success();
809 }
810
811 // Otherwise, we need to pack the arguments into an LLVM struct type before
812 // returning.
813 auto packedType = getTypeConverter()->packFunctionResults(
814 op.getOperandTypes(), useBarePtrCallConv);
815 if (!packedType) {
816 return rewriter.notifyMatchFailure(op, "could not convert result types");
817 }
818
819 Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
820 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
821 packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
822 }
823 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
824 op->getAttrs());
825 return success();
826}
827
829 TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
830 typeConverter.addTypeAttributeConversion(
831 [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
832 gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
833 unsigned addressSpace = mapping(memorySpace);
834 return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
835 addressSpace);
836 });
837}
return success()
static SmallString< 16 > getUniqueSymbolName(Operation *moduleOp, StringRef prefix)
static LLVM::GlobalOp getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit)
Generates a symbol with 0-sized array type for dynamic shared memory usage, or uses existing symbol.
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space)
static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, Type llvm1DVectorTy, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Helper for impl::scalarizeVectorOp.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Definition Value.h:309
typename gpu::GPUFuncOp::Adaptor OpAdaptor
Definition Pattern.h:218
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition Pattern.cpp:190
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition Pattern.cpp:88
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition Pattern.cpp:32
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition Pattern.cpp:278
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Conversion from types to the LLVM IR dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false, SmallVector< SmallVector< Type > > *groupedTypes=nullptr, int64_t *numConvertedTypes=nullptr) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
static SmallString< N > generateSymbolName(StringRef name, UniqueChecker uniqueChecker, unsigned &uniquingCounter)
Generate a unique symbol name.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Unrolls op to array/vector elements.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
std::function< unsigned(gpu::AddressSpace)> MemorySpaceMapping
A function that maps a MemorySpace enum to a target-specific integer value.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override