25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Frontend/OpenMP/OMPConstants.h"
29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30 #include "llvm/IR/DebugInfoMetadata.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/ReplaceConstant.h"
33 #include "llvm/Support/FileSystem.h"
34 #include "llvm/TargetParser/Triple.h"
35 #include "llvm/Transforms/Utils/ModuleUtils.h"
47 static llvm::omp::ScheduleKind
48 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
49 if (!schedKind.has_value())
50 return llvm::omp::OMP_SCHEDULE_Default;
51 switch (schedKind.value()) {
52 case omp::ClauseScheduleKind::Static:
53 return llvm::omp::OMP_SCHEDULE_Static;
54 case omp::ClauseScheduleKind::Dynamic:
55 return llvm::omp::OMP_SCHEDULE_Dynamic;
56 case omp::ClauseScheduleKind::Guided:
57 return llvm::omp::OMP_SCHEDULE_Guided;
58 case omp::ClauseScheduleKind::Auto:
59 return llvm::omp::OMP_SCHEDULE_Auto;
61 return llvm::omp::OMP_SCHEDULE_Runtime;
63 llvm_unreachable(
"unhandled schedule clause argument");
68 class OpenMPAllocaStackFrame
73 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
74 : allocaInsertPoint(allocaIP) {}
75 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80 class OpenMPVarMappingStackFrame
82 OpenMPVarMappingStackFrame> {
86 explicit OpenMPVarMappingStackFrame(
96 static llvm::OpenMPIRBuilder::InsertPointTy
102 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
104 [&](
const OpenMPAllocaStackFrame &frame) {
105 allocaInsertPoint = frame.allocaInsertPoint;
109 return allocaInsertPoint;
118 if (builder.GetInsertBlock() ==
119 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
120 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
121 "Assuming end of basic block");
122 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
123 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
124 builder.GetInsertBlock()->getNextNode());
125 builder.CreateBr(entryBB);
126 builder.SetInsertPoint(entryBB);
129 llvm::BasicBlock &funcEntryBlock =
130 builder.GetInsertBlock()->getParent()->getEntryBlock();
131 return llvm::OpenMPIRBuilder::InsertPointTy(
132 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
141 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
144 llvm::BasicBlock *continuationBlock =
145 splitBB(builder,
true,
"omp.region.cont");
146 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
148 llvm::LLVMContext &llvmContext = builder.getContext();
149 for (
Block &bb : region) {
150 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
151 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
152 builder.GetInsertBlock()->getNextNode());
153 moduleTranslation.
mapBlock(&bb, llvmBB);
156 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
162 bool operandsProcessed =
false;
163 unsigned numYields = 0;
165 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
166 if (!operandsProcessed) {
167 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
168 continuationBlockPHITypes.push_back(
169 moduleTranslation.
convertType(yield->getOperand(i).getType()));
171 operandsProcessed =
true;
173 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
174 "mismatching number of values yielded from the region");
175 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
176 llvm::Type *operandType =
177 moduleTranslation.
convertType(yield->getOperand(i).getType());
179 assert(continuationBlockPHITypes[i] == operandType &&
180 "values of mismatching types yielded from the region");
189 if (!continuationBlockPHITypes.empty())
191 continuationBlockPHIs &&
192 "expected continuation block PHIs if converted regions yield values");
193 if (continuationBlockPHIs) {
194 llvm::IRBuilderBase::InsertPointGuard guard(builder);
195 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
196 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
197 for (llvm::Type *ty : continuationBlockPHITypes)
198 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
204 for (
Block *bb : blocks) {
205 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
208 if (bb->isEntryBlock()) {
209 assert(sourceTerminator->getNumSuccessors() == 1 &&
210 "provided entry block has multiple successors");
211 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
212 "ContinuationBlock is not the successor of the entry block");
213 sourceTerminator->setSuccessor(0, llvmBB);
216 llvm::IRBuilderBase::InsertPointGuard guard(builder);
218 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder))) {
219 bodyGenStatus = failure();
220 return continuationBlock;
230 Operation *terminator = bb->getTerminator();
231 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
232 builder.CreateBr(continuationBlock);
234 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
235 (*continuationBlockPHIs)[i]->addIncoming(
249 return continuationBlock;
255 case omp::ClauseProcBindKind::Close:
256 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
257 case omp::ClauseProcBindKind::Master:
258 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
259 case omp::ClauseProcBindKind::Primary:
260 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
261 case omp::ClauseProcBindKind::Spread:
262 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
264 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
271 auto maskedOp = cast<omp::MaskedOp>(opInst);
272 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
275 LogicalResult bodyGenStatus = success();
277 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
279 auto ®ion = maskedOp.getRegion();
280 builder.restoreIP(codeGenIP);
287 auto finiCB = [&](InsertPointTy codeGenIP) {};
289 llvm::Value *filterVal =
nullptr;
290 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
291 filterVal = moduleTranslation.
lookupValue(filterVar);
293 llvm::LLVMContext &llvmContext = builder.getContext();
297 assert(filterVal !=
nullptr);
298 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
300 ompLoc, bodyGenCB, finiCB, filterVal));
308 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
311 LogicalResult bodyGenStatus = success();
313 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
315 auto ®ion = cast<omp::MasterOp>(opInst).getRegion();
316 builder.restoreIP(codeGenIP);
323 auto finiCB = [&](InsertPointTy codeGenIP) {};
325 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
327 ompLoc, bodyGenCB, finiCB));
335 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
336 auto criticalOp = cast<omp::CriticalOp>(opInst);
339 LogicalResult bodyGenStatus = success();
341 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
343 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
344 builder.restoreIP(codeGenIP);
346 moduleTranslation, bodyGenStatus);
351 auto finiCB = [&](InsertPointTy codeGenIP) {};
353 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
354 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
355 llvm::Constant *hint =
nullptr;
358 if (criticalOp.getNameAttr()) {
361 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
362 auto criticalDeclareOp =
363 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
366 llvm::Type::getInt32Ty(llvmContext),
367 static_cast<int>(criticalDeclareOp.getHintVal()));
370 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint));
375 template <
typename T>
379 std::optional<ArrayAttr> attr = loop.getReductions();
383 reductions.reserve(reductions.size() + loop.getNumReductionVars());
384 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
385 reductions.push_back(
386 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
397 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
405 if (llvm::hasSingleElement(region)) {
406 llvm::Instruction *potentialTerminator =
407 builder.GetInsertBlock()->empty() ? nullptr
408 : &builder.GetInsertBlock()->back();
410 if (potentialTerminator && potentialTerminator->isTerminator())
411 potentialTerminator->removeFromParent();
412 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
415 region.
front(),
true, builder)))
419 if (continuationBlockArgs)
421 *continuationBlockArgs,
428 if (potentialTerminator && potentialTerminator->isTerminator()) {
429 llvm::BasicBlock *block = builder.GetInsertBlock();
430 if (block->empty()) {
436 potentialTerminator->insertInto(block, block->begin());
438 potentialTerminator->insertAfter(&block->back());
445 LogicalResult bodyGenStatus = success();
448 region, blockName, builder, moduleTranslation, bodyGenStatus, &phis);
449 if (failed(bodyGenStatus))
451 if (continuationBlockArgs)
452 llvm::append_range(*continuationBlockArgs, phis);
453 builder.SetInsertPoint(continuationBlock,
454 continuationBlock->getFirstInsertionPt());
461 using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
462 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
464 using OwningAtomicReductionGen =
465 std::function<llvm::OpenMPIRBuilder::InsertPointTy(
466 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
473 static OwningReductionGen
479 OwningReductionGen gen =
480 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
481 llvm::Value *lhs, llvm::Value *rhs,
482 llvm::Value *&result)
mutable {
483 Region &reductionRegion = decl.getReductionRegion();
486 builder.restoreIP(insertPoint);
489 "omp.reduction.nonatomic.body",
490 builder, moduleTranslation, &phis)))
491 return llvm::OpenMPIRBuilder::InsertPointTy();
492 assert(phis.size() == 1);
494 return builder.saveIP();
503 static OwningAtomicReductionGen
505 llvm::IRBuilderBase &builder,
507 if (decl.getAtomicReductionRegion().empty())
508 return OwningAtomicReductionGen();
513 OwningAtomicReductionGen atomicGen =
514 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
515 llvm::Value *lhs, llvm::Value *rhs)
mutable {
516 Region &atomicRegion = decl.getAtomicReductionRegion();
519 builder.restoreIP(insertPoint);
522 "omp.reduction.atomic.body", builder,
523 moduleTranslation, &phis)))
524 return llvm::OpenMPIRBuilder::InsertPointTy();
525 assert(phis.empty());
526 return builder.saveIP();
535 auto orderedOp = cast<omp::OrderedOp>(opInst);
537 omp::ClauseDepend dependType = *orderedOp.getDependTypeVal();
538 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
539 unsigned numLoops = *orderedOp.getNumLoopsVal();
541 moduleTranslation.
lookupValues(orderedOp.getDependVecVars());
543 size_t indexVecValues = 0;
544 while (indexVecValues < vecValues.size()) {
546 storeValues.reserve(numLoops);
547 for (
unsigned i = 0; i < numLoops; i++) {
548 storeValues.push_back(vecValues[indexVecValues]);
551 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
553 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
554 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
555 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
565 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
566 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
569 if (orderedRegionOp.getSimd())
574 LogicalResult bodyGenStatus = success();
576 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
578 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
579 builder.restoreIP(codeGenIP);
581 moduleTranslation, bodyGenStatus);
586 auto finiCB = [&](InsertPointTy codeGenIP) {};
588 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
591 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getSimd()));
592 return bodyGenStatus;
596 template <
typename T>
600 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
606 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
608 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
611 llvm::Value *var = builder.CreateAlloca(
612 moduleTranslation.
convertType(reductionDecls[i].getType()));
613 moduleTranslation.
mapValue(reductionArgs[i], var);
614 privateReductionVariables[i] = var;
615 reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
620 template <
typename T>
626 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
627 Region &initializerRegion = reduction.getInitializerRegion();
630 "the initialization region has one argument");
632 mlir::Value mlirSource = loop.getReductionVars()[i];
633 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
634 assert(llvmSource &&
"lookup reduction var");
639 template <
typename T>
641 T loop, llvm::IRBuilderBase &builder,
648 unsigned numReductions = loop.getNumReductionVars();
650 for (
unsigned i = 0; i < numReductions; ++i) {
651 owningReductionGens.push_back(
653 owningAtomicReductionGens.push_back(
658 reductionInfos.reserve(numReductions);
659 for (
unsigned i = 0; i < numReductions; ++i) {
660 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
661 if (owningAtomicReductionGens[i])
662 atomicGen = owningAtomicReductionGens[i];
663 llvm::Value *variable =
664 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
665 reductionInfos.push_back(
667 privateReductionVariables[i],
668 llvm::OpenMPIRBuilder::EvalKind::Scalar,
669 owningReductionGens[i],
670 nullptr, atomicGen});
679 llvm::IRBuilderBase &builder, StringRef regionName,
680 bool shouldLoadCleanupRegionArg =
true) {
682 if (cleanupRegion->empty())
688 llvm::Instruction *potentialTerminator =
689 builder.GetInsertBlock()->empty() ? nullptr
690 : &builder.GetInsertBlock()->back();
691 if (potentialTerminator && potentialTerminator->isTerminator())
692 builder.SetInsertPoint(potentialTerminator);
693 llvm::Value *prviateVarValue =
694 shouldLoadCleanupRegionArg
695 ? builder.CreateLoad(
698 : privateVariables[i];
716 OP op, llvm::IRBuilderBase &builder,
718 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
722 if (op.getNumReductionVars() == 0)
733 owningReductionGens, owningAtomicReductionGens,
734 privateReductionVariables, reductionInfos);
739 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
740 builder.SetInsertPoint(tempTerminator);
741 llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
742 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
743 isByRef, op.getNowait());
744 if (!contInsertPoint.getBlock())
745 return op->
emitOpError() <<
"failed to convert reductions";
746 auto nextInsertionPoint =
747 ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
748 tempTerminator->eraseFromParent();
749 builder.restoreIP(nextInsertionPoint);
753 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
754 [](omp::DeclareReductionOp reductionDecl) {
755 return &reductionDecl.getCleanupRegion();
758 moduleTranslation, builder,
759 "omp.reduction.cleanup");
770 template <
typename OP>
774 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
779 if (op.getNumReductionVars() == 0)
783 allocaIP, reductionDecls, privateReductionVariables,
784 reductionVariableMap, isByRef);
789 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
796 "omp.reduction.neutral", builder,
797 moduleTranslation, &phis)))
799 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
800 "reduction neutral element declaration region");
804 llvm::Value *var = builder.CreateAlloca(
805 moduleTranslation.
convertType(reductionDecls[i].getType()));
808 builder.CreateStore(phis[0], var);
810 privateReductionVariables[i] = var;
811 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
812 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
815 builder.CreateStore(phis[0], privateReductionVariables[i]);
822 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
831 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
832 using StorableBodyGenCallbackTy =
833 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
835 auto sectionsOp = cast<omp::SectionsOp>(opInst);
839 if (!sectionsOp.getAllocateVars().empty() ||
840 !sectionsOp.getAllocatorsVars().empty())
842 <<
"allocate clause is not supported for sections construct";
845 assert(isByRef.size() == sectionsOp.getNumReductionVars());
849 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
853 sectionsOp.getNumReductionVars());
857 sectionsOp.getRegion().getArguments();
860 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
861 reductionDecls, privateReductionVariables, reductionVariableMap,
869 moduleTranslation, reductionVariableMap);
871 LogicalResult bodyGenStatus = success();
875 auto sectionOp = dyn_cast<omp::SectionOp>(op);
879 Region ®ion = sectionOp.getRegion();
880 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation,
881 &bodyGenStatus](InsertPointTy allocaIP,
882 InsertPointTy codeGenIP) {
883 builder.restoreIP(codeGenIP);
890 sectionsOp.getRegion().getNumArguments());
891 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
892 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
893 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
895 moduleTranslation.mapValue(sectionArg, llvmVal);
899 moduleTranslation, bodyGenStatus);
901 sectionCBs.push_back(sectionCB);
907 if (sectionCBs.empty())
910 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
915 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
917 llvm::Value *&replacementValue) -> InsertPointTy {
918 replacementValue = &vPtr;
924 auto finiCB = [&](InsertPointTy codeGenIP) {};
927 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
929 ompLoc, allocaIP, sectionCBs, privCB, finiCB,
false,
930 sectionsOp.getNowait()));
932 if (failed(bodyGenStatus))
933 return bodyGenStatus;
937 allocaIP, reductionDecls,
938 privateReductionVariables, isByRef);
945 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
946 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
947 LogicalResult bodyGenStatus = success();
948 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
949 builder.restoreIP(codegenIP);
951 moduleTranslation, bodyGenStatus);
953 auto finiCB = [&](InsertPointTy codeGenIP) {};
957 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateFuncs();
960 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
961 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
962 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
963 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
964 llvmCPFuncs.push_back(
969 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
970 return bodyGenStatus;
977 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
978 LogicalResult bodyGenStatus = success();
979 if (!op.getAllocatorsVars().empty() || op.getReductions())
980 return op.
emitError(
"unhandled clauses for translation to LLVM IR");
982 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
984 moduleTranslation, allocaIP);
985 builder.restoreIP(codegenIP);
987 moduleTranslation, bodyGenStatus);
990 llvm::Value *numTeamsLower =
nullptr;
991 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
992 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
994 llvm::Value *numTeamsUpper =
nullptr;
995 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
996 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
998 llvm::Value *threadLimit =
nullptr;
999 if (
Value threadLimitVar = op.getThreadLimit())
1000 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1002 llvm::Value *ifExpr =
nullptr;
1003 if (
Value ifExprVar = op.getIfExpr())
1004 ifExpr = moduleTranslation.
lookupValue(ifExprVar);
1006 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1008 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));
1009 return bodyGenStatus;
1016 if (dependVars.empty())
1018 for (
auto dep : llvm::zip(dependVars, depends->getValue())) {
1019 llvm::omp::RTLDependenceKindTy type;
1021 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
1022 case mlir::omp::ClauseTaskDepend::taskdependin:
1023 type = llvm::omp::RTLDependenceKindTy::DepIn;
1028 case mlir::omp::ClauseTaskDepend::taskdependout:
1029 case mlir::omp::ClauseTaskDepend::taskdependinout:
1030 type = llvm::omp::RTLDependenceKindTy::DepInOut;
1033 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
1034 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
1035 dds.emplace_back(dd);
1039 static LogicalResult
1042 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1043 LogicalResult bodyGenStatus = success();
1044 if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() ||
1045 taskOp.getInReductions() || taskOp.getPriority() ||
1046 !taskOp.getAllocateVars().empty()) {
1047 return taskOp.emitError(
"unhandled clauses for translation to LLVM IR");
1049 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1053 moduleTranslation, allocaIP);
1055 builder.restoreIP(codegenIP);
1057 moduleTranslation, bodyGenStatus);
1062 moduleTranslation, dds);
1064 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1066 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1068 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
1069 moduleTranslation.
lookupValue(taskOp.getFinalExpr()),
1070 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds));
1071 return bodyGenStatus;
1075 static LogicalResult
1078 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1079 LogicalResult bodyGenStatus = success();
1080 if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
1081 return tgOp.emitError(
"unhandled clauses for translation to LLVM IR");
1083 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1084 builder.restoreIP(codegenIP);
1086 moduleTranslation, bodyGenStatus);
1089 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1091 ompLoc, allocaIP, bodyCB));
1092 return bodyGenStatus;
1095 static LogicalResult
1098 auto wsloopOp = cast<omp::WsloopOp>(opInst);
1102 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
1105 assert(isByRef.size() == wsloopOp.getNumReductionVars());
1109 wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
1112 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getStep()[0]);
1113 llvm::Type *ivType = step->getType();
1114 llvm::Value *chunk =
nullptr;
1115 if (wsloopOp.getScheduleChunkVar()) {
1116 llvm::Value *chunkVar =
1117 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunkVar());
1118 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
1123 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1127 wsloopOp.getNumReductionVars());
1131 wsloopOp.getRegion().getArguments();
1134 wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
1135 reductionDecls, privateReductionVariables, reductionVariableMap,
1143 moduleTranslation, reductionVariableMap);
1146 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1153 LogicalResult bodyGenStatus = success();
1154 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
1157 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1162 bodyInsertPoints.push_back(ip);
1164 if (loopInfos.size() != loopOp.getNumLoops() - 1)
1168 builder.restoreIP(ip);
1170 moduleTranslation, bodyGenStatus);
1179 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1180 llvm::Value *lowerBound =
1181 moduleTranslation.
lookupValue(loopOp.getLowerBound()[i]);
1182 llvm::Value *upperBound =
1183 moduleTranslation.
lookupValue(loopOp.getUpperBound()[i]);
1184 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getStep()[i]);
1189 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1190 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1192 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
1193 computeIP = loopInfos.front()->getPreheaderIP();
1195 loopInfos.push_back(ompBuilder->createCanonicalLoop(
1196 loc, bodyGen, lowerBound, upperBound, step,
1197 true, loopOp.getInclusive(), computeIP));
1199 if (failed(bodyGenStatus))
1205 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
1206 llvm::CanonicalLoopInfo *loopInfo =
1207 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
1212 bool isOrdered = wsloopOp.getOrderedVal().has_value();
1213 std::optional<omp::ScheduleModifier> scheduleModifier =
1214 wsloopOp.getScheduleModifier();
1215 bool isSimd = wsloopOp.getSimdModifier();
1217 ompBuilder->applyWorkshareLoop(
1218 ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
1219 convertToScheduleKind(schedule), chunk, isSimd,
1220 scheduleModifier == omp::ScheduleModifier::monotonic,
1221 scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
1227 builder.restoreIP(afterIP);
1231 allocaIP, reductionDecls,
1232 privateReductionVariables, isByRef);
1245 : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1246 privateArgBeginIdx(opInst.getNumReductionVars()),
1247 privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1248 auto privateVarsIt = privateVars.begin();
1250 for (
size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1251 ++argIdx, ++privateVarsIt)
1253 *privateVarsIt, region);
1257 auto privateVarsIt = privateVars.begin();
1259 for (
size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1260 ++argIdx, ++privateVarsIt)
1262 region.getArgument(argIdx), region);
1268 unsigned privateArgBeginIdx;
1269 unsigned privateArgEndIdx;
1273 static LogicalResult
1276 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1279 assert(isByRef.size() == opInst.getNumReductionVars());
1283 LogicalResult bodyGenStatus = success();
1290 opInst.getNumReductionVars());
1292 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1297 opInst.getRegion().getArguments().slice(
1298 opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(),
1299 opInst.getNumReductionVars());
1302 allocaIP, reductionDecls, privateReductionVariables,
1303 reductionVariableMap, isByRef);
1306 builder.restoreIP(allocaIP);
1307 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1309 InsertPointTy(allocaIP.getBlock(),
1310 allocaIP.getBlock()->getTerminator()->getIterator());
1312 for (
unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1316 byRefVars[i] = builder.CreateAlloca(
1317 moduleTranslation.
convertType(reductionDecls[i].getType()));
1321 builder.SetInsertPoint(initBlock->getFirstNonPHIOrDbgOrAlloca());
1323 for (
unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1329 reductionDecls[i].getInitializerRegion(),
"omp.reduction.neutral",
1330 builder, moduleTranslation, &phis)))
1331 bodyGenStatus = failure();
1332 assert(phis.size() == 1 &&
1333 "expected one value to be yielded from the "
1334 "reduction neutral element declaration region");
1338 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1343 builder.CreateStore(phis[0], byRefVars[i]);
1345 privateReductionVariables[i] = byRefVars[i];
1346 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1347 reductionVariableMap.try_emplace(opInst.getReductionVars()[i], phis[0]);
1350 builder.CreateStore(phis[0], privateReductionVariables[i]);
1356 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1363 moduleTranslation, reductionVariableMap);
1368 moduleTranslation, allocaIP);
1371 builder.restoreIP(codeGenIP);
1374 moduleTranslation, bodyGenStatus);
1377 if (opInst.getNumReductionVars() > 0) {
1383 owningReductionGens, owningAtomicReductionGens,
1384 privateReductionVariables, reductionInfos);
1387 builder.SetInsertPoint(regionBlock->getTerminator());
1390 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1391 builder.SetInsertPoint(tempTerminator);
1393 llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
1394 ompBuilder->createReductions(builder.saveIP(), allocaIP,
1395 reductionInfos, isByRef,
false);
1396 if (!contInsertPoint.getBlock()) {
1397 bodyGenStatus = opInst->emitOpError() <<
"failed to convert reductions";
1401 tempTerminator->eraseFromParent();
1402 builder.restoreIP(contInsertPoint);
1412 auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1413 llvm::Value &, llvm::Value &vPtr,
1414 llvm::Value *&replacementValue) -> InsertPointTy {
1415 replacementValue = &vPtr;
1420 auto [privVar, privatizerClone] =
1421 [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1422 if (!opInst.getPrivateVars().empty()) {
1423 auto privVars = opInst.getPrivateVars();
1424 auto privatizers = opInst.getPrivatizers();
1426 for (
auto [privVar, privatizerAttr] :
1427 llvm::zip_equal(privVars, *privatizers)) {
1430 llvm::Value *llvmPrivVar = moduleTranslation.
lookupValue(privVar);
1431 if (llvmPrivVar != &vPtr)
1434 SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1435 omp::PrivateClauseOp privatizer =
1436 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1447 auto clone = llvm::cast<mlir::omp::PrivateClauseOp>(
1448 opCloner.
clone(*privatizer));
1451 unsigned counter = 0;
1453 privatizer.getSymName(),
1454 [&](llvm::StringRef candidate) {
1455 return SymbolTable::lookupNearestSymbolFrom(
1456 opInst, StringAttr::get(&context, candidate)) !=
1461 clone.setSymName(cloneName);
1462 return {privVar,
clone};
1470 Region &allocRegion = privatizerClone.getAllocRegion();
1473 if (privatizerClone.getDataSharingType() ==
1474 omp::DataSharingClauseType::FirstPrivate) {
1475 auto oldAllocBackBlock = std::prev(allocRegion.
end());
1476 omp::YieldOp oldAllocYieldOp =
1477 llvm::cast<omp::YieldOp>(oldAllocBackBlock->getTerminator());
1479 Region ©Region = privatizerClone.getCopyRegion();
1486 auto newCopyRegionFrontBlock = std::next(oldAllocBackBlock);
1494 &*newCopyRegionFrontBlock, &*oldAllocBackBlock,
1495 {allocRegion.
getArgument(0), oldAllocYieldOp.getOperand(0)});
1499 oldAllocYieldOp.erase();
1508 auto oldIP = builder.saveIP();
1509 builder.restoreIP(allocaIP);
1513 moduleTranslation, &yieldedValues))) {
1514 opInst.emitError(
"failed to inline `alloc` region of an `omp.private` "
1515 "op in the parallel region");
1516 bodyGenStatus = failure();
1517 privatizerClone.erase();
1519 assert(yieldedValues.size() == 1);
1520 replacementValue = yieldedValues.front();
1524 privateVariables.push_back(replacementValue);
1525 privatizerClones.push_back(privatizerClone);
1528 builder.restoreIP(oldIP);
1536 auto finiCB = [&](InsertPointTy codeGenIP) {
1537 InsertPointTy oldIP = builder.saveIP();
1538 builder.restoreIP(codeGenIP);
1543 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
1544 [](omp::DeclareReductionOp reductionDecl) {
1545 return &reductionDecl.getCleanupRegion();
1548 reductionCleanupRegions, privateReductionVariables,
1549 moduleTranslation, builder,
"omp.reduction.cleanup")))
1550 bodyGenStatus = failure();
1553 llvm::transform(privatizerClones, std::back_inserter(privateCleanupRegions),
1554 [](omp::PrivateClauseOp privatizer) {
1555 return &privatizer.getDeallocRegion();
1559 privateCleanupRegions, privateVariables, moduleTranslation, builder,
1560 "omp.private.dealloc",
false)))
1561 bodyGenStatus = failure();
1563 builder.restoreIP(oldIP);
1566 llvm::Value *ifCond =
nullptr;
1567 if (
auto ifExprVar = opInst.getIfExpr())
1568 ifCond = moduleTranslation.
lookupValue(ifExprVar);
1569 llvm::Value *numThreads =
nullptr;
1570 if (
auto numThreadsVar = opInst.getNumThreadsVar())
1571 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
1572 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1573 if (
auto bind = opInst.getProcBindVal())
1576 bool isCancellable =
false;
1578 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1580 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1583 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1584 ifCond, numThreads, pbKind, isCancellable));
1586 for (mlir::omp::PrivateClauseOp privatizerClone : privatizerClones)
1587 privatizerClone.erase();
1589 return bodyGenStatus;
1593 static llvm::omp::OrderKind
1596 return llvm::omp::OrderKind::OMP_ORDER_unknown;
1598 case omp::ClauseOrderKind::Concurrent:
1599 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
1601 llvm_unreachable(
"Unknown ClauseOrderKind kind");
1605 static LogicalResult
1608 auto simdOp = cast<omp::SimdOp>(opInst);
1609 auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
1611 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1618 LogicalResult bodyGenStatus = success();
1619 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
1622 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1627 bodyInsertPoints.push_back(ip);
1629 if (loopInfos.size() != loopOp.getNumLoops() - 1)
1633 builder.restoreIP(ip);
1635 moduleTranslation, bodyGenStatus);
1644 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1645 llvm::Value *lowerBound =
1646 moduleTranslation.
lookupValue(loopOp.getLowerBound()[i]);
1647 llvm::Value *upperBound =
1648 moduleTranslation.
lookupValue(loopOp.getUpperBound()[i]);
1649 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getStep()[i]);
1654 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1655 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1657 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
1659 computeIP = loopInfos.front()->getPreheaderIP();
1661 loopInfos.push_back(ompBuilder->createCanonicalLoop(
1662 loc, bodyGen, lowerBound, upperBound, step,
1663 true,
true, computeIP));
1665 if (failed(bodyGenStatus))
1670 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
1671 llvm::CanonicalLoopInfo *loopInfo =
1672 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
1674 llvm::ConstantInt *simdlen =
nullptr;
1675 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
1676 simdlen = builder.getInt64(simdlenVar.value());
1678 llvm::ConstantInt *safelen =
nullptr;
1679 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
1680 safelen = builder.getInt64(safelenVar.value());
1682 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
1684 ompBuilder->applySimd(loopInfo, alignedVars,
1686 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
1688 order, simdlen, safelen);
1690 builder.restoreIP(afterIP);
1695 static llvm::AtomicOrdering
1698 return llvm::AtomicOrdering::Monotonic;
1701 case omp::ClauseMemoryOrderKind::Seq_cst:
1702 return llvm::AtomicOrdering::SequentiallyConsistent;
1703 case omp::ClauseMemoryOrderKind::Acq_rel:
1704 return llvm::AtomicOrdering::AcquireRelease;
1705 case omp::ClauseMemoryOrderKind::Acquire:
1706 return llvm::AtomicOrdering::Acquire;
1707 case omp::ClauseMemoryOrderKind::Release:
1708 return llvm::AtomicOrdering::Release;
1709 case omp::ClauseMemoryOrderKind::Relaxed:
1710 return llvm::AtomicOrdering::Monotonic;
1712 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
1716 static LogicalResult
1720 auto readOp = cast<omp::AtomicReadOp>(opInst);
1723 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1726 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
1727 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
1729 llvm::Type *elementType =
1730 moduleTranslation.
convertType(readOp.getElementType());
1732 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
1733 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
1734 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
1739 static LogicalResult
1742 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
1745 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1747 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
1748 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
1749 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
1750 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
1752 builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
1760 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
1761 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
1762 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
1763 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
1764 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
1765 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
1766 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
1767 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
1768 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
1769 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
1773 static LogicalResult
1775 llvm::IRBuilderBase &builder,
1780 auto &innerOpList = opInst.getRegion().front().getOperations();
1781 bool isXBinopExpr{
false};
1782 llvm::AtomicRMWInst::BinOp binop;
1784 llvm::Value *llvmExpr =
nullptr;
1785 llvm::Value *llvmX =
nullptr;
1786 llvm::Type *llvmXElementType =
nullptr;
1787 if (innerOpList.size() == 2) {
1793 opInst.getRegion().getArgument(0))) {
1794 return opInst.emitError(
"no atomic update operation with region argument"
1795 " as operand found inside atomic.update region");
1798 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
1800 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
1804 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
1806 llvmX = moduleTranslation.
lookupValue(opInst.getX());
1808 opInst.getRegion().getArgument(0).getType());
1809 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1813 llvm::AtomicOrdering atomicOrdering =
1817 LogicalResult updateGenStatus = success();
1818 auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus](
1819 llvm::Value *atomicx,
1820 llvm::IRBuilder<> &builder) -> llvm::Value * {
1822 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
1823 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
1824 if (failed(moduleTranslation.
convertBlock(bb,
true, builder))) {
1825 updateGenStatus = (opInst.emitError()
1826 <<
"unable to convert update operation to llvm IR");
1829 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
1830 assert(yieldop && yieldop.getResults().size() == 1 &&
1831 "terminator must be omp.yield op and it must have exactly one "
1833 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
1838 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1839 builder.restoreIP(ompBuilder->createAtomicUpdate(
1840 ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn,
1842 return updateGenStatus;
1845 static LogicalResult
1847 llvm::IRBuilderBase &builder,
1851 bool isXBinopExpr =
false, isPostfixUpdate =
false;
1852 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
1854 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
1855 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
1857 assert((atomicUpdateOp || atomicWriteOp) &&
1858 "internal op must be an atomic.update or atomic.write op");
1860 if (atomicWriteOp) {
1861 isPostfixUpdate =
true;
1862 mlirExpr = atomicWriteOp.getExpr();
1864 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
1865 atomicCaptureOp.getAtomicUpdateOp().getOperation();
1866 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
1867 bool isRegionArgUsed{
false};
1870 for (
Operation &innerOp : innerOpList) {
1871 if (innerOp.getNumOperands() == 2) {
1873 if (!llvm::is_contained(innerOp.getOperands(),
1874 atomicUpdateOp.getRegion().getArgument(0)))
1876 isRegionArgUsed =
true;
1878 innerOp.getNumOperands() > 0 &&
1879 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
1881 (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
1885 if (!isRegionArgUsed)
1886 return atomicUpdateOp.emitError(
1887 "no atomic update operation with region argument"
1888 " as operand found inside atomic.update region");
1891 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
1892 llvm::Value *llvmX =
1893 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
1894 llvm::Value *llvmV =
1895 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
1896 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
1897 atomicCaptureOp.getAtomicReadOp().getElementType());
1898 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1901 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
1905 llvm::AtomicOrdering atomicOrdering =
1908 LogicalResult updateGenStatus = success();
1909 auto updateFn = [&](llvm::Value *atomicx,
1910 llvm::IRBuilder<> &builder) -> llvm::Value * {
1912 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
1913 Block &bb = *atomicUpdateOp.getRegion().
begin();
1914 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
1916 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
1917 if (failed(moduleTranslation.
convertBlock(bb,
true, builder))) {
1918 updateGenStatus = (atomicUpdateOp.emitError()
1919 <<
"unable to convert update operation to llvm IR");
1922 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
1923 assert(yieldop && yieldop.getResults().size() == 1 &&
1924 "terminator must be omp.yield op and it must have exactly one "
1926 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
1931 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1932 builder.restoreIP(ompBuilder->createAtomicCapture(
1933 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
1934 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr));
1935 return updateGenStatus;
1940 static LogicalResult
1943 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1944 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
1946 Value symAddr = threadprivateOp.getSymAddr();
1948 if (!isa<LLVM::AddressOfOp>(symOp))
1949 return opInst.
emitError(
"Addressing symbol not found");
1950 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
1952 LLVM::GlobalOp global =
1953 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
1954 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
1955 llvm::Type *type = globalValue->getValueType();
1956 llvm::TypeSize typeSize =
1957 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
1959 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
1960 llvm::StringRef suffix = llvm::StringRef(
".cache", 6);
1961 std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
1962 llvm::Value *callInst =
1964 ompLoc, globalValue, size, cacheName);
1969 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
1971 switch (deviceClause) {
1972 case mlir::omp::DeclareTargetDeviceType::host:
1973 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
1975 case mlir::omp::DeclareTargetDeviceType::nohost:
1976 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
1978 case mlir::omp::DeclareTargetDeviceType::any:
1979 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
1982 llvm_unreachable(
"unhandled device clause");
1985 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
1987 mlir::omp::DeclareTargetCaptureClause captureClasue) {
1988 switch (captureClasue) {
1989 case mlir::omp::DeclareTargetCaptureClause::to:
1990 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
1991 case mlir::omp::DeclareTargetCaptureClause::link:
1992 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
1993 case mlir::omp::DeclareTargetCaptureClause::enter:
1994 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
1996 llvm_unreachable(
"unhandled capture clause");
2001 llvm::OpenMPIRBuilder &ompBuilder) {
2003 llvm::raw_svector_ostream os(suffix);
2005 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
2006 auto fileInfoCallBack = [&loc]() {
2007 return std::pair<std::string, uint64_t>(
2008 llvm::StringRef(loc.getFilename()), loc.getLine());
2012 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
2014 os <<
"_decl_tgt_ref_ptr";
2020 if (
auto addressOfOp =
2021 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
2022 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
2023 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
2024 if (
auto declareTargetGlobal =
2025 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
2026 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2027 mlir::omp::DeclareTargetCaptureClause::link)
2036 static llvm::Value *
2043 if (
auto addressOfOp =
2044 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
2045 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
2046 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
2047 addressOfOp.getGlobalName()))) {
2049 if (
auto declareTargetGlobal =
2050 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
2051 gOp.getOperation())) {
2055 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
2056 mlir::omp::DeclareTargetCaptureClause::link) ||
2057 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2058 mlir::omp::DeclareTargetCaptureClause::to &&
2059 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
2063 if (gOp.getSymName().contains(suffix))
2068 (gOp.getSymName().str() + suffix.str()).str());
2100 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
2105 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
2106 arrTy.getElementType()))
2122 Operation *clauseOp, llvm::Value *basePointer,
2123 llvm::Type *baseType, llvm::IRBuilderBase &builder,
2125 if (
auto memberClause =
2126 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
2131 if (!memberClause.getBounds().empty()) {
2132 llvm::Value *elementCount = builder.getInt64(1);
2133 for (
auto bounds : memberClause.getBounds()) {
2134 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2135 bounds.getDefiningOp())) {
2140 elementCount = builder.CreateMul(
2144 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
2145 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
2146 builder.getInt64(1)));
2153 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
2161 return builder.CreateMul(elementCount,
2162 builder.getInt64(underlyingTypeSzInBits / 8));
2173 llvm::IRBuilderBase &builder) {
2175 if (
auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2176 mapValue.getDefiningOp())) {
2178 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2183 if (llvm::Value *refPtr =
2185 moduleTranslation)) {
2187 mapData.BasePointers.push_back(refPtr);
2190 mapData.BasePointers.push_back(mapData.
OriginalValue.back());
2194 moduleTranslation.
convertType(mapOp.getVarType()));
2195 mapData.Sizes.push_back(
2196 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2197 mapData.
BaseType.back(), builder, moduleTranslation));
2198 mapData.
MapClause.push_back(mapOp.getOperation());
2199 mapData.Types.push_back(
2200 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2203 mapData.DevicePointers.push_back(
2215 if (
auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2216 mapValue.getDefiningOp())) {
2217 for (
auto member : map.getMembers()) {
2218 if (member == mapOp) {
2229 mlir::omp::MapInfoOp memberOp) {
2230 auto *res = llvm::find(mapData.
MapClause, memberOp);
2231 assert(res != mapData.
MapClause.end() &&
2232 "MapInfoOp for member not found in MapData, cannot return index");
2233 return std::distance(mapData.
MapClause.begin(), res);
2236 static mlir::omp::MapInfoOp
2241 if (indexAttr.size() == 1)
2242 if (
auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2243 mapInfo.getMembers()[0].getDefiningOp()))
2248 std::iota(indices.begin(), indices.end(), 0);
2250 llvm::sort(indices.begin(), indices.end(),
2251 [&](
const size_t a,
const size_t b) {
2252 auto indexValues = indexAttr.getValues<int32_t>();
2253 for (int i = 0; i < shape[1]; ++i) {
2254 int aIndex = indexValues[a * shape[1] + i];
2255 int bIndex = indexValues[b * shape[1] + i];
2257 if (aIndex == bIndex)
2260 if (aIndex != -1 && bIndex == -1)
2263 if (aIndex == -1 && bIndex != -1)
2267 if (aIndex < bIndex)
2270 if (bIndex < aIndex)
2280 return llvm::cast<mlir::omp::MapInfoOp>(
2281 mapInfo.getMembers()[indices.front()].getDefiningOp());
2303 std::vector<llvm::Value *>
2305 llvm::IRBuilderBase &builder,
bool isArrayTy,
2307 std::vector<llvm::Value *> idx;
2318 idx.push_back(builder.getInt64(0));
2319 for (
int i = bounds.size() - 1; i >= 0; --i) {
2320 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2321 bounds[i].getDefiningOp())) {
2322 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
2344 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
2345 for (
size_t i = 1; i < bounds.size(); ++i) {
2346 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2347 bounds[i].getDefiningOp())) {
2348 dimensionIndexSizeOffset.push_back(builder.CreateMul(
2349 moduleTranslation.
lookupValue(boundOp.getExtent()),
2350 dimensionIndexSizeOffset[i - 1]));
2358 for (
int i = bounds.size() - 1; i >= 0; --i) {
2359 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2360 bounds[i].getDefiningOp())) {
2362 idx.emplace_back(builder.CreateMul(
2363 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
2364 dimensionIndexSizeOffset[i]));
2366 idx.back() = builder.CreateAdd(
2367 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
2368 boundOp.getLowerBound()),
2369 dimensionIndexSizeOffset[i]));
2394 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl,
2396 uint64_t mapDataIndex,
bool isTargetParams) {
2398 combinedInfo.Types.emplace_back(
2400 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
2401 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
2402 combinedInfo.DevicePointers.emplace_back(
2405 mapData.
MapClause[mapDataIndex]->getLoc(), ompBuilder));
2406 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2416 llvm::cast<mlir::omp::MapInfoOp>(mapData.
MapClause[mapDataIndex]);
2418 llvm::Value *lowAddr, *highAddr;
2419 if (!parentClause.getPartialMap()) {
2420 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
2421 builder.getPtrTy());
2422 highAddr = builder.CreatePointerCast(
2423 builder.CreateConstGEP1_32(mapData.
BaseType[mapDataIndex],
2424 mapData.Pointers[mapDataIndex], 1),
2425 builder.getPtrTy());
2426 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2429 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.
MapClause[mapDataIndex]);
2432 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
2433 builder.getPtrTy());
2436 highAddr = builder.CreatePointerCast(
2437 builder.CreateGEP(mapData.
BaseType[lastMemberIdx],
2438 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
2439 builder.getPtrTy());
2440 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
2443 llvm::Value *size = builder.CreateIntCast(
2444 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
2445 builder.getInt64Ty(),
2447 combinedInfo.Sizes.push_back(size);
2454 llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2455 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2457 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
2458 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
2459 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2467 if (!parentClause.getPartialMap()) {
2468 combinedInfo.Types.emplace_back(mapFlag);
2469 combinedInfo.DevicePointers.emplace_back(
2472 mapData.
MapClause[mapDataIndex]->getLoc(), ompBuilder));
2473 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2474 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2475 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
2477 return memberOfFlag;
2489 if (mapOp.getVarPtrPtr())
2504 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl,
2506 uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
2509 llvm::cast<mlir::omp::MapInfoOp>(mapData.
MapClause[mapDataIndex]);
2511 for (
auto mappedMembers : parentClause.getMembers()) {
2513 llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
2516 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
2521 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
2522 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2523 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
2524 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2526 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2528 combinedInfo.Types.emplace_back(mapFlag);
2529 combinedInfo.DevicePointers.emplace_back(
2531 combinedInfo.Names.emplace_back(
2533 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2534 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
2535 combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
2542 bool isTargetParams,
int mapDataParentIdx = -1) {
2546 auto mapFlag = mapData.Types[mapDataIdx];
2548 llvm::cast<mlir::omp::MapInfoOp>(mapData.
MapClause[mapDataIdx]);
2552 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2555 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2557 if (mapInfoOp.getMapCaptureType().value() ==
2558 mlir::omp::VariableCaptureKind::ByCopy &&
2560 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2565 if (mapDataParentIdx >= 0)
2566 combinedInfo.BasePointers.emplace_back(
2567 mapData.BasePointers[mapDataParentIdx]);
2569 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
2571 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
2572 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
2573 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
2574 combinedInfo.Types.emplace_back(mapFlag);
2575 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
2580 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl,
2582 uint64_t mapDataIndex,
bool isTargetParams) {
2584 llvm::cast<mlir::omp::MapInfoOp>(mapData.
MapClause[mapDataIndex]);
2589 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
2590 auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
2591 parentClause.getMembers()[0].getDefiningOp());
2608 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
2610 combinedInfo, mapData, mapDataIndex, isTargetParams);
2612 combinedInfo, mapData, mapDataIndex,
2613 memberOfParentFlag);
2623 llvm::IRBuilderBase &builder) {
2624 for (
size_t i = 0; i < mapData.
MapClause.size(); ++i) {
2628 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.
MapClause[i]);
2629 mlir::omp::VariableCaptureKind captureKind =
2630 mapOp.getMapCaptureType().value_or(
2631 mlir::omp::VariableCaptureKind::ByRef);
2641 switch (captureKind) {
2642 case mlir::omp::VariableCaptureKind::ByRef: {
2643 llvm::Value *newV = mapData.Pointers[i];
2645 moduleTranslation, builder, mapData.
BaseType[i]->isArrayTy(),
2648 newV = builder.CreateLoad(builder.getPtrTy(), newV);
2650 if (!offsetIdx.empty())
2651 newV = builder.CreateInBoundsGEP(mapData.
BaseType[i], newV, offsetIdx,
2653 mapData.Pointers[i] = newV;
2655 case mlir::omp::VariableCaptureKind::ByCopy: {
2656 llvm::Type *type = mapData.
BaseType[i];
2658 if (mapData.Pointers[i]->getType()->isPointerTy())
2659 newV = builder.CreateLoad(type, mapData.Pointers[i]);
2661 newV = mapData.Pointers[i];
2664 auto curInsert = builder.saveIP();
2666 auto *memTempAlloc =
2667 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
2668 builder.restoreIP(curInsert);
2670 builder.CreateStore(newV, memTempAlloc);
2671 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
2674 mapData.Pointers[i] = newV;
2675 mapData.BasePointers[i] = newV;
2677 case mlir::omp::VariableCaptureKind::This:
2678 case mlir::omp::VariableCaptureKind::VLAType:
2679 mapData.
MapClause[i]->emitOpError(
"Unhandled capture kind");
2694 bool isTargetParams =
false) {
2711 auto fail = [&combinedInfo]() ->
void {
2712 combinedInfo.BasePointers.clear();
2713 combinedInfo.Pointers.clear();
2714 combinedInfo.DevicePointers.clear();
2715 combinedInfo.Sizes.clear();
2716 combinedInfo.Types.clear();
2717 combinedInfo.Names.clear();
2725 for (
size_t i = 0; i < mapData.
MapClause.size(); ++i) {
2731 auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.
MapClause[i]);
2732 if (!mapInfoOp.getMembers().empty()) {
2734 combinedInfo, mapData, i, isTargetParams);
2741 auto findMapInfo = [&combinedInfo](llvm::Value *val,
unsigned &index) {
2743 for (llvm::Value *basePtr : combinedInfo.BasePointers) {
2751 auto addDevInfos = [&, fail](
auto devOperands,
auto devOpType) ->
void {
2752 for (
const auto &devOp : devOperands) {
2754 if (!isa<LLVM::LLVMPointerType>(devOp.getType()))
2757 llvm::Value *mapOpValue = moduleTranslation.
lookupValue(devOp);
2761 if (findMapInfo(mapOpValue, infoIndex)) {
2762 combinedInfo.Types[infoIndex] |=
2763 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2764 combinedInfo.DevicePointers[infoIndex] = devOpType;
2766 combinedInfo.BasePointers.emplace_back(mapOpValue);
2767 combinedInfo.Pointers.emplace_back(mapOpValue);
2768 combinedInfo.DevicePointers.emplace_back(devOpType);
2769 combinedInfo.Names.emplace_back(
2771 combinedInfo.Types.emplace_back(
2772 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2773 combinedInfo.Sizes.emplace_back(builder.getInt64(0));
2778 addDevInfos(devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2779 addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2782 static LogicalResult
2785 llvm::Value *ifCond =
nullptr;
2786 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
2790 llvm::omp::RuntimeFunction RTLFn;
2795 LogicalResult result =
2797 .Case([&](omp::TargetDataOp dataOp) {
2798 if (
auto ifExprVar = dataOp.getIfExpr())
2799 ifCond = moduleTranslation.
lookupValue(ifExprVar);
2801 if (
auto devId = dataOp.getDevice())
2803 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2804 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2805 deviceID = intAttr.getInt();
2807 mapOperands = dataOp.getMapOperands();
2808 useDevPtrOperands = dataOp.getUseDevicePtr();
2809 useDevAddrOperands = dataOp.getUseDeviceAddr();
2812 .Case([&](omp::TargetEnterDataOp enterDataOp) {
2813 if (enterDataOp.getNowait())
2814 return (LogicalResult)(enterDataOp.emitError(
2815 "`nowait` is not supported yet"));
2817 if (
auto ifExprVar = enterDataOp.getIfExpr())
2818 ifCond = moduleTranslation.
lookupValue(ifExprVar);
2820 if (
auto devId = enterDataOp.getDevice())
2822 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2823 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2824 deviceID = intAttr.getInt();
2825 RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
2826 mapOperands = enterDataOp.getMapOperands();
2829 .Case([&](omp::TargetExitDataOp exitDataOp) {
2830 if (exitDataOp.getNowait())
2831 return (LogicalResult)(exitDataOp.emitError(
2832 "`nowait` is not supported yet"));
2834 if (
auto ifExprVar = exitDataOp.getIfExpr())
2835 ifCond = moduleTranslation.
lookupValue(ifExprVar);
2837 if (
auto devId = exitDataOp.getDevice())
2839 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2840 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2841 deviceID = intAttr.getInt();
2843 RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
2844 mapOperands = exitDataOp.getMapOperands();
2847 .Case([&](omp::TargetUpdateOp updateDataOp) {
2848 if (updateDataOp.getNowait())
2849 return (LogicalResult)(updateDataOp.emitError(
2850 "`nowait` is not supported yet"));
2852 if (
auto ifExprVar = updateDataOp.getIfExpr())
2853 ifCond = moduleTranslation.
lookupValue(ifExprVar);
2855 if (
auto devId = updateDataOp.getDevice())
2857 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2858 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2859 deviceID = intAttr.getInt();
2861 RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
2862 mapOperands = updateDataOp.getMapOperands();
2866 return op->
emitError(
"unsupported OpenMP operation: ")
2873 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2883 builder.restoreIP(codeGenIP);
2884 if (
auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
2885 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
2886 useDevPtrOperands, useDevAddrOperands);
2888 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
2890 return combinedInfo;
2893 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
2896 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
2897 LogicalResult bodyGenStatus = success();
2898 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
2899 assert(isa<omp::TargetDataOp>(op) &&
2900 "BodyGen requested for non TargetDataOp");
2901 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
2902 switch (bodyGenType) {
2903 case BodyGenTy::Priv:
2905 if (!info.DevicePtrInfoMap.empty()) {
2906 builder.restoreIP(codeGenIP);
2907 unsigned argIndex = 0;
2908 for (
auto &devPtrOp : useDevPtrOperands) {
2909 llvm::Value *mapOpValue = moduleTranslation.
lookupValue(devPtrOp);
2912 info.DevicePtrInfoMap[mapOpValue].second);
2916 for (
auto &devAddrOp : useDevAddrOperands) {
2917 llvm::Value *mapOpValue = moduleTranslation.
lookupValue(devAddrOp);
2919 auto *LI = builder.CreateLoad(
2920 builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
2921 moduleTranslation.
mapValue(arg, LI);
2926 builder, moduleTranslation);
2929 case BodyGenTy::DupNoPriv:
2931 case BodyGenTy::NoPriv:
2933 if (info.DevicePtrInfoMap.empty()) {
2934 builder.restoreIP(codeGenIP);
2936 builder, moduleTranslation);
2940 return builder.saveIP();
2943 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2944 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2946 if (isa<omp::TargetDataOp>(op)) {
2947 builder.restoreIP(ompBuilder->createTargetData(
2948 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
2949 info, genMapInfoCB,
nullptr, bodyGenCB));
2951 builder.restoreIP(ompBuilder->createTargetData(
2952 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
2953 info, genMapInfoCB, &RTLFn));
2956 return bodyGenStatus;
2964 if (!cast<mlir::ModuleOp>(op))
2969 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
2970 attribute.getOpenmpDeviceVersion());
2972 if (attribute.getNoGpuLib())
2975 ompBuilder->createGlobalFlag(
2976 attribute.getDebugKind() ,
2977 "__omp_rtl_debug_kind");
2978 ompBuilder->createGlobalFlag(
2980 .getAssumeTeamsOversubscription()
2982 "__omp_rtl_assume_teams_oversubscription");
2983 ompBuilder->createGlobalFlag(
2985 .getAssumeThreadsOversubscription()
2987 "__omp_rtl_assume_threads_oversubscription");
2988 ompBuilder->createGlobalFlag(
2989 attribute.getAssumeNoThreadState() ,
2990 "__omp_rtl_assume_no_thread_state");
2991 ompBuilder->createGlobalFlag(
2993 .getAssumeNoNestedParallelism()
2995 "__omp_rtl_assume_no_nested_parallelism");
3000 omp::TargetOp targetOp,
3001 llvm::StringRef parentName =
"") {
3002 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
3004 assert(fileLoc &&
"No file found from location");
3005 StringRef fileName = fileLoc.getFilename().getValue();
3007 llvm::sys::fs::UniqueID id;
3008 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
3009 targetOp.emitError(
"Unable to get unique ID for file");
3013 uint64_t line = fileLoc.getLine();
3014 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
3015 id.getFile(), line);
3020 auto targetOp = cast<omp::TargetOp>(opInst);
3021 if (targetOp.getIfExpr()) {
3022 opInst.
emitError(
"If clause not yet supported");
3026 if (targetOp.getDevice()) {
3027 opInst.
emitError(
"Device clause not yet supported");
3031 if (targetOp.getThreadLimit()) {
3032 opInst.
emitError(
"Thread limit clause not yet supported");
3036 if (targetOp.getNowait()) {
3037 opInst.
emitError(
"Nowait clause not yet supported");
3047 llvm::IRBuilderBase &builder, llvm::Function *func) {
3048 for (
size_t i = 0; i < mapData.
MapClause.size(); ++i) {
3068 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.
OriginalValue[i]))
3069 convertUsersOfConstantsToInstructions(constant, func,
false);
3077 userVec.push_back(user);
3079 for (llvm::User *user : userVec) {
3080 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
3081 if (insn->getFunction() == func) {
3082 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
3083 mapData.BasePointers[i]);
3084 load->moveBefore(insn);
3132 static llvm::IRBuilderBase::InsertPoint
3134 llvm::Value *input, llvm::Value *&retVal,
3135 llvm::IRBuilderBase &builder,
3136 llvm::OpenMPIRBuilder &ompBuilder,
3138 llvm::IRBuilderBase::InsertPoint allocaIP,
3139 llvm::IRBuilderBase::InsertPoint codeGenIP) {
3140 builder.restoreIP(allocaIP);
3142 mlir::omp::VariableCaptureKind capture =
3143 mlir::omp::VariableCaptureKind::ByRef;
3146 for (
size_t i = 0; i < mapData.
MapClause.size(); ++i)
3148 if (
auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
3150 capture = mapOp.getMapCaptureType().value_or(
3151 mlir::omp::VariableCaptureKind::ByRef);
3157 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
3158 unsigned int defaultAS =
3159 ompBuilder.M.getDataLayout().getProgramAddressSpace();
3162 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
3164 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
3165 v = builder.CreatePointerBitCastOrAddrSpaceCast(
3166 v, arg.getType()->getPointerTo(defaultAS));
3168 builder.CreateStore(&arg, v);
3170 builder.restoreIP(codeGenIP);
3173 case mlir::omp::VariableCaptureKind::ByCopy: {
3177 case mlir::omp::VariableCaptureKind::ByRef: {
3178 retVal = builder.CreateAlignedLoad(
3180 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
3183 case mlir::omp::VariableCaptureKind::This:
3184 case mlir::omp::VariableCaptureKind::VLAType:
3185 assert(
false &&
"Currently unsupported capture kind");
3189 return builder.saveIP();
3192 static LogicalResult
3200 auto targetOp = cast<omp::TargetOp>(opInst);
3201 auto &targetRegion = targetOp.getRegion();
3204 llvm::Function *llvmOutlinedFn =
nullptr;
3206 LogicalResult bodyGenStatus = success();
3207 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3208 auto bodyCB = [&](InsertPointTy allocaIP,
3209 InsertPointTy codeGenIP) -> InsertPointTy {
3212 llvm::Function *llvmParentFn =
3214 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
3215 assert(llvmParentFn && llvmOutlinedFn &&
3216 "Both parent and outlined functions must exist at this point");
3218 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
3219 attr.isStringAttribute())
3220 llvmOutlinedFn->addFnAttr(attr);
3222 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
3223 attr.isStringAttribute())
3224 llvmOutlinedFn->addFnAttr(attr);
3226 builder.restoreIP(codeGenIP);
3227 unsigned argIndex = 0;
3228 for (
auto &mapOp : mapOperands) {
3230 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
3231 llvm::Value *mapOpValue =
3232 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
3233 const auto &arg = targetRegion.front().getArgument(argIndex);
3234 moduleTranslation.
mapValue(arg, mapOpValue);
3238 targetRegion,
"omp.target", builder, moduleTranslation, bodyGenStatus);
3239 builder.SetInsertPoint(exitBlock);
3240 return builder.saveIP();
3243 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3244 StringRef parentName = parentFn.getName();
3246 llvm::TargetRegionEntryInfo entryInfo;
3251 int32_t defaultValTeams = -1;
3252 int32_t defaultValThreads = 0;
3254 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3262 auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
3264 builder.restoreIP(codeGenIP);
3265 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
3267 return combinedInfos;
3270 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
3271 llvm::Value *&retVal, InsertPointTy allocaIP,
3272 InsertPointTy codeGenIP) {
3280 if (!ompBuilder->Config.isTargetDevice()) {
3281 retVal = cast<llvm::Value>(&arg);
3286 *ompBuilder, moduleTranslation,
3287 allocaIP, codeGenIP);
3291 for (
size_t i = 0; i < mapOperands.size(); ++i) {
3303 moduleTranslation, dds);
3306 ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams,
3307 defaultValThreads, kernelInput, genMapInfoCB, bodyCB, argAccessorCB,
3316 return bodyGenStatus;
3319 static LogicalResult
3329 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
3330 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
3332 if (!offloadMod.getIsTargetDevice())
3335 omp::DeclareTargetDeviceType declareType =
3336 attribute.getDeviceType().getValue();
3338 if (declareType == omp::DeclareTargetDeviceType::host) {
3339 llvm::Function *llvmFunc =
3341 llvmFunc->dropAllReferences();
3342 llvmFunc->eraseFromParent();
3348 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
3349 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
3350 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
3352 bool isDeclaration = gOp.isDeclaration();
3353 bool isExternallyVisible =
3356 llvm::StringRef mangledName = gOp.getSymName();
3357 auto captureClause =
3363 std::vector<llvm::GlobalVariable *> generatedRefs;
3365 std::vector<llvm::Triple> targetTriple;
3366 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
3368 LLVM::LLVMDialect::getTargetTripleAttrName()));
3369 if (targetTripleAttr)
3370 targetTriple.emplace_back(targetTripleAttr.data());
3372 auto fileInfoCallBack = [&loc]() {
3373 std::string filename =
"";
3374 std::uint64_t lineNo = 0;
3377 filename = loc.getFilename().str();
3378 lineNo = loc.getLine();
3381 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
3385 ompBuilder->registerTargetGlobalVariable(
3386 captureClause, deviceClause, isDeclaration, isExternallyVisible,
3387 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
3388 generatedRefs,
false, targetTriple,
3390 gVal->getType(), gVal);
3392 if (ompBuilder->Config.isTargetDevice() &&
3393 (attribute.getCaptureClause().getValue() !=
3394 mlir::omp::DeclareTargetCaptureClause::to ||
3395 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3396 ompBuilder->getAddrOfDeclareTargetVar(
3397 captureClause, deviceClause, isDeclaration, isExternallyVisible,
3398 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
3399 generatedRefs,
false, targetTriple, gVal->getType(),
3417 if (
auto declareTargetIface =
3418 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3419 parentFn.getOperation()))
3420 if (declareTargetIface.isDeclareTarget() &&
3421 declareTargetIface.getDeclareTargetDeviceType() !=
3422 mlir::omp::DeclareTargetDeviceType::host)
3430 static LogicalResult
3437 .Case([&](omp::BarrierOp) {
3438 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
3441 .Case([&](omp::TaskwaitOp) {
3442 ompBuilder->createTaskwait(builder.saveIP());
3445 .Case([&](omp::TaskyieldOp) {
3446 ompBuilder->createTaskyield(builder.saveIP());
3449 .Case([&](omp::FlushOp) {
3458 ompBuilder->createFlush(builder.saveIP());
3461 .Case([&](omp::ParallelOp op) {
3464 .Case([&](omp::MaskedOp) {
3467 .Case([&](omp::MasterOp) {
3470 .Case([&](omp::CriticalOp) {
3473 .Case([&](omp::OrderedRegionOp) {
3476 .Case([&](omp::OrderedOp) {
3479 .Case([&](omp::WsloopOp) {
3482 .Case([&](omp::SimdOp) {
3485 .Case([&](omp::AtomicReadOp) {
3488 .Case([&](omp::AtomicWriteOp) {
3491 .Case([&](omp::AtomicUpdateOp op) {
3494 .Case([&](omp::AtomicCaptureOp op) {
3497 .Case([&](omp::SectionsOp) {
3500 .Case([&](omp::SingleOp op) {
3503 .Case([&](omp::TeamsOp op) {
3506 .Case([&](omp::TaskOp op) {
3509 .Case([&](omp::TaskgroupOp op) {
3512 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
3513 omp::CriticalDeclareOp>([](
auto op) {
3524 .Case([&](omp::ThreadprivateOp) {
3527 .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
3528 omp::TargetUpdateOp>([&](
auto op) {
3531 .Case([&](omp::TargetOp) {
3534 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
3542 return inst->
emitError(
"unsupported OpenMP operation: ")
3547 static LogicalResult
3553 static LogicalResult
3556 if (isa<omp::TargetOp>(op))
3558 if (isa<omp::TargetDataOp>(op))
3562 if (isa<omp::TargetOp>(oper)) {
3564 return WalkResult::interrupt();
3565 return WalkResult::skip();
3567 if (isa<omp::TargetDataOp>(oper)) {
3569 return WalkResult::interrupt();
3570 return WalkResult::skip();
3572 return WalkResult::advance();
3573 }).wasInterrupted();
3574 return failure(interrupted);
3581 class OpenMPDialectLLVMIRTranslationInterface
3602 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
3608 .Case(
"omp.is_target_device",
3610 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
3611 llvm::OpenMPIRBuilderConfig &config =
3613 config.setIsTargetDevice(deviceAttr.getValue());
3620 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
3621 llvm::OpenMPIRBuilderConfig &config =
3623 config.setIsGPU(gpuAttr.getValue());
3628 .Case(
"omp.host_ir_filepath",
3630 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
3631 llvm::OpenMPIRBuilder *ompBuilder =
3633 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
3640 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
3644 .Case(
"omp.version",
3646 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
3647 llvm::OpenMPIRBuilder *ompBuilder =
3649 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
3650 versionAttr.getVersion());
3655 .Case(
"omp.declare_target",
3657 if (
auto declareTargetAttr =
3658 dyn_cast<omp::DeclareTargetAttr>(attr))
3663 .Case(
"omp.requires",
3665 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
3666 using Requires = omp::ClauseRequires;
3667 Requires flags = requiresAttr.getValue();
3668 llvm::OpenMPIRBuilderConfig &config =
3670 config.setHasRequiresReverseOffload(
3671 bitEnumContainsAll(flags, Requires::reverse_offload));
3672 config.setHasRequiresUnifiedAddress(
3673 bitEnumContainsAll(flags, Requires::unified_address));
3674 config.setHasRequiresUnifiedSharedMemory(
3675 bitEnumContainsAll(flags, Requires::unified_shared_memory));
3676 config.setHasRequiresDynamicAllocators(
3677 bitEnumContainsAll(flags, Requires::dynamic_allocators));
3692 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
3693 Operation *op, llvm::IRBuilderBase &builder,
3697 if (ompBuilder->Config.isTargetDevice()) {
3708 registry.
insert<omp::OpenMPDialect>();
3710 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, const LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, const SmallVector< Value > &devPtrOperands={}, const SmallVector< Value > &devAddrOperands={}, bool isTargetParams=false)
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, mlir::OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
void collectMapDataFromMapOperands(MapInfoData &mapData, llvm::SmallVectorImpl< Value > &mapOperands, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClasue)
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, unsigned i)
Map input argument to all reduction initialization regions.
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool targetOpSupported(Operation &opInst)
static int getMapDataMemberIdx(MapInfoData &mapData, mlir::omp::MapInfoOp memberOp)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void buildDependData(std::optional< ArrayAttr > depends, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void allocByValReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static void collectReductionDecls(T loop, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given loop.
static mlir::omp::MapInfoOp getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(mlir::omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::BasicBlock * convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
A RAII class that on construction replaces the region arguments of the parallel op (which correspond ...
~OmpParallelOpConversionManager()
OmpParallelOpConversionManager(omp::ParallelOp opInst)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense integer vector or tensor object.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
WalkResult stackWalk(llvm::function_ref< WalkResult(const T &)> callback) const
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
BlockListType & getBlocks()
BlockArgument getArgument(unsigned i)
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallVector< bool, 4 > IsAMember
llvm::SmallVector< llvm::Value *, 4 > OriginalValue
llvm::SmallVector< bool, 4 > IsDeclareTarget
llvm::SmallVector< llvm::Type *, 4 > BaseType
void append(MapInfoData &CurInfo)
Append arrays in CurInfo.
llvm::SmallVector< mlir::Operation *, 4 > MapClause
RAII object calling stackPush/stackPop on construction/destruction.