MLIR 22.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "GPUOpsLowering.h"
10
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/Builders.h"
17#include "llvm/ADT/SmallVectorExtras.h"
18#include "llvm/ADT/StringSet.h"
19#include "llvm/Support/FormatVariadic.h"
20
21using namespace mlir;
22
23LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
24 OpBuilder &b, StringRef name,
25 LLVM::LLVMFunctionType type) {
26 auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
27 SymbolTable::lookupSymbolIn(moduleOp, name));
28 if (existing)
29 return existing;
30
32 b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
33 return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
34}
35
37 StringRef prefix) {
38 // Get a unique global name.
39 unsigned stringNumber = 0;
40 SmallString<16> stringConstName;
41 do {
42 stringConstName.clear();
43 (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
44 } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
45 return stringConstName;
46}
47
49 Operation *moduleOp, Type llvmI8,
50 StringRef namePrefix,
51 StringRef str,
52 uint64_t alignment,
53 unsigned addrSpace) {
54 llvm::SmallString<20> nullTermStr(str);
55 nullTermStr.push_back('\0'); // Null terminate for C
56 auto globalType =
57 LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
58 StringAttr attr = b.getStringAttr(nullTermStr);
59
60 // Try to find existing global.
61 for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
62 if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
63 globalOp.getValueAttr() == attr &&
64 globalOp.getAlignment().value_or(0) == alignment &&
65 globalOp.getAddrSpace() == addrSpace)
66 return globalOp;
67
68 // Not found: create new global.
70 b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
71 SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
72 return LLVM::GlobalOp::create(b, loc, globalType,
73 /*isConstant=*/true, LLVM::Linkage::Internal,
74 name, attr, alignment, addrSpace);
75}
76
77LogicalResult
78GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const {
80 Location loc = gpuFuncOp.getLoc();
81
82 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
83 if (encodeWorkgroupAttributionsAsArguments) {
84 // Append an `llvm.ptr` argument to the function signature to encode
85 // workgroup attributions.
86
87 ArrayRef<BlockArgument> workgroupAttributions =
88 gpuFuncOp.getWorkgroupAttributions();
89 size_t numAttributions = workgroupAttributions.size();
90
91 // Insert all arguments at the end.
92 unsigned index = gpuFuncOp.getNumArguments();
93 SmallVector<unsigned> argIndices(numAttributions, index);
94
95 // New arguments will simply be `llvm.ptr` with the correct address space
96 Type workgroupPtrType =
97 rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
98 SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
99
100 // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
101 std::array attrs{
102 rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
103 rewriter.getUnitAttr()),
104 rewriter.getNamedAttr(
105 getDialect().getWorkgroupAttributionAttrHelper().getName(),
106 rewriter.getUnitAttr()),
107 };
109 for (BlockArgument attribution : workgroupAttributions) {
110 auto attributionType = cast<MemRefType>(attribution.getType());
111 IntegerAttr numElements =
112 rewriter.getI64IntegerAttr(attributionType.getNumElements());
113 Type llvmElementType =
114 getTypeConverter()->convertType(attributionType.getElementType());
115 if (!llvmElementType)
116 return failure();
117 TypeAttr type = TypeAttr::get(llvmElementType);
118 attrs.back().setValue(
119 rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
120 argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
121 }
122
123 // Location match function location
124 SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
125
126 // Perform signature modification
127 rewriter.modifyOpInPlace(
128 gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
129 LogicalResult inserted =
130 static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
131 argIndices, argTypes, argAttrs, argLocs);
132 (void)inserted;
133 assert(succeeded(inserted) &&
134 "expected GPU funcs to support inserting any argument");
135 });
136 } else {
137 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
138 for (auto [idx, attribution] :
139 llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
140 auto type = dyn_cast<MemRefType>(attribution.getType());
141 assert(type && type.hasStaticShape() && "unexpected type in attribution");
142
143 uint64_t numElements = type.getNumElements();
144
145 auto elementType =
146 cast<Type>(typeConverter->convertType(type.getElementType()));
147 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
148 std::string name =
149 std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
150 uint64_t alignment = 0;
151 if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
152 gpuFuncOp.getWorkgroupAttributionAttr(
153 idx, LLVM::LLVMDialect::getAlignAttrName())))
154 alignment = alignAttr.getInt();
155 auto globalOp = LLVM::GlobalOp::create(
156 rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
157 LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
158 workgroupAddrSpace);
159 workgroupBuffers.push_back(globalOp);
160 }
161 }
162
163 // Remap proper input types.
164 TypeConverter::SignatureConversion signatureConversion(
165 gpuFuncOp.front().getNumArguments());
166
168 gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
169 getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
170 if (!funcType) {
171 return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
172 diag << "failed to convert function signature type for: "
173 << gpuFuncOp.getFunctionType();
174 });
175 }
176
177 // Create the new function operation. Only copy those attributes that are
178 // not specific to function modeling.
180 ArrayAttr argAttrs;
181 for (const auto &attr : gpuFuncOp->getAttrs()) {
182 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
183 attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
184 attr.getName() ==
185 gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
186 attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
187 attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
188 attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
189 attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
190 continue;
191 if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
192 argAttrs = gpuFuncOp.getArgAttrsAttr();
193 continue;
194 }
195 attributes.push_back(attr);
196 }
197
198 DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
199 DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
200 // Ensure we don't lose information if the function is lowered before its
201 // surrounding context.
202 auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
203 if (knownBlockSize)
204 attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
205 knownBlockSize);
206 if (knownGridSize)
207 attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
208 knownGridSize);
209
210 // Add a dialect specific kernel attribute in addition to GPU kernel
211 // attribute. The former is necessary for further translation while the
212 // latter is expected by gpu.launch_func.
213 if (gpuFuncOp.isKernel()) {
214 if (kernelAttributeName)
215 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
216 // Set the dialect-specific block size attribute if there is one.
217 if (kernelBlockSizeAttributeName && knownBlockSize) {
218 attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
219 }
220 }
221 LLVM::CConv callingConvention = gpuFuncOp.isKernel()
222 ? kernelCallingConvention
223 : nonKernelCallingConvention;
224 auto llvmFuncOp = LLVM::LLVMFuncOp::create(
225 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
226 LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
227 /*comdat=*/nullptr, attributes);
228
229 {
230 // Insert operations that correspond to converted workgroup and private
231 // memory attributions to the body of the function. This must operate on
232 // the original function, before the body region is inlined in the new
233 // function to maintain the relation between block arguments and the
234 // parent operation that assigns their semantics.
235 OpBuilder::InsertionGuard guard(rewriter);
236
237 // Rewrite workgroup memory attributions to addresses of global buffers.
238 rewriter.setInsertionPointToStart(&gpuFuncOp.front());
239 unsigned numProperArguments = gpuFuncOp.getNumArguments();
240
241 if (encodeWorkgroupAttributionsAsArguments) {
242 // Build a MemRefDescriptor with each of the arguments added above.
243
244 unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
245 assert(numProperArguments >= numAttributions &&
246 "Expecting attributions to be encoded as arguments already");
247
248 // Arguments encoding workgroup attributions will be in positions
249 // [numProperArguments, numProperArguments+numAttributions)
250 ArrayRef<BlockArgument> attributionArguments =
251 gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
252 numAttributions);
253 for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
254 gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
255 auto [attribution, arg] = vals;
256 auto type = cast<MemRefType>(attribution.getType());
257
258 // Arguments are of llvm.ptr type and attributions are of memref type:
259 // we need to wrap them in memref descriptors.
261 rewriter, loc, *getTypeConverter(), type, arg);
262
263 // And remap the arguments
264 signatureConversion.remapInput(numProperArguments + idx, descr);
265 }
266 } else {
267 for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
268 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
269 global.getAddrSpace());
270 Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
271 global.getSymNameAttr());
272 Value memory =
273 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
274 address, ArrayRef<LLVM::GEPArg>{0, 0});
275
276 // Build a memref descriptor pointing to the buffer to plug with the
277 // existing memref infrastructure. This may use more registers than
278 // otherwise necessary given that memref sizes are fixed, but we can try
279 // and canonicalize that away later.
280 Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
281 auto type = cast<MemRefType>(attribution.getType());
283 rewriter, loc, *getTypeConverter(), type, memory);
284 signatureConversion.remapInput(numProperArguments + idx, descr);
285 }
286 }
287
288 // Rewrite private memory attributions to alloca'ed buffers.
289 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
290 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
291 for (const auto [idx, attribution] :
292 llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
293 auto type = cast<MemRefType>(attribution.getType());
294 assert(type && type.hasStaticShape() && "unexpected type in attribution");
295
296 // Explicitly drop memory space when lowering private memory
297 // attributions since NVVM models it as `alloca`s in the default
298 // memory space and does not support `alloca`s with addrspace(5).
299 Type elementType = typeConverter->convertType(type.getElementType());
300 auto ptrType =
301 LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
302 Value numElements = LLVM::ConstantOp::create(
303 rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
304 uint64_t alignment = 0;
305 if (auto alignAttr =
306 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
307 idx, LLVM::LLVMDialect::getAlignAttrName())))
308 alignment = alignAttr.getInt();
309 Value allocated =
310 LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
311 elementType, numElements, alignment);
313 rewriter, loc, *getTypeConverter(), type, allocated);
314 signatureConversion.remapInput(
315 numProperArguments + numWorkgroupAttributions + idx, descr);
316 }
317 }
318
319 // Move the region to the new function, update the entry block signature.
320 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
321 llvmFuncOp.end());
322 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
323 &signatureConversion)))
324 return failure();
325
326 // Get memref type from function arguments and set the noalias to
327 // pointer arguments.
328 for (const auto [idx, argTy] :
329 llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
330 auto remapping = signatureConversion.getInputMapping(idx);
331 NamedAttrList argAttr =
332 argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
333 auto copyAttribute = [&](StringRef attrName) {
334 Attribute attr = argAttr.erase(attrName);
335 if (!attr)
336 return;
337 for (size_t i = 0, e = remapping->size; i < e; ++i)
338 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
339 };
340 auto copyPointerAttribute = [&](StringRef attrName) {
341 Attribute attr = argAttr.erase(attrName);
342
343 if (!attr)
344 return;
345 if (remapping->size > 1 &&
346 attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
347 emitWarning(llvmFuncOp.getLoc(),
348 "Cannot copy noalias with non-bare pointers.\n");
349 return;
350 }
351 for (size_t i = 0, e = remapping->size; i < e; ++i) {
352 if (isa<LLVM::LLVMPointerType>(
353 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
354 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
355 }
356 }
357 };
358
359 if (argAttr.empty())
360 continue;
361
362 copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
363 copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
364 copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
365 bool lowersToPointer = false;
366 for (size_t i = 0, e = remapping->size; i < e; ++i) {
367 lowersToPointer |= isa<LLVM::LLVMPointerType>(
368 llvmFuncOp.getArgument(remapping->inputNo + i).getType());
369 }
370
371 if (lowersToPointer) {
372 copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
373 copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
374 copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
375 copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
376 copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
377 copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
378 copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
379 copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
380 copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
381 copyPointerAttribute(
382 LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
383 copyPointerAttribute(
384 LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
385 }
386 }
387 rewriter.eraseOp(gpuFuncOp);
388 return success();
389}
390
392 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter) const {
394 Location loc = gpuPrintfOp->getLoc();
395
396 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
397 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
398 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
399 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
400
401 Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
402 if (!moduleOp)
403 return rewriter.notifyMatchFailure(gpuPrintfOp,
404 "Couldn't find a parent module");
405
406 auto ocklBegin =
407 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
408 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
409 LLVM::LLVMFuncOp ocklAppendArgs;
410 if (!adaptor.getArgs().empty()) {
411 ocklAppendArgs = getOrDefineFunction(
412 moduleOp, loc, rewriter, "__ockl_printf_append_args",
413 LLVM::LLVMFunctionType::get(
414 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
415 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
416 }
417 auto ocklAppendStringN = getOrDefineFunction(
418 moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
419 LLVM::LLVMFunctionType::get(
420 llvmI64,
421 {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
422
423 /// Start the printf hostcall
424 Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
425 auto printfBeginCall =
426 LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
427 Value printfDesc = printfBeginCall.getResult();
428
429 // Create the global op or find an existing one.
430 LLVM::GlobalOp global = getOrCreateStringConstant(
431 rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
432
433 // Get a pointer to the format string's first element and pass it to printf()
434 Value globalPtr = LLVM::AddressOfOp::create(
435 rewriter, loc,
436 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
437 global.getSymNameAttr());
438 Value stringStart =
439 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
440 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
441 Value stringLen = LLVM::ConstantOp::create(
442 rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
443
444 Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
445 Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
446
447 auto appendFormatCall = LLVM::CallOp::create(
448 rewriter, loc, ocklAppendStringN,
449 ValueRange{printfDesc, stringStart, stringLen,
450 adaptor.getArgs().empty() ? oneI32 : zeroI32});
451 printfDesc = appendFormatCall.getResult();
452
453 // __ockl_printf_append_args takes 7 values per append call
454 constexpr size_t argsPerAppend = 7;
455 size_t nArgs = adaptor.getArgs().size();
456 for (size_t group = 0; group < nArgs; group += argsPerAppend) {
457 size_t bound = std::min(group + argsPerAppend, nArgs);
458 size_t numArgsThisCall = bound - group;
459
461 arguments.push_back(printfDesc);
462 arguments.push_back(
463 LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
464 for (size_t i = group; i < bound; ++i) {
465 Value arg = adaptor.getArgs()[i];
466 if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
467 if (!floatType.isF64())
468 arg = LLVM::FPExtOp::create(
469 rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
470 arg);
471 arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
472 }
473 if (arg.getType().getIntOrFloatBitWidth() != 64)
474 arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
475
476 arguments.push_back(arg);
477 }
478 // Pad out to 7 arguments since the hostcall always needs 7
479 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
480 arguments.push_back(zeroI64);
481 }
482
483 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
484 arguments.push_back(isLast);
485 auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
486 printfDesc = call.getResult();
487 }
488 rewriter.eraseOp(gpuPrintfOp);
489 return success();
490}
491
493 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const {
495 Location loc = gpuPrintfOp->getLoc();
496
497 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
498 mlir::Type ptrType =
499 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
500
501 Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
502 if (!moduleOp)
503 return rewriter.notifyMatchFailure(gpuPrintfOp,
504 "Couldn't find a parent module");
505
506 auto printfType =
507 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
508 /*isVarArg=*/true);
509 LLVM::LLVMFuncOp printfDecl =
510 getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
511 printfDecl.setCConv(callingConvention);
512
513 // Create the global op or find an existing one.
514 LLVM::GlobalOp global = getOrCreateStringConstant(
515 rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
516 /*alignment=*/0, addressSpace);
517
518 // Get a pointer to the format string's first element
519 Value globalPtr = LLVM::AddressOfOp::create(
520 rewriter, loc,
521 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
522 global.getSymNameAttr());
523 Value stringStart =
524 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
525 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
526
527 // Construct arguments and function call
528 auto argsRange = adaptor.getArgs();
529 SmallVector<Value, 4> printfArgs;
530 printfArgs.reserve(argsRange.size() + 1);
531 printfArgs.push_back(stringStart);
532 printfArgs.append(argsRange.begin(), argsRange.end());
533
534 auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
535 call.setCConv(callingConvention);
536 rewriter.eraseOp(gpuPrintfOp);
537 return success();
538}
539
541 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter) const {
543 Location loc = gpuPrintfOp->getLoc();
544
545 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
546 mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
547
548 Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
549 if (!moduleOp)
550 return rewriter.notifyMatchFailure(gpuPrintfOp,
551 "Couldn't find a parent module");
552
553 // Create a valid global location removing any metadata attached to the
554 // location as debug info metadata inside of a function cannot be used outside
555 // of that function.
557
558 auto vprintfType =
559 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
560 LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction(
561 moduleOp, globalLoc, rewriter, "vprintf", vprintfType);
562
563 // Create the global op or find an existing one.
564 LLVM::GlobalOp global =
565 getOrCreateStringConstant(rewriter, globalLoc, moduleOp, llvmI8,
566 "printfFormat_", adaptor.getFormat());
567
568 // Get a pointer to the format string's first element
569 Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
570 Value stringStart =
571 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
572 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
573 SmallVector<Type> types;
575 // Promote and pack the arguments into a stack allocation.
576 for (Value arg : adaptor.getArgs()) {
577 Type type = arg.getType();
578 Value promotedArg = arg;
579 assert(type.isIntOrFloat());
580 if (isa<FloatType>(type)) {
581 type = rewriter.getF64Type();
582 promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
583 }
584 types.push_back(type);
585 args.push_back(promotedArg);
586 }
587 Type structType =
588 LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
589 Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
590 rewriter.getIndexAttr(1));
591 Value tempAlloc =
592 LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
593 /*alignment=*/0);
594 for (auto [index, arg] : llvm::enumerate(args)) {
595 Value ptr = LLVM::GEPOp::create(
596 rewriter, loc, ptrType, structType, tempAlloc,
597 ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
598 LLVM::StoreOp::create(rewriter, loc, arg, ptr);
599 }
600 std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
601
602 LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
603 rewriter.eraseOp(gpuPrintfOp);
604 return success();
605}
606
607/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
608/// Used either directly (for ops on 1D vectors) or as the callback passed to
609/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
611 Type llvm1DVectorTy,
612 ConversionPatternRewriter &rewriter,
613 const LLVMTypeConverter &converter) {
614 TypeRange operandTypes(operands);
615 VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
616 Location loc = op->getLoc();
617 Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
618 Type indexType = converter.convertType(rewriter.getIndexType());
619 StringAttr name = op->getName().getIdentifier();
620 Type elementType = vectorType.getElementType();
621
622 for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
623 Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
624 auto extractElement = [&](Value operand) -> Value {
625 if (!isa<VectorType>(operand.getType()))
626 return operand;
627 return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
628 };
629 auto scalarOperands = llvm::map_to_vector(operands, extractElement);
630 Operation *scalarOp =
631 rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
632 result = LLVM::InsertElementOp::create(rewriter, loc, result,
633 scalarOp->getResult(0), index);
634 }
635 return result;
636}
637
638/// Unrolls op to array/vector elements.
639LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
640 ConversionPatternRewriter &rewriter,
641 const LLVMTypeConverter &converter) {
642 TypeRange operandTypes(operands);
643 if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
644 VectorType vectorType =
645 cast<VectorType>(converter.convertType(op->getResultTypes()[0]));
646 rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
647 rewriter, converter));
648 return success();
649 }
650
651 if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
653 op, operands, converter,
654 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
655 return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
656 converter);
657 },
658 rewriter);
659 }
660
661 return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
662}
663
664static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
665 return IntegerAttr::get(IntegerType::get(ctx, 64), space);
666}
667
668/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
669/// or uses existing symbol.
670static LLVM::GlobalOp getDynamicSharedMemorySymbol(
671 ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
672 gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
673 MemRefType memrefType, unsigned alignmentBit) {
674 uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
675
676 FailureOr<unsigned> addressSpace =
677 typeConverter->getMemRefAddressSpace(memrefType);
678 if (failed(addressSpace)) {
679 op->emitError() << "conversion of memref memory space "
680 << memrefType.getMemorySpace()
681 << " to integer address space "
682 "failed. Consider adding memory space conversions.";
683 }
684
685 // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
686 // LLVM::GlobalOp is suitable for shared memory, return it.
687 llvm::StringSet<> existingGlobalNames;
688 for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
689 existingGlobalNames.insert(globalOp.getSymName());
690 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
691 if (globalOp.getAddrSpace() == addressSpace.value() &&
692 arrayType.getNumElements() == 0 &&
693 globalOp.getAlignment().value_or(0) == alignmentByte) {
694 return globalOp;
695 }
696 }
697 }
698
699 // Step 2. Find a unique symbol name
700 unsigned uniquingCounter = 0;
702 "__dynamic_shmem_",
703 [&](StringRef candidate) {
704 return existingGlobalNames.contains(candidate);
705 },
706 uniquingCounter);
707
708 // Step 3. Generate a global op
709 OpBuilder::InsertionGuard guard(rewriter);
710 rewriter.setInsertionPointToStart(moduleOp.getBody());
711
712 auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
713 typeConverter->convertType(memrefType.getElementType()), 0);
714
715 return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
716 /*isConstant=*/false, LLVM::Linkage::Internal,
717 symName, /*value=*/Attribute(), alignmentByte,
718 addressSpace.value());
719}
720
722 gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
723 ConversionPatternRewriter &rewriter) const {
724 Location loc = op.getLoc();
725 MemRefType memrefType = op.getResultMemref().getType();
726 Type elementType = typeConverter->convertType(memrefType.getElementType());
727
728 // Step 1: Generate a memref<0xi8> type
729 MemRefLayoutAttrInterface layout = {};
730 auto memrefType0sz =
731 MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
732
733 // Step 2: Generate a global symbol or existing for the dynamic shared
734 // memory with memref<0xi8> type
735 auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>();
736 LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
737 rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
738
739 // Step 3. Get address of the global symbol
740 OpBuilder::InsertionGuard guard(rewriter);
741 rewriter.setInsertionPoint(op);
742 auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
743 Type baseType = basePtr->getResultTypes().front();
744
745 // Step 4. Generate GEP using offsets
746 SmallVector<LLVM::GEPArg> gepArgs = {0};
747 Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
748 basePtr, gepArgs);
749 // Step 5. Create a memref descriptor
750 SmallVector<Value> shape, strides;
751 Value sizeBytes;
752 getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
753 sizeBytes);
754 auto memRefDescriptor = this->createMemRefDescriptor(
755 loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
756
757 // Step 5. Replace the op with memref descriptor
758 rewriter.replaceOp(op, {memRefDescriptor});
759 return success();
760}
761
763 gpu::ReturnOp op, OpAdaptor adaptor,
764 ConversionPatternRewriter &rewriter) const {
765 Location loc = op.getLoc();
766 unsigned numArguments = op.getNumOperands();
767 SmallVector<Value, 4> updatedOperands;
768
769 bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
770 if (useBarePtrCallConv) {
771 // For the bare-ptr calling convention, extract the aligned pointer to
772 // be returned from the memref descriptor.
773 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
774 Type oldTy = std::get<0>(it).getType();
775 Value newOperand = std::get<1>(it);
776 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
777 cast<BaseMemRefType>(oldTy))) {
778 MemRefDescriptor memrefDesc(newOperand);
779 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
780 } else if (isa<UnrankedMemRefType>(oldTy)) {
781 // Unranked memref is not supported in the bare pointer calling
782 // convention.
783 return failure();
784 }
785 updatedOperands.push_back(newOperand);
786 }
787 } else {
788 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
789 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
790 updatedOperands,
791 /*toDynamic=*/true);
792 }
793
794 // If ReturnOp has 0 or 1 operand, create it and return immediately.
795 if (numArguments <= 1) {
796 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
797 op, TypeRange(), updatedOperands, op->getAttrs());
798 return success();
799 }
800
801 // Otherwise, we need to pack the arguments into an LLVM struct type before
802 // returning.
803 auto packedType = getTypeConverter()->packFunctionResults(
804 op.getOperandTypes(), useBarePtrCallConv);
805 if (!packedType) {
806 return rewriter.notifyMatchFailure(op, "could not convert result types");
807 }
808
809 Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
810 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
811 packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
812 }
813 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
814 op->getAttrs());
815 return success();
816}
817
819 TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
820 typeConverter.addTypeAttributeConversion(
821 [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
822 gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
823 unsigned addressSpace = mapping(memorySpace);
824 return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
825 addressSpace);
826 });
827}
return success()
static SmallString< 16 > getUniqueSymbolName(Operation *moduleOp, StringRef prefix)
static LLVM::GlobalOp getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit)
Generates a symbol with 0-sized array type for dynamic shared memory usage, or uses existing symbol.
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space)
static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, Type llvm1DVectorTy, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Helper for impl::scalarizeVectorOp.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Definition Value.h:309
typename gpu::GPUFuncOp::Adaptor OpAdaptor
Definition Pattern.h:209
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition Pattern.cpp:190
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition Pattern.cpp:88
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition Pattern.cpp:32
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition Pattern.cpp:278
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Conversion from types to the LLVM IR dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false, SmallVector< SmallVector< Type > > *groupedTypes=nullptr, int64_t *numConvertedTypes=nullptr) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
static SmallString< N > generateSymbolName(StringRef name, UniqueChecker uniqueChecker, unsigned &uniquingCounter)
Generate a unique symbol name.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Unrolls op to array/vector elements.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
std::function< unsigned(gpu::AddressSpace)> MemorySpaceMapping
A function that maps a MemorySpace enum to a target-specific integer value.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override