132 ConversionPatternRewriter &rewriter)
const {
136 if (encodeWorkgroupAttributionsAsArguments) {
141 gpuFuncOp.getWorkgroupAttributionBBArgs();
142 size_t numAttributions = workgroupAttributions.size();
145 unsigned index = gpuFuncOp.getNumArguments();
149 Type workgroupPtrType =
150 rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
155 rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
156 rewriter.getUnitAttr()),
157 rewriter.getNamedAttr(
158 getDialect().getWorkgroupAttributionAttrHelper().getName(),
159 rewriter.getUnitAttr()),
163 auto attributionType = cast<MemRefType>(attribution.getType());
164 IntegerAttr numElements =
165 rewriter.getI64IntegerAttr(attributionType.getNumElements());
166 Type llvmElementType =
168 if (!llvmElementType)
170 TypeAttr type = TypeAttr::get(llvmElementType);
171 attrs.back().setValue(
172 rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
173 argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
180 rewriter.modifyOpInPlace(
181 gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
183 static_cast<FunctionOpInterface
>(gpuFuncOp).insertArguments(
184 argIndices, argTypes, argAttrs, argLocs);
187 "expected GPU funcs to support inserting any argument");
190 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
191 for (
auto [idx, attribution] :
192 llvm::enumerate(gpuFuncOp.getWorkgroupAttributionBBArgs())) {
193 auto type = dyn_cast<MemRefType>(attribution.getType());
194 assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
196 uint64_t numElements = type.getNumElements();
199 cast<Type>(typeConverter->convertType(type.getElementType()));
200 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
202 std::string(llvm::formatv(
"__wg_{0}_{1}", gpuFuncOp.getName(), idx));
203 uint64_t alignment = 0;
204 if (
auto alignAttr = dyn_cast_or_null<IntegerAttr>(
205 gpuFuncOp.getWorkgroupAttributionAttr(
206 idx, LLVM::LLVMDialect::getAlignAttrName())))
207 alignment = alignAttr.getInt();
208 auto globalOp = LLVM::GlobalOp::create(
209 rewriter, gpuFuncOp.getLoc(), arrayType,
false,
210 LLVM::Linkage::Internal, name,
Attribute(), alignment,
212 workgroupBuffers.push_back(globalOp);
217 TypeConverter::SignatureConversion signatureConversion(
218 gpuFuncOp.front().getNumArguments());
221 gpuFuncOp.getFunctionType(),
false,
224 return rewriter.notifyMatchFailure(gpuFuncOp, [&](
Diagnostic &
diag) {
225 diag <<
"failed to convert function signature type for: "
226 << gpuFuncOp.getFunctionType();
230 ArrayAttr argAttrs = gpuFuncOp.getArgAttrsAttr();
232 FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
234 if (failed(loweredAttrs))
235 return rewriter.notifyMatchFailure(gpuFuncOp,
236 "failed to lower func attributes");
238 auto llvmFuncOp = LLVM::LLVMFuncOp::create(rewriter, gpuFuncOp.getLoc(),
239 loweredAttrs->properties,
240 loweredAttrs->discardableAttrs);
251 rewriter.setInsertionPointToStart(&gpuFuncOp.front());
252 unsigned numProperArguments = gpuFuncOp.getNumArguments();
254 if (encodeWorkgroupAttributionsAsArguments) {
257 unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
258 assert(numProperArguments >= numAttributions &&
259 "Expecting attributions to be encoded as arguments already");
264 gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
266 for (
auto [idx, vals] : llvm::enumerate(
267 llvm::zip_equal(gpuFuncOp.getWorkgroupAttributionBBArgs(),
268 attributionArguments))) {
269 auto [attribution, arg] = vals;
270 auto type = cast<MemRefType>(attribution.getType());
278 signatureConversion.remapInput(numProperArguments + idx, descr);
281 for (
const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
282 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
283 global.getAddrSpace());
284 Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
285 global.getSymNameAttr());
287 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
294 Value attribution = gpuFuncOp.getWorkgroupAttributionBBArgs()[idx];
295 auto type = cast<MemRefType>(attribution.
getType());
298 signatureConversion.remapInput(numProperArguments + idx, descr);
303 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
304 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
305 for (
const auto [idx, attribution] :
306 llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
307 auto type = cast<MemRefType>(attribution.getType());
308 assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
313 Type elementType = typeConverter->convertType(type.getElementType());
315 LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
316 Value numElements = LLVM::ConstantOp::create(
317 rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
318 uint64_t alignment = 0;
320 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
321 idx, LLVM::LLVMDialect::getAlignAttrName())))
322 alignment = alignAttr.getInt();
324 LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
325 elementType, numElements, alignment);
328 signatureConversion.remapInput(
329 numProperArguments + numWorkgroupAttributions + idx, descr);
334 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
336 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
337 &signatureConversion)))
342 for (
const auto [idx, argTy] :
343 llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
344 auto remapping = signatureConversion.getInputMapping(idx);
346 argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) :
NamedAttrList();
347 auto copyAttribute = [&](StringRef attrName) {
351 for (
size_t i = 0, e = remapping->size; i < e; ++i)
352 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
354 auto copyPointerAttribute = [&](StringRef attrName) {
359 if (remapping->size > 1 &&
360 attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
362 "Cannot copy noalias with non-bare pointers.\n");
365 for (
size_t i = 0, e = remapping->size; i < e; ++i) {
366 if (isa<LLVM::LLVMPointerType>(
367 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
368 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
376 copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
377 copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
378 copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
379 bool lowersToPointer =
false;
380 for (
size_t i = 0, e = remapping->size; i < e; ++i) {
381 lowersToPointer |= isa<LLVM::LLVMPointerType>(
382 llvmFuncOp.getArgument(remapping->inputNo + i).getType());
385 if (lowersToPointer) {
386 copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
387 copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
388 copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
389 copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
390 copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
391 copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
392 copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
393 copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
394 copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
395 copyPointerAttribute(
396 LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
397 copyPointerAttribute(
398 LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
401 rewriter.eraseOp(gpuFuncOp);
406 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter)
const {
408 Location loc = gpuPrintfOp->getLoc();
410 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
411 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
412 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
413 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
417 return rewriter.notifyMatchFailure(gpuPrintfOp,
418 "Couldn't find a parent module");
422 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
423 LLVM::LLVMFuncOp ocklAppendArgs;
424 if (!adaptor.getArgs().empty()) {
426 moduleOp, loc, rewriter,
"__ockl_printf_append_args",
427 LLVM::LLVMFunctionType::get(
428 llvmI64, {llvmI64, llvmI32, llvmI64, llvmI64, llvmI64,
429 llvmI64, llvmI64, llvmI64, llvmI64, llvmI32}));
432 moduleOp, loc, rewriter,
"__ockl_printf_append_string_n",
433 LLVM::LLVMFunctionType::get(
435 {llvmI64, ptrType, llvmI64, llvmI32}));
438 Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
439 auto printfBeginCall =
440 LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
441 Value printfDesc = printfBeginCall.getResult();
445 rewriter, loc, moduleOp, llvmI8,
"printfFormat_", adaptor.getFormat());
448 Value globalPtr = LLVM::AddressOfOp::create(
450 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
451 global.getSymNameAttr());
453 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
455 Value stringLen = LLVM::ConstantOp::create(
456 rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
458 Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
459 Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
461 auto appendFormatCall = LLVM::CallOp::create(
462 rewriter, loc, ocklAppendStringN,
463 ValueRange{printfDesc, stringStart, stringLen,
464 adaptor.getArgs().empty() ? oneI32 : zeroI32});
465 printfDesc = appendFormatCall.getResult();
468 constexpr size_t argsPerAppend = 7;
469 size_t nArgs = adaptor.getArgs().size();
470 for (
size_t group = 0; group < nArgs; group += argsPerAppend) {
471 size_t bound = std::min(group + argsPerAppend, nArgs);
472 size_t numArgsThisCall = bound - group;
475 arguments.push_back(printfDesc);
477 LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
478 for (
size_t i = group; i < bound; ++i) {
479 Value arg = adaptor.getArgs()[i];
480 if (
auto floatType = dyn_cast<FloatType>(arg.
getType())) {
481 if (!floatType.isF64())
482 arg = LLVM::FPExtOp::create(
483 rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
485 arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
488 arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
490 arguments.push_back(arg);
493 for (
size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
494 arguments.push_back(zeroI64);
497 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
498 arguments.push_back(isLast);
499 auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
500 printfDesc = call.getResult();
502 rewriter.eraseOp(gpuPrintfOp);
507 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
508 ConversionPatternRewriter &rewriter)
const {
509 Location loc = gpuPrintfOp->getLoc();
511 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
513 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
517 return rewriter.notifyMatchFailure(gpuPrintfOp,
518 "Couldn't find a parent module");
521 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
523 LLVM::LLVMFuncOp printfDecl =
525 printfDecl.setCConv(callingConvention);
529 rewriter, loc, moduleOp, llvmI8,
"printfFormat_", adaptor.getFormat(),
533 Value globalPtr = LLVM::AddressOfOp::create(
535 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
536 global.getSymNameAttr());
538 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
542 auto argsRange = adaptor.getArgs();
544 printfArgs.reserve(argsRange.size() + 1);
545 printfArgs.push_back(stringStart);
546 printfArgs.append(argsRange.begin(), argsRange.end());
548 auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
549 call.setCConv(callingConvention);
550 rewriter.eraseOp(gpuPrintfOp);
555 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
556 ConversionPatternRewriter &rewriter)
const {
557 Location loc = gpuPrintfOp->getLoc();
559 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
560 mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
564 return rewriter.notifyMatchFailure(gpuPrintfOp,
565 "Couldn't find a parent module");
573 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
575 moduleOp, globalLoc, rewriter,
"vprintf", vprintfType);
578 LLVM::GlobalOp global =
580 "printfFormat_", adaptor.getFormat());
583 Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
585 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
590 for (
Value arg : adaptor.getArgs()) {
591 Type type = arg.getType();
592 Value promotedArg = arg;
594 if (isa<FloatType>(type)) {
595 type = rewriter.getF64Type();
596 promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
598 types.push_back(type);
599 args.push_back(promotedArg);
602 LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
603 Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
604 rewriter.getIndexAttr(1));
606 LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
608 for (
auto [
index, arg] : llvm::enumerate(args)) {
610 rewriter, loc, ptrType, structType, tempAlloc,
612 LLVM::StoreOp::create(rewriter, loc, arg,
ptr);
614 std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
616 LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
617 rewriter.eraseOp(gpuPrintfOp);