24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Frontend/OpenMP/OMPConstants.h"
28 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/DebugInfoMetadata.h"
31 #include "llvm/IR/DerivedTypes.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/MDBuilder.h"
34 #include "llvm/IR/ReplaceConstant.h"
35 #include "llvm/Support/FileSystem.h"
36 #include "llvm/TargetParser/Triple.h"
37 #include "llvm/Transforms/Utils/ModuleUtils.h"
48 static llvm::omp::ScheduleKind
49 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
50 if (!schedKind.has_value())
51 return llvm::omp::OMP_SCHEDULE_Default;
52 switch (schedKind.value()) {
53 case omp::ClauseScheduleKind::Static:
54 return llvm::omp::OMP_SCHEDULE_Static;
55 case omp::ClauseScheduleKind::Dynamic:
56 return llvm::omp::OMP_SCHEDULE_Dynamic;
57 case omp::ClauseScheduleKind::Guided:
58 return llvm::omp::OMP_SCHEDULE_Guided;
59 case omp::ClauseScheduleKind::Auto:
60 return llvm::omp::OMP_SCHEDULE_Auto;
62 return llvm::omp::OMP_SCHEDULE_Runtime;
64 llvm_unreachable(
"unhandled schedule clause argument");
69 class OpenMPAllocaStackFrame
74 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
75 : allocaInsertPoint(allocaIP) {}
76 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
82 class OpenMPLoopInfoStackFrame
86 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
105 class PreviouslyReportedError
106 :
public llvm::ErrorInfo<PreviouslyReportedError> {
108 void log(raw_ostream &)
const override {
112 std::error_code convertToErrorCode()
const override {
114 "PreviouslyReportedError doesn't support ECError conversion");
132 class LinearClauseProcessor {
140 llvm::BasicBlock *linearFinalizationBB;
141 llvm::BasicBlock *linearExitBB;
142 llvm::BasicBlock *linearLastIterExitBB;
146 void createLinearVar(llvm::IRBuilderBase &builder,
147 LLVM::ModuleTranslation &moduleTranslation,
149 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
150 moduleTranslation.lookupValue(linearVar))) {
151 linearPreconditionVars.push_back(builder.CreateAlloca(
152 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_var"));
153 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
154 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_result");
155 linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
156 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
157 linearOrigVars.push_back(linearVarAlloca);
162 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
164 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
168 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
169 initLinearVar(llvm::IRBuilderBase &builder,
170 LLVM::ModuleTranslation &moduleTranslation,
171 llvm::BasicBlock *loopPreHeader) {
172 builder.SetInsertPoint(loopPreHeader->getTerminator());
173 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
174 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
175 linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
176 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
178 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
179 moduleTranslation.getOpenMPBuilder()->createBarrier(
180 builder.saveIP(), llvm::omp::OMPD_barrier);
181 return afterBarrierIP;
185 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
186 llvm::Value *loopInductionVar) {
187 builder.SetInsertPoint(loopBody->getTerminator());
188 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::LoadInst *linearVarStart =
191 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
193 linearPreconditionVars[index]);
194 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
195 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
196 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
202 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
203 llvm::BasicBlock *loopExit) {
204 linearFinalizationBB = loopExit->splitBasicBlock(
205 loopExit->getTerminator(),
"omp_loop.linear_finalization");
206 linearExitBB = linearFinalizationBB->splitBasicBlock(
207 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
208 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
209 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
213 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
214 finalizeLinearVar(llvm::IRBuilderBase &builder,
215 LLVM::ModuleTranslation &moduleTranslation,
216 llvm::Value *lastIter) {
218 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
219 llvm::Value *loopLastIterLoad = builder.CreateLoad(
220 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
221 llvm::Value *isLast =
222 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
224 llvm::Type::getInt32Ty(builder.getContext()), 0));
226 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
227 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
228 llvm::LoadInst *linearVarTemp =
229 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
230 linearLoopBodyTemps[index]);
231 builder.CreateStore(linearVarTemp, linearOrigVars[index]);
237 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
238 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
239 linearFinalizationBB->getTerminator()->eraseFromParent();
241 builder.SetInsertPoint(linearExitBB->getTerminator());
242 return moduleTranslation.getOpenMPBuilder()->createBarrier(
243 builder.saveIP(), llvm::omp::OMPD_barrier);
248 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
251 for (llvm::User *user : linearOrigVal[varIndex]->users())
252 users.push_back(user);
253 for (
auto *user : users) {
254 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
255 if (userInst->getParent()->getName().str() == BBName)
256 user->replaceUsesOfWith(linearOrigVal[varIndex],
257 linearLoopBodyTemps[varIndex]);
268 SymbolRefAttr symbolName) {
269 omp::PrivateClauseOp privatizer =
270 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
272 assert(privatizer &&
"privatizer not found in the symbol table");
283 auto todo = [&op](StringRef clauseName) {
284 return op.
emitError() <<
"not yet implemented: Unhandled clause "
285 << clauseName <<
" in " << op.
getName()
289 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
290 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
291 result = todo(
"allocate");
293 auto checkBare = [&todo](
auto op, LogicalResult &result) {
295 result = todo(
"ompx_bare");
297 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
298 omp::ClauseCancellationConstructType cancelledDirective =
299 op.getCancelDirective();
302 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
309 if (isa_and_nonnull<omp::TaskloopOp>(parent))
310 result = todo(
"cancel directive inside of taskloop");
313 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
314 if (!op.getDependVars().empty() || op.getDependKinds())
315 result = todo(
"depend");
317 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
319 result = todo(
"device");
321 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
322 if (op.getDistScheduleChunkSize())
323 result = todo(
"dist_schedule with chunk_size");
325 auto checkHint = [](
auto op, LogicalResult &) {
329 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
330 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
331 op.getInReductionSyms())
332 result = todo(
"in_reduction");
334 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
335 if (!op.getIsDevicePtrVars().empty())
336 result = todo(
"is_device_ptr");
338 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
339 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
340 result = todo(
"linear");
342 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
344 result = todo(
"nowait");
346 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
347 if (op.getOrder() || op.getOrderMod())
348 result = todo(
"order");
350 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
351 if (op.getParLevelSimd())
352 result = todo(
"parallelization-level");
354 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
355 if (op.getPriority())
356 result = todo(
"priority");
358 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
359 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
361 if (!op.getPrivateVars().empty() && op.getNowait())
362 result = todo(
"privatization for deferred target tasks");
364 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
365 result = todo(
"privatization");
368 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
369 if (isa<omp::TeamsOp>(op))
370 if (!op.getReductionVars().empty() || op.getReductionByref() ||
371 op.getReductionSyms())
372 result = todo(
"reduction");
373 if (op.getReductionMod() &&
374 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
375 result = todo(
"reduction with modifier");
377 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
378 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
379 op.getTaskReductionSyms())
380 result = todo(
"task_reduction");
382 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
384 result = todo(
"untied");
387 LogicalResult result = success();
389 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
390 .Case([&](omp::CancellationPointOp op) {
391 checkCancelDirective(op, result);
393 .Case([&](omp::DistributeOp op) {
394 checkAllocate(op, result);
395 checkDistSchedule(op, result);
396 checkOrder(op, result);
398 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
399 .Case([&](omp::SectionsOp op) {
400 checkAllocate(op, result);
401 checkPrivate(op, result);
402 checkReduction(op, result);
404 .Case([&](omp::SingleOp op) {
405 checkAllocate(op, result);
406 checkPrivate(op, result);
408 .Case([&](omp::TeamsOp op) {
409 checkAllocate(op, result);
410 checkPrivate(op, result);
412 .Case([&](omp::TaskOp op) {
413 checkAllocate(op, result);
414 checkInReduction(op, result);
416 .Case([&](omp::TaskgroupOp op) {
417 checkAllocate(op, result);
418 checkTaskReduction(op, result);
420 .Case([&](omp::TaskwaitOp op) {
421 checkDepend(op, result);
422 checkNowait(op, result);
424 .Case([&](omp::TaskloopOp op) {
426 checkUntied(op, result);
427 checkPriority(op, result);
429 .Case([&](omp::WsloopOp op) {
430 checkAllocate(op, result);
431 checkLinear(op, result);
432 checkOrder(op, result);
433 checkReduction(op, result);
435 .Case([&](omp::ParallelOp op) {
436 checkAllocate(op, result);
437 checkReduction(op, result);
439 .Case([&](omp::SimdOp op) {
440 checkLinear(op, result);
441 checkReduction(op, result);
443 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
444 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
445 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
446 [&](
auto op) { checkDepend(op, result); })
447 .Case([&](omp::TargetOp op) {
448 checkAllocate(op, result);
449 checkBare(op, result);
450 checkDevice(op, result);
451 checkInReduction(op, result);
452 checkIsDevicePtr(op, result);
453 checkPrivate(op, result);
463 LogicalResult result = success();
465 llvm::handleAllErrors(
467 [&](
const PreviouslyReportedError &) { result = failure(); },
468 [&](
const llvm::ErrorInfoBase &err) {
475 template <
typename T>
485 static llvm::OpenMPIRBuilder::InsertPointTy
487 LLVM::ModuleTranslation &moduleTranslation) {
491 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
492 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
493 [&](OpenMPAllocaStackFrame &frame) {
494 allocaInsertPoint = frame.allocaInsertPoint;
502 allocaInsertPoint.getBlock()->getParent() ==
503 builder.GetInsertBlock()->getParent())
504 return allocaInsertPoint;
513 if (builder.GetInsertBlock() ==
514 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
515 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
516 "Assuming end of basic block");
517 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
518 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
519 builder.GetInsertBlock()->getNextNode());
520 builder.CreateBr(entryBB);
521 builder.SetInsertPoint(entryBB);
524 llvm::BasicBlock &funcEntryBlock =
525 builder.GetInsertBlock()->getParent()->getEntryBlock();
526 return llvm::OpenMPIRBuilder::InsertPointTy(
527 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
533 static llvm::CanonicalLoopInfo *
535 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
536 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
537 [&](OpenMPLoopInfoStackFrame &frame) {
538 loopInfo = frame.loopInfo;
550 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
551 LLVM::ModuleTranslation &moduleTranslation,
553 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
555 llvm::BasicBlock *continuationBlock =
556 splitBB(builder,
true,
"omp.region.cont");
557 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
559 llvm::LLVMContext &llvmContext = builder.getContext();
560 for (
Block &bb : region) {
561 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
562 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
563 builder.GetInsertBlock()->getNextNode());
564 moduleTranslation.mapBlock(&bb, llvmBB);
567 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
574 unsigned numYields = 0;
576 if (!isLoopWrapper) {
577 bool operandsProcessed =
false;
579 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
580 if (!operandsProcessed) {
581 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
582 continuationBlockPHITypes.push_back(
583 moduleTranslation.convertType(yield->getOperand(i).getType()));
585 operandsProcessed =
true;
587 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
588 "mismatching number of values yielded from the region");
589 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
590 llvm::Type *operandType =
591 moduleTranslation.convertType(yield->getOperand(i).getType());
593 assert(continuationBlockPHITypes[i] == operandType &&
594 "values of mismatching types yielded from the region");
604 if (!continuationBlockPHITypes.empty())
606 continuationBlockPHIs &&
607 "expected continuation block PHIs if converted regions yield values");
608 if (continuationBlockPHIs) {
609 llvm::IRBuilderBase::InsertPointGuard guard(builder);
610 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
611 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
612 for (llvm::Type *ty : continuationBlockPHITypes)
613 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
619 for (
Block *bb : blocks) {
620 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
623 if (bb->isEntryBlock()) {
624 assert(sourceTerminator->getNumSuccessors() == 1 &&
625 "provided entry block has multiple successors");
626 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
627 "ContinuationBlock is not the successor of the entry block");
628 sourceTerminator->setSuccessor(0, llvmBB);
631 llvm::IRBuilderBase::InsertPointGuard guard(builder);
633 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
634 return llvm::make_error<PreviouslyReportedError>();
639 builder.CreateBr(continuationBlock);
650 Operation *terminator = bb->getTerminator();
651 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
652 builder.CreateBr(continuationBlock);
654 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
655 (*continuationBlockPHIs)[i]->addIncoming(
656 moduleTranslation.lookupValue(terminator->
getOperand(i)), llvmBB);
667 moduleTranslation.forgetMapping(region);
669 return continuationBlock;
675 case omp::ClauseProcBindKind::Close:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
677 case omp::ClauseProcBindKind::Master:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
679 case omp::ClauseProcBindKind::Primary:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
681 case omp::ClauseProcBindKind::Spread:
682 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
684 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
693 static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
694 omp::BlockArgOpenMPOpInterface blockArgIface) {
696 blockArgIface.getBlockArgsPairs(blockArgsPairs);
697 for (
auto [var, arg] : blockArgsPairs)
698 moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
704 LLVM::ModuleTranslation &moduleTranslation) {
705 auto maskedOp = cast<omp::MaskedOp>(opInst);
706 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
711 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
713 auto ®ion = maskedOp.getRegion();
714 builder.restoreIP(codeGenIP);
722 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
724 llvm::Value *filterVal =
nullptr;
725 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
726 filterVal = moduleTranslation.lookupValue(filterVar);
728 llvm::LLVMContext &llvmContext = builder.getContext();
732 assert(filterVal !=
nullptr);
733 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
734 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
735 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
741 builder.restoreIP(*afterIP);
748 LLVM::ModuleTranslation &moduleTranslation) {
749 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
750 auto masterOp = cast<omp::MasterOp>(opInst);
755 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
757 auto ®ion = masterOp.getRegion();
758 builder.restoreIP(codeGenIP);
766 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
768 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
769 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
770 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
776 builder.restoreIP(*afterIP);
783 LLVM::ModuleTranslation &moduleTranslation) {
784 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
785 auto criticalOp = cast<omp::CriticalOp>(opInst);
790 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
792 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
793 builder.restoreIP(codeGenIP);
801 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
803 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
804 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
805 llvm::Constant *hint =
nullptr;
808 if (criticalOp.getNameAttr()) {
811 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
812 auto criticalDeclareOp =
813 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
817 static_cast<int>(criticalDeclareOp.getHint()));
819 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
820 moduleTranslation.getOpenMPBuilder()->createCritical(
821 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
826 builder.restoreIP(*afterIP);
833 template <
typename OP>
836 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
837 mlirVars.reserve(blockArgs.size());
838 llvmVars.reserve(blockArgs.size());
839 collectPrivatizationDecls<OP>(op);
842 mlirVars.push_back(privateVar);
854 void collectPrivatizationDecls(OP op) {
855 std::optional<ArrayAttr> attr = op.getPrivateSyms();
859 privatizers.reserve(privatizers.size() + attr->size());
860 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
867 template <
typename T>
871 std::optional<ArrayAttr> attr = op.getReductionSyms();
875 reductions.reserve(reductions.size() + op.getNumReductionVars());
876 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
877 reductions.push_back(
878 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
889 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
890 LLVM::ModuleTranslation &moduleTranslation,
898 llvm::Instruction *potentialTerminator =
899 builder.GetInsertBlock()->empty() ? nullptr
900 : &builder.GetInsertBlock()->back();
902 if (potentialTerminator && potentialTerminator->isTerminator())
903 potentialTerminator->removeFromParent();
904 moduleTranslation.mapBlock(®ion.
front(), builder.GetInsertBlock());
906 if (
failed(moduleTranslation.convertBlock(
907 region.
front(),
true, builder)))
911 if (continuationBlockArgs)
913 *continuationBlockArgs,
918 moduleTranslation.forgetMapping(region);
920 if (potentialTerminator && potentialTerminator->isTerminator()) {
921 llvm::BasicBlock *block = builder.GetInsertBlock();
922 if (block->empty()) {
928 potentialTerminator->insertInto(block, block->begin());
930 potentialTerminator->insertAfter(&block->back());
944 if (continuationBlockArgs)
945 llvm::append_range(*continuationBlockArgs, phis);
946 builder.SetInsertPoint(*continuationBlock,
947 (*continuationBlock)->getFirstInsertionPt());
954 using OwningReductionGen =
955 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
956 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
958 using OwningAtomicReductionGen =
959 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
960 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
967 static OwningReductionGen
969 LLVM::ModuleTranslation &moduleTranslation) {
973 OwningReductionGen gen =
974 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
975 llvm::Value *lhs, llvm::Value *rhs,
976 llvm::Value *&result)
mutable
977 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
978 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
979 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
980 builder.restoreIP(insertPoint);
983 "omp.reduction.nonatomic.body", builder,
984 moduleTranslation, &phis)))
985 return llvm::createStringError(
986 "failed to inline `combiner` region of `omp.declare_reduction`");
987 result = llvm::getSingleElement(phis);
988 return builder.saveIP();
997 static OwningAtomicReductionGen
999 llvm::IRBuilderBase &builder,
1000 LLVM::ModuleTranslation &moduleTranslation) {
1001 if (decl.getAtomicReductionRegion().empty())
1002 return OwningAtomicReductionGen();
1007 OwningAtomicReductionGen atomicGen =
1008 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1009 llvm::Value *lhs, llvm::Value *rhs)
mutable
1010 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1011 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1012 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1013 builder.restoreIP(insertPoint);
1016 "omp.reduction.atomic.body", builder,
1017 moduleTranslation, &phis)))
1018 return llvm::createStringError(
1019 "failed to inline `atomic` region of `omp.declare_reduction`");
1020 assert(phis.empty());
1021 return builder.saveIP();
1027 static LogicalResult
1029 LLVM::ModuleTranslation &moduleTranslation) {
1030 auto orderedOp = cast<omp::OrderedOp>(opInst);
1035 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1036 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1037 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1039 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1041 size_t indexVecValues = 0;
1042 while (indexVecValues < vecValues.size()) {
1044 storeValues.reserve(numLoops);
1045 for (
unsigned i = 0; i < numLoops; i++) {
1046 storeValues.push_back(vecValues[indexVecValues]);
1049 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1051 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1052 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1053 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1060 static LogicalResult
1062 LLVM::ModuleTranslation &moduleTranslation) {
1063 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1064 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1069 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1071 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1072 builder.restoreIP(codeGenIP);
1080 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1082 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1083 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1084 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1085 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1090 builder.restoreIP(*afterIP);
1096 struct DeferredStore {
1097 DeferredStore(llvm::Value *value, llvm::Value *address)
1098 : value(value), address(address) {}
1101 llvm::Value *address;
1108 template <
typename T>
1109 static LogicalResult
1111 llvm::IRBuilderBase &builder,
1112 LLVM::ModuleTranslation &moduleTranslation,
1113 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1119 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1120 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1123 deferredStores.reserve(loop.getNumReductionVars());
1125 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1126 Region &allocRegion = reductionDecls[i].getAllocRegion();
1128 if (allocRegion.
empty())
1133 builder, moduleTranslation, &phis)))
1134 return loop.emitError(
1135 "failed to inline `alloc` region of `omp.declare_reduction`");
1137 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1142 llvm::Value *var = builder.CreateAlloca(
1143 moduleTranslation.convertType(reductionDecls[i].getType()));
1145 llvm::Type *ptrTy = builder.getPtrTy();
1146 llvm::Value *castVar =
1147 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1148 llvm::Value *castPhi =
1149 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1151 deferredStores.emplace_back(castPhi, castVar);
1153 privateReductionVariables[i] = castVar;
1154 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1155 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1157 assert(allocRegion.
empty() &&
1158 "allocaction is implicit for by-val reduction");
1159 llvm::Value *var = builder.CreateAlloca(
1160 moduleTranslation.convertType(reductionDecls[i].getType()));
1162 llvm::Type *ptrTy = builder.getPtrTy();
1163 llvm::Value *castVar =
1164 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 moduleTranslation.mapValue(reductionArgs[i], castVar);
1167 privateReductionVariables[i] = castVar;
1168 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1176 template <
typename T>
1183 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1184 Region &initializerRegion = reduction.getInitializerRegion();
1187 mlir::Value mlirSource = loop.getReductionVars()[i];
1188 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1189 assert(llvmSource &&
"lookup reduction var");
1190 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
1193 llvm::Value *allocation =
1194 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1195 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1201 llvm::BasicBlock *block =
nullptr) {
1202 if (block ==
nullptr)
1203 block = builder.GetInsertBlock();
1205 if (block->empty() || block->getTerminator() ==
nullptr)
1206 builder.SetInsertPoint(block);
1208 builder.SetInsertPoint(block->getTerminator());
1216 template <
typename OP>
1217 static LogicalResult
1219 llvm::IRBuilderBase &builder,
1220 LLVM::ModuleTranslation &moduleTranslation,
1221 llvm::BasicBlock *latestAllocaBlock,
1227 if (op.getNumReductionVars() == 0)
1230 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1231 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1232 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1233 builder.restoreIP(allocaIP);
1236 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1238 if (!reductionDecls[i].getAllocRegion().empty())
1244 byRefVars[i] = builder.CreateAlloca(
1245 moduleTranslation.convertType(reductionDecls[i].getType()));
1253 for (
auto [data, addr] : deferredStores)
1254 builder.CreateStore(data, addr);
1259 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1264 reductionVariableMap, i);
1272 "omp.reduction.neutral", builder,
1273 moduleTranslation, &phis)))
1276 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1277 "reduction neutral element declaration region");
1282 if (!reductionDecls[i].getAllocRegion().empty())
1291 builder.CreateStore(phis[0], byRefVars[i]);
1293 privateReductionVariables[i] = byRefVars[i];
1294 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1295 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1298 builder.CreateStore(phis[0], privateReductionVariables[i]);
1305 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1312 template <
typename T>
1314 T loop, llvm::IRBuilderBase &builder,
1315 LLVM::ModuleTranslation &moduleTranslation,
1321 unsigned numReductions = loop.getNumReductionVars();
1323 for (
unsigned i = 0; i < numReductions; ++i) {
1324 owningReductionGens.push_back(
1326 owningAtomicReductionGens.push_back(
1331 reductionInfos.reserve(numReductions);
1332 for (
unsigned i = 0; i < numReductions; ++i) {
1333 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1334 if (owningAtomicReductionGens[i])
1335 atomicGen = owningAtomicReductionGens[i];
1336 llvm::Value *variable =
1337 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1338 reductionInfos.push_back(
1339 {moduleTranslation.convertType(reductionDecls[i].
getType()), variable,
1340 privateReductionVariables[i],
1341 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1342 owningReductionGens[i],
1343 nullptr, atomicGen});
1348 static LogicalResult
1351 LLVM::ModuleTranslation &moduleTranslation,
1352 llvm::IRBuilderBase &builder, StringRef regionName,
1353 bool shouldLoadCleanupRegionArg =
true) {
1355 if (cleanupRegion->empty())
1361 llvm::Instruction *potentialTerminator =
1362 builder.GetInsertBlock()->empty() ? nullptr
1363 : &builder.GetInsertBlock()->back();
1364 if (potentialTerminator && potentialTerminator->isTerminator())
1365 builder.SetInsertPoint(potentialTerminator);
1366 llvm::Value *privateVarValue =
1367 shouldLoadCleanupRegionArg
1368 ? builder.CreateLoad(
1370 privateVariables[i])
1371 : privateVariables[i];
1373 moduleTranslation.mapValue(entry.
getArgument(0), privateVarValue);
1376 moduleTranslation)))
1381 moduleTranslation.forgetMapping(*cleanupRegion);
1389 OP op, llvm::IRBuilderBase &builder,
1390 LLVM::ModuleTranslation &moduleTranslation,
1391 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1394 bool isNowait =
false,
bool isTeamsReduction =
false) {
1396 if (op.getNumReductionVars() == 0)
1403 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1408 owningReductionGens, owningAtomicReductionGens,
1409 privateReductionVariables, reductionInfos);
1414 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1415 builder.SetInsertPoint(tempTerminator);
1416 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1417 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1418 isByRef, isNowait, isTeamsReduction);
1423 if (!contInsertPoint->getBlock())
1424 return op->emitOpError() <<
"failed to convert reductions";
1426 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1427 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1432 tempTerminator->eraseFromParent();
1433 builder.restoreIP(*afterIP);
1437 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1438 [](omp::DeclareReductionOp reductionDecl) {
1439 return &reductionDecl.getCleanupRegion();
1442 moduleTranslation, builder,
1443 "omp.reduction.cleanup");
1454 template <
typename OP>
1457 LLVM::ModuleTranslation &moduleTranslation,
1458 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1463 if (op.getNumReductionVars() == 0)
1469 allocaIP, reductionDecls,
1470 privateReductionVariables, reductionVariableMap,
1471 deferredStores, isByRef)))
1475 allocaIP.getBlock(), reductionDecls,
1476 privateReductionVariables, reductionVariableMap,
1477 isByRef, deferredStores);
1487 static llvm::Value *
1489 LLVM::ModuleTranslation &moduleTranslation,
1491 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1492 return moduleTranslation.lookupValue(privateVar);
1494 Value blockArg = (*mappedPrivateVars)[privateVar];
1497 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1498 "A block argument corresponding to a mapped var should have "
1501 if (privVarType == blockArgType)
1502 return moduleTranslation.lookupValue(blockArg);
1508 if (!isa<LLVM::LLVMPointerType>(privVarType))
1509 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1510 moduleTranslation.lookupValue(blockArg));
1512 return moduleTranslation.lookupValue(privateVar);
1520 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1522 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1524 Region &initRegion = privDecl.getInitRegion();
1525 if (initRegion.
empty())
1526 return llvmPrivateVar;
1530 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1531 assert(nonPrivateVar);
1532 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1533 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1538 moduleTranslation, &phis)))
1539 return llvm::createStringError(
1540 "failed to inline `init` region of `omp.private`");
1542 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1547 moduleTranslation.forgetMapping(initRegion);
1557 LLVM::ModuleTranslation &moduleTranslation,
1561 return llvm::Error::success();
1563 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1569 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1571 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1572 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1575 return privVarOrErr.takeError();
1577 llvmPrivateVar = privVarOrErr.get();
1578 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1583 return llvm::Error::success();
1591 LLVM::ModuleTranslation &moduleTranslation,
1593 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1596 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1597 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1598 allocaTerminator->getIterator()),
1599 true, allocaTerminator->getStableDebugLoc(),
1600 "omp.region.after_alloca");
1602 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1604 allocaTerminator = allocaIP.getBlock()->getTerminator();
1605 builder.SetInsertPoint(allocaTerminator);
1607 assert(allocaTerminator->getNumSuccessors() == 1 &&
1608 "This is an unconditional branch created by splitBB");
1610 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1611 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1613 unsigned int allocaAS =
1614 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1615 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1617 .getProgramAddressSpace();
1619 for (
auto [privDecl, mlirPrivVar, blockArg] :
1622 llvm::Type *llvmAllocType =
1623 moduleTranslation.convertType(privDecl.getType());
1624 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1625 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1626 llvmAllocType,
nullptr,
"omp.private.alloc");
1627 if (allocaAS != defaultAS)
1628 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1629 builder.getPtrTy(defaultAS));
1631 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1634 return afterAllocas;
1639 LLVM::ModuleTranslation &moduleTranslation,
1645 bool needsFirstprivate =
1646 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1647 return privOp.getDataSharingType() ==
1648 omp::DataSharingClauseType::FirstPrivate;
1651 if (!needsFirstprivate)
1654 llvm::BasicBlock *copyBlock =
1655 splitBB(builder,
true,
"omp.private.copy");
1658 for (
auto [decl, mlirVar, llvmVar] :
1659 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1660 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1664 Region ©Region = decl.getCopyRegion();
1668 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1669 assert(nonPrivateVar);
1670 moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1673 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1677 moduleTranslation)))
1678 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1687 moduleTranslation.forgetMapping(copyRegion);
1690 if (insertBarrier) {
1691 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1692 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1693 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1701 static LogicalResult
1703 LLVM::ModuleTranslation &moduleTranslation,
Location loc,
1708 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1709 [](omp::PrivateClauseOp privatizer) {
1710 return &privatizer.getDeallocRegion();
1714 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1715 "omp.private.dealloc",
false)))
1716 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1717 "`omp.private` op in");
1729 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1736 static LogicalResult
1738 LLVM::ModuleTranslation &moduleTranslation) {
1739 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1740 using StorableBodyGenCallbackTy =
1741 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1743 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1749 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1753 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1757 sectionsOp.getNumReductionVars());
1761 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1764 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1765 reductionDecls, privateReductionVariables, reductionVariableMap,
1772 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1776 Region ®ion = sectionOp.getRegion();
1777 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1778 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1779 builder.restoreIP(codeGenIP);
1786 sectionsOp.getRegion().getNumArguments());
1787 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1788 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1789 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1791 moduleTranslation.mapValue(sectionArg, llvmVal);
1798 sectionCBs.push_back(sectionCB);
1804 if (sectionCBs.empty())
1807 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1812 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1813 llvm::Value &vPtr, llvm::Value *&replacementValue)
1814 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1815 replacementValue = &vPtr;
1821 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1825 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1826 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1827 moduleTranslation.getOpenMPBuilder()->createSections(
1828 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1829 sectionsOp.getNowait());
1834 builder.restoreIP(*afterIP);
1838 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1839 privateReductionVariables, isByRef, sectionsOp.getNowait());
1843 static LogicalResult
1845 LLVM::ModuleTranslation &moduleTranslation) {
1846 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1847 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1852 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1853 builder.restoreIP(codegenIP);
1855 builder, moduleTranslation)
1858 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1862 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1865 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1866 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1867 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1868 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1869 llvmCPFuncs.push_back(
1870 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1873 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1874 moduleTranslation.getOpenMPBuilder()->createSingle(
1875 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1881 builder.restoreIP(*afterIP);
1887 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1892 for (
auto ra : iface.getReductionBlockArgs())
1893 for (
auto &use : ra.getUses()) {
1894 auto *useOp = use.getOwner();
1896 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1897 debugUses.push_back(useOp);
1901 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1906 Operation *currentOp = currentDistOp.getOperation();
1907 if (distOp && (distOp != currentOp))
1916 for (
auto use : debugUses)
1922 static LogicalResult
1924 LLVM::ModuleTranslation &moduleTranslation) {
1925 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1930 unsigned numReductionVars = op.getNumReductionVars();
1934 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1940 if (doTeamsReduction) {
1941 isByRef =
getIsByRef(op.getReductionByref());
1943 assert(isByRef.size() == op.getNumReductionVars());
1946 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1951 op, reductionArgs, builder, moduleTranslation, allocaIP,
1952 reductionDecls, privateReductionVariables, reductionVariableMap,
1957 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1959 moduleTranslation, allocaIP);
1960 builder.restoreIP(codegenIP);
1966 llvm::Value *numTeamsLower =
nullptr;
1967 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1968 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
1970 llvm::Value *numTeamsUpper =
nullptr;
1971 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1972 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
1974 llvm::Value *threadLimit =
nullptr;
1975 if (
Value threadLimitVar = op.getThreadLimit())
1976 threadLimit = moduleTranslation.lookupValue(threadLimitVar);
1978 llvm::Value *ifExpr =
nullptr;
1979 if (
Value ifVar = op.getIfExpr())
1980 ifExpr = moduleTranslation.lookupValue(ifVar);
1982 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1983 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1984 moduleTranslation.getOpenMPBuilder()->createTeams(
1985 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1990 builder.restoreIP(*afterIP);
1991 if (doTeamsReduction) {
1994 op, builder, moduleTranslation, allocaIP, reductionDecls,
1995 privateReductionVariables, isByRef,
2003 LLVM::ModuleTranslation &moduleTranslation,
2005 if (dependVars.empty())
2007 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2008 llvm::omp::RTLDependenceKindTy type;
2010 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2011 case mlir::omp::ClauseTaskDepend::taskdependin:
2012 type = llvm::omp::RTLDependenceKindTy::DepIn;
2017 case mlir::omp::ClauseTaskDepend::taskdependout:
2018 case mlir::omp::ClauseTaskDepend::taskdependinout:
2019 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2021 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2022 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2024 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2025 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2028 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2029 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2030 dds.emplace_back(dd);
2042 llvm::IRBuilderBase &llvmBuilder,
2044 llvm::omp::Directive cancelDirective) {
2045 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2046 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2050 llvmBuilder.restoreIP(ip);
2056 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2057 return llvm::Error::success();
2062 ompBuilder.pushFinalizationCB(
2072 llvm::OpenMPIRBuilder &ompBuilder,
2073 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2074 ompBuilder.popFinalizationCB();
2075 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2076 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2077 assert(cancelBranch->getNumSuccessors() == 1 &&
2078 "cancel branch should have one target");
2079 cancelBranch->setSuccessor(0, constructFini);
2086 class TaskContextStructManager {
2088 TaskContextStructManager(llvm::IRBuilderBase &builder,
2089 LLVM::ModuleTranslation &moduleTranslation,
2091 : builder{builder}, moduleTranslation{moduleTranslation},
2092 privateDecls{privateDecls} {}
2098 void generateTaskContextStruct();
2104 void createGEPsToPrivateVars();
2107 void freeStructPtr();
2110 return llvmPrivateVarGEPs;
2113 llvm::Value *getStructPtr() {
return structPtr; }
2116 llvm::IRBuilderBase &builder;
2117 LLVM::ModuleTranslation &moduleTranslation;
2128 llvm::Value *structPtr =
nullptr;
2130 llvm::Type *structTy =
nullptr;
2134 void TaskContextStructManager::generateTaskContextStruct() {
2135 if (privateDecls.empty())
2137 privateVarTypes.reserve(privateDecls.size());
2139 for (omp::PrivateClauseOp &privOp : privateDecls) {
2142 if (!privOp.readsFromMold())
2144 Type mlirType = privOp.getType();
2145 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2151 llvm::DataLayout dataLayout =
2152 builder.GetInsertBlock()->getModule()->getDataLayout();
2153 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2154 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2157 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2159 "omp.task.context_ptr");
2162 void TaskContextStructManager::createGEPsToPrivateVars() {
2164 assert(privateVarTypes.empty());
2169 llvmPrivateVarGEPs.clear();
2170 llvmPrivateVarGEPs.reserve(privateDecls.size());
2171 llvm::Value *zero = builder.getInt32(0);
2173 for (
auto privDecl : privateDecls) {
2174 if (!privDecl.readsFromMold()) {
2176 llvmPrivateVarGEPs.push_back(
nullptr);
2179 llvm::Value *iVal = builder.getInt32(i);
2180 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2181 llvmPrivateVarGEPs.push_back(gep);
2186 void TaskContextStructManager::freeStructPtr() {
2190 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2192 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2193 builder.CreateFree(structPtr);
2197 static LogicalResult
2199 LLVM::ModuleTranslation &moduleTranslation) {
2200 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2205 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2217 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2222 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2223 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2224 builder.getContext(),
"omp.task.start",
2225 builder.GetInsertBlock()->getParent());
2226 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2227 builder.SetInsertPoint(branchToTaskStartBlock);
2230 llvm::BasicBlock *copyBlock =
2231 splitBB(builder,
true,
"omp.private.copy");
2232 llvm::BasicBlock *initBlock =
2233 splitBB(builder,
true,
"omp.private.init");
2249 moduleTranslation, allocaIP);
2252 builder.SetInsertPoint(initBlock->getTerminator());
2255 taskStructMgr.generateTaskContextStruct();
2262 taskStructMgr.createGEPsToPrivateVars();
2264 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2267 taskStructMgr.getLLVMPrivateVarGEPs())) {
2269 if (!privDecl.readsFromMold())
2271 assert(llvmPrivateVarAlloc &&
2272 "reads from mold so shouldn't have been skipped");
2275 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2276 blockArg, llvmPrivateVarAlloc, initBlock);
2277 if (!privateVarOrErr)
2278 return handleError(privateVarOrErr, *taskOp.getOperation());
2287 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2288 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2289 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2291 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2292 llvmPrivateVarAlloc);
2294 assert(llvmPrivateVarAlloc->getType() ==
2295 moduleTranslation.convertType(blockArg.getType()));
2305 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2306 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2307 taskOp.getPrivateNeedsBarrier())))
2308 return llvm::failure();
2311 builder.SetInsertPoint(taskStartBlock);
2313 auto bodyCB = [&](InsertPointTy allocaIP,
2314 InsertPointTy codegenIP) -> llvm::Error {
2318 moduleTranslation, allocaIP);
2321 builder.restoreIP(codegenIP);
2323 llvm::BasicBlock *privInitBlock =
nullptr;
2328 auto [blockArg, privDecl, mlirPrivVar] = zip;
2330 if (privDecl.readsFromMold())
2333 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2334 llvm::Type *llvmAllocType =
2335 moduleTranslation.convertType(privDecl.getType());
2336 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2337 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2338 llvmAllocType,
nullptr,
"omp.private.alloc");
2341 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2342 blockArg, llvmPrivateVar, privInitBlock);
2343 if (!privateVarOrError)
2344 return privateVarOrError.takeError();
2345 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2346 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2349 taskStructMgr.createGEPsToPrivateVars();
2350 for (
auto [i, llvmPrivVar] :
2353 assert(privateVarsInfo.
llvmVars[i] &&
2354 "This is added in the loop above");
2357 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2362 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2366 if (!privateDecl.readsFromMold())
2369 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2370 llvmPrivateVar = builder.CreateLoad(
2371 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2373 assert(llvmPrivateVar->getType() ==
2374 moduleTranslation.convertType(blockArg.getType()));
2375 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2379 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2381 return llvm::make_error<PreviouslyReportedError>();
2383 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2388 return llvm::make_error<PreviouslyReportedError>();
2391 taskStructMgr.freeStructPtr();
2393 return llvm::Error::success();
2396 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2402 llvm::omp::Directive::OMPD_taskgroup);
2406 moduleTranslation, dds);
2408 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2409 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2410 moduleTranslation.getOpenMPBuilder()->createTask(
2411 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2412 moduleTranslation.lookupValue(taskOp.getFinal()),
2413 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2414 taskOp.getMergeable(),
2415 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2416 moduleTranslation.lookupValue(taskOp.getPriority()));
2424 builder.restoreIP(*afterIP);
2429 static LogicalResult
2431 LLVM::ModuleTranslation &moduleTranslation) {
2432 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2436 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2437 builder.restoreIP(codegenIP);
2439 builder, moduleTranslation)
2444 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2445 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2446 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2452 builder.restoreIP(*afterIP);
2456 static LogicalResult
2458 LLVM::ModuleTranslation &moduleTranslation) {
2462 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2467 static LogicalResult
2469 LLVM::ModuleTranslation &moduleTranslation) {
2470 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2471 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2475 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2477 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2481 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2484 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2485 llvm::Type *ivType = step->getType();
2486 llvm::Value *chunk =
nullptr;
2487 if (wsloopOp.getScheduleChunk()) {
2488 llvm::Value *chunkVar =
2489 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2490 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2497 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2501 wsloopOp.getNumReductionVars());
2504 builder, moduleTranslation, privateVarsInfo, allocaIP);
2511 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2516 moduleTranslation, allocaIP, reductionDecls,
2517 privateReductionVariables, reductionVariableMap,
2518 deferredStores, isByRef)))
2527 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2529 wsloopOp.getPrivateNeedsBarrier())))
2532 assert(afterAllocas.get()->getSinglePredecessor());
2535 afterAllocas.get()->getSinglePredecessor(),
2536 reductionDecls, privateReductionVariables,
2537 reductionVariableMap, isByRef, deferredStores)))
2541 bool isOrdered = wsloopOp.getOrdered().has_value();
2542 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2543 bool isSimd = wsloopOp.getScheduleSimd();
2544 bool loopNeedsBarrier = !wsloopOp.getNowait();
2549 llvm::omp::WorksharingLoopType workshareLoopType =
2550 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2551 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2552 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2556 llvm::omp::Directive::OMPD_for);
2558 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2561 LinearClauseProcessor linearClauseProcessor;
2562 if (wsloopOp.getLinearVars().size()) {
2563 for (
mlir::Value linearVar : wsloopOp.getLinearVars())
2564 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2566 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2567 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2571 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2579 if (wsloopOp.getLinearVars().size()) {
2580 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2581 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2582 loopInfo->getPreheader());
2585 builder.restoreIP(*afterBarrierIP);
2586 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2587 loopInfo->getIndVar());
2588 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2589 loopInfo->getExit());
2592 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2593 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2594 ompBuilder->applyWorkshareLoop(
2595 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2596 convertToScheduleKind(schedule), chunk, isSimd,
2597 scheduleMod == omp::ScheduleModifier::monotonic,
2598 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2605 if (wsloopOp.getLinearVars().size()) {
2606 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2607 assert(loopInfo->getLastIter() &&
2608 "`lastiter` in CanonicalLoopInfo is nullptr");
2609 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2610 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2611 loopInfo->getLastIter());
2614 for (
size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2615 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2617 builder.restoreIP(oldIP);
2625 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2626 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2636 static LogicalResult
2638 LLVM::ModuleTranslation &moduleTranslation) {
2639 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2641 assert(isByRef.size() == opInst.getNumReductionVars());
2642 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2653 opInst.getNumReductionVars());
2656 auto bodyGenCB = [&](InsertPointTy allocaIP,
2657 InsertPointTy codeGenIP) -> llvm::Error {
2659 builder, moduleTranslation, privateVarsInfo, allocaIP);
2661 return llvm::make_error<PreviouslyReportedError>();
2667 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2670 InsertPointTy(allocaIP.getBlock(),
2671 allocaIP.getBlock()->getTerminator()->getIterator());
2674 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2675 reductionDecls, privateReductionVariables, reductionVariableMap,
2676 deferredStores, isByRef)))
2677 return llvm::make_error<PreviouslyReportedError>();
2679 assert(afterAllocas.get()->getSinglePredecessor());
2680 builder.restoreIP(codeGenIP);
2686 return llvm::make_error<PreviouslyReportedError>();
2689 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2691 opInst.getPrivateNeedsBarrier())))
2692 return llvm::make_error<PreviouslyReportedError>();
2696 afterAllocas.get()->getSinglePredecessor(),
2697 reductionDecls, privateReductionVariables,
2698 reductionVariableMap, isByRef, deferredStores)))
2699 return llvm::make_error<PreviouslyReportedError>();
2704 moduleTranslation, allocaIP);
2708 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2710 return regionBlock.takeError();
2713 if (opInst.getNumReductionVars() > 0) {
2719 owningReductionGens, owningAtomicReductionGens,
2720 privateReductionVariables, reductionInfos);
2723 builder.SetInsertPoint((*regionBlock)->getTerminator());
2726 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2727 builder.SetInsertPoint(tempTerminator);
2729 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2730 ompBuilder->createReductions(
2731 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2733 if (!contInsertPoint)
2734 return contInsertPoint.takeError();
2736 if (!contInsertPoint->getBlock())
2737 return llvm::make_error<PreviouslyReportedError>();
2739 tempTerminator->eraseFromParent();
2740 builder.restoreIP(*contInsertPoint);
2743 return llvm::Error::success();
2746 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2747 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2756 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2757 InsertPointTy oldIP = builder.saveIP();
2758 builder.restoreIP(codeGenIP);
2763 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2764 [](omp::DeclareReductionOp reductionDecl) {
2765 return &reductionDecl.getCleanupRegion();
2768 reductionCleanupRegions, privateReductionVariables,
2769 moduleTranslation, builder,
"omp.reduction.cleanup")))
2770 return llvm::createStringError(
2771 "failed to inline `cleanup` region of `omp.declare_reduction`");
2776 return llvm::make_error<PreviouslyReportedError>();
2778 builder.restoreIP(oldIP);
2779 return llvm::Error::success();
2782 llvm::Value *ifCond =
nullptr;
2783 if (
auto ifVar = opInst.getIfExpr())
2784 ifCond = moduleTranslation.lookupValue(ifVar);
2785 llvm::Value *numThreads =
nullptr;
2786 if (
auto numThreadsVar = opInst.getNumThreads())
2787 numThreads = moduleTranslation.lookupValue(numThreadsVar);
2788 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2789 if (
auto bind = opInst.getProcBindKind())
2793 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2795 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2797 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2798 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2799 ifCond, numThreads, pbKind, isCancellable);
2804 builder.restoreIP(*afterIP);
2809 static llvm::omp::OrderKind
2812 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2814 case omp::ClauseOrderKind::Concurrent:
2815 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2817 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2821 static LogicalResult
2823 LLVM::ModuleTranslation &moduleTranslation) {
2824 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2825 auto simdOp = cast<omp::SimdOp>(opInst);
2833 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2836 simdOp.getNumReductionVars());
2841 assert(isByRef.size() == simdOp.getNumReductionVars());
2843 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2847 builder, moduleTranslation, privateVarsInfo, allocaIP);
2852 moduleTranslation, allocaIP, reductionDecls,
2853 privateReductionVariables, reductionVariableMap,
2854 deferredStores, isByRef)))
2865 assert(afterAllocas.get()->getSinglePredecessor());
2868 afterAllocas.get()->getSinglePredecessor(),
2869 reductionDecls, privateReductionVariables,
2870 reductionVariableMap, isByRef, deferredStores)))
2873 llvm::ConstantInt *simdlen =
nullptr;
2874 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2875 simdlen = builder.getInt64(simdlenVar.value());
2877 llvm::ConstantInt *safelen =
nullptr;
2878 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2879 safelen = builder.getInt64(safelenVar.value());
2881 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2884 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2885 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2887 for (
size_t i = 0; i < operands.size(); ++i) {
2888 llvm::Value *alignment =
nullptr;
2889 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
2890 llvm::Type *ty = llvmVal->
getType();
2892 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2893 alignment = builder.getInt64(intAttr.getInt());
2894 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2895 assert(alignment &&
"Invalid alignment value");
2899 if (!intAttr.getValue().isPowerOf2())
2902 auto curInsert = builder.saveIP();
2903 builder.SetInsertPoint(sourceBlock);
2904 llvmVal = builder.CreateLoad(ty, llvmVal);
2905 builder.restoreIP(curInsert);
2906 alignedVars[llvmVal] = alignment;
2910 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2915 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2917 ompBuilder->applySimd(loopInfo, alignedVars,
2919 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
2921 order, simdlen, safelen);
2928 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
2929 privateReductionVariables))) {
2930 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
2932 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
2933 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
2934 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
2938 llvm::Value *redValue = originalVariable;
2941 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
2942 llvm::Value *privateRedValue = builder.CreateLoad(
2943 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
2944 llvm::Value *reduced;
2946 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
2949 builder.restoreIP(res.get());
2953 builder.CreateStore(reduced, originalVariable);
2958 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
2959 [](omp::DeclareReductionOp reductionDecl) {
2960 return &reductionDecl.getCleanupRegion();
2963 moduleTranslation, builder,
2964 "omp.reduction.cleanup")))
2973 static LogicalResult
2975 LLVM::ModuleTranslation &moduleTranslation) {
2976 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2977 auto loopOp = cast<omp::LoopNestOp>(opInst);
2980 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2985 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2986 llvm::Value *iv) -> llvm::Error {
2988 moduleTranslation.mapValue(
2989 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2994 bodyInsertPoints.push_back(ip);
2996 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2997 return llvm::Error::success();
3000 builder.restoreIP(ip);
3002 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3004 return regionBlock.takeError();
3006 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3007 return llvm::Error::success();
3015 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3016 llvm::Value *lowerBound =
3017 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3018 llvm::Value *upperBound =
3019 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3020 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3025 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3026 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3028 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3030 computeIP = loopInfos.front()->getPreheaderIP();
3034 ompBuilder->createCanonicalLoop(
3035 loc, bodyGen, lowerBound, upperBound, step,
3036 true, loopOp.getLoopInclusive(), computeIP);
3041 loopInfos.push_back(*loopResult);
3046 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3047 loopInfos.front()->getAfterIP();
3051 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3052 [&](OpenMPLoopInfoStackFrame &frame) {
3053 frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
3061 builder.restoreIP(afterIP);
3066 static LogicalResult
3068 LLVM::ModuleTranslation &moduleTranslation) {
3069 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3071 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3072 Value loopIV = op.getInductionVar();
3073 Value loopTC = op.getTripCount();
3075 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3078 ompBuilder->createCanonicalLoop(
3080 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3083 moduleTranslation.mapValue(loopIV, llvmIV);
3085 builder.restoreIP(ip);
3090 return bodyGenStatus.takeError();
3092 llvmTC,
"omp.loop");
3096 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3097 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3098 builder.restoreIP(afterIP);
3101 if (
Value cli = op.getCli())
3102 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3109 static LogicalResult
3111 LLVM::ModuleTranslation &moduleTranslation) {
3112 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3114 Value applyee = op.getApplyee();
3115 assert(applyee &&
"Loop to apply unrolling on required");
3117 llvm::CanonicalLoopInfo *consBuilderCLI =
3118 moduleTranslation.lookupOMPLoop(applyee);
3119 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3120 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3122 moduleTranslation.invalidateOmpLoop(applyee);
3127 static llvm::AtomicOrdering
3130 return llvm::AtomicOrdering::Monotonic;
3133 case omp::ClauseMemoryOrderKind::Seq_cst:
3134 return llvm::AtomicOrdering::SequentiallyConsistent;
3135 case omp::ClauseMemoryOrderKind::Acq_rel:
3136 return llvm::AtomicOrdering::AcquireRelease;
3137 case omp::ClauseMemoryOrderKind::Acquire:
3138 return llvm::AtomicOrdering::Acquire;
3139 case omp::ClauseMemoryOrderKind::Release:
3140 return llvm::AtomicOrdering::Release;
3141 case omp::ClauseMemoryOrderKind::Relaxed:
3142 return llvm::AtomicOrdering::Monotonic;
3144 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3148 static LogicalResult
3150 LLVM::ModuleTranslation &moduleTranslation) {
3151 auto readOp = cast<omp::AtomicReadOp>(opInst);
3155 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3156 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3159 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3162 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3163 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3165 llvm::Type *elementType =
3166 moduleTranslation.convertType(readOp.getElementType());
3168 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3169 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3170 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3175 static LogicalResult
3177 LLVM::ModuleTranslation &moduleTranslation) {
3178 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3182 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3183 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3186 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3188 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3189 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3190 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3191 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3194 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3202 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3203 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3204 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3205 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3206 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3207 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3208 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3209 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3210 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3211 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3215 bool &isIgnoreDenormalMode,
3216 bool &isFineGrainedMemory,
3217 bool &isRemoteMemory) {
3218 isIgnoreDenormalMode =
false;
3219 isFineGrainedMemory =
false;
3220 isRemoteMemory =
false;
3221 if (atomicUpdateOp &&
3222 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3223 mlir::omp::AtomicControlAttr atomicControlAttr =
3224 atomicUpdateOp.getAtomicControlAttr();
3225 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3226 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3227 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3232 static LogicalResult
3234 llvm::IRBuilderBase &builder,
3235 LLVM::ModuleTranslation &moduleTranslation) {
3236 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3241 auto &innerOpList = opInst.getRegion().front().getOperations();
3242 bool isXBinopExpr{
false};
3243 llvm::AtomicRMWInst::BinOp binop;
3245 llvm::Value *llvmExpr =
nullptr;
3246 llvm::Value *llvmX =
nullptr;
3247 llvm::Type *llvmXElementType =
nullptr;
3248 if (innerOpList.size() == 2) {
3254 opInst.getRegion().getArgument(0))) {
3255 return opInst.emitError(
"no atomic update operation with region argument"
3256 " as operand found inside atomic.update region");
3259 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3261 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3265 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3267 llvmX = moduleTranslation.lookupValue(opInst.getX());
3268 llvmXElementType = moduleTranslation.convertType(
3269 opInst.getRegion().getArgument(0).getType());
3270 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3274 llvm::AtomicOrdering atomicOrdering =
3279 [&opInst, &moduleTranslation](
3280 llvm::Value *atomicx,
3283 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3284 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3285 if (
failed(moduleTranslation.convertBlock(bb,
true, builder)))
3286 return llvm::make_error<PreviouslyReportedError>();
3288 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3289 assert(yieldop && yieldop.getResults().size() == 1 &&
3290 "terminator must be omp.yield op and it must have exactly one "
3292 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3295 bool isIgnoreDenormalMode;
3296 bool isFineGrainedMemory;
3297 bool isRemoteMemory;
3302 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3303 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3304 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3305 atomicOrdering, binop, updateFn,
3306 isXBinopExpr, isIgnoreDenormalMode,
3307 isFineGrainedMemory, isRemoteMemory);
3312 builder.restoreIP(*afterIP);
3316 static LogicalResult
3318 llvm::IRBuilderBase &builder,
3319 LLVM::ModuleTranslation &moduleTranslation) {
3320 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3325 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3326 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3328 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3329 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3331 assert((atomicUpdateOp || atomicWriteOp) &&
3332 "internal op must be an atomic.update or atomic.write op");
3334 if (atomicWriteOp) {
3335 isPostfixUpdate =
true;
3336 mlirExpr = atomicWriteOp.getExpr();
3338 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3339 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3340 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3343 if (innerOpList.size() == 2) {
3346 atomicUpdateOp.getRegion().getArgument(0))) {
3347 return atomicUpdateOp.emitError(
3348 "no atomic update operation with region argument"
3349 " as operand found inside atomic.update region");
3353 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3356 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3360 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3361 llvm::Value *llvmX =
3362 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3363 llvm::Value *llvmV =
3364 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3365 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3366 atomicCaptureOp.getAtomicReadOp().getElementType());
3367 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3370 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3374 llvm::AtomicOrdering atomicOrdering =
3378 [&](llvm::Value *atomicx,
3381 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
3382 Block &bb = *atomicUpdateOp.getRegion().
begin();
3383 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
3385 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3386 if (
failed(moduleTranslation.convertBlock(bb,
true, builder)))
3387 return llvm::make_error<PreviouslyReportedError>();
3389 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3390 assert(yieldop && yieldop.getResults().size() == 1 &&
3391 "terminator must be omp.yield op and it must have exactly one "
3393 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3396 bool isIgnoreDenormalMode;
3397 bool isFineGrainedMemory;
3398 bool isRemoteMemory;
3400 isFineGrainedMemory, isRemoteMemory);
3403 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3404 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3405 ompBuilder->createAtomicCapture(
3406 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3407 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
3408 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
3413 builder.restoreIP(*afterIP);
3418 omp::ClauseCancellationConstructType directive) {
3419 switch (directive) {
3420 case omp::ClauseCancellationConstructType::Loop:
3421 return llvm::omp::Directive::OMPD_for;
3422 case omp::ClauseCancellationConstructType::Parallel:
3423 return llvm::omp::Directive::OMPD_parallel;
3424 case omp::ClauseCancellationConstructType::Sections:
3425 return llvm::omp::Directive::OMPD_sections;
3426 case omp::ClauseCancellationConstructType::Taskgroup:
3427 return llvm::omp::Directive::OMPD_taskgroup;
3429 llvm_unreachable(
"Unhandled cancellation construct type");
3432 static LogicalResult
3434 LLVM::ModuleTranslation &moduleTranslation) {
3438 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3439 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3441 llvm::Value *ifCond =
nullptr;
3442 if (
Value ifVar = op.getIfExpr())
3443 ifCond = moduleTranslation.lookupValue(ifVar);
3445 llvm::omp::Directive cancelledDirective =
3448 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3449 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3454 builder.restoreIP(afterIP.get());
3459 static LogicalResult
3461 llvm::IRBuilderBase &builder,
3462 LLVM::ModuleTranslation &moduleTranslation) {
3466 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3467 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3469 llvm::omp::Directive cancelledDirective =
3472 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3473 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3478 builder.restoreIP(afterIP.get());
3485 static LogicalResult
3487 LLVM::ModuleTranslation &moduleTranslation) {
3488 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3489 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3490 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3495 Value symAddr = threadprivateOp.getSymAddr();
3498 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3501 if (!isa<LLVM::AddressOfOp>(symOp))
3502 return opInst.
emitError(
"Addressing symbol not found");
3503 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3505 LLVM::GlobalOp global =
3506 addressOfOp.getGlobal(moduleTranslation.symbolTable());
3507 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
3509 if (!ompBuilder->Config.isTargetDevice()) {
3510 llvm::Type *type = globalValue->getValueType();
3511 llvm::TypeSize typeSize =
3512 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3514 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3515 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3516 ompLoc, globalValue, size, global.getSymName() +
".cache");
3517 moduleTranslation.mapValue(opInst.
getResult(0), callInst);
3519 moduleTranslation.mapValue(opInst.
getResult(0), globalValue);
3525 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3527 switch (deviceClause) {
3528 case mlir::omp::DeclareTargetDeviceType::host:
3529 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3531 case mlir::omp::DeclareTargetDeviceType::nohost:
3532 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3534 case mlir::omp::DeclareTargetDeviceType::any:
3535 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3538 llvm_unreachable(
"unhandled device clause");
3541 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3543 mlir::omp::DeclareTargetCaptureClause captureClause) {
3544 switch (captureClause) {
3545 case mlir::omp::DeclareTargetCaptureClause::to:
3546 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3547 case mlir::omp::DeclareTargetCaptureClause::link:
3548 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3549 case mlir::omp::DeclareTargetCaptureClause::enter:
3550 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3552 llvm_unreachable(
"unhandled capture clause");
3557 llvm::OpenMPIRBuilder &ompBuilder) {
3559 llvm::raw_svector_ostream os(suffix);
3562 auto fileInfoCallBack = [&loc]() {
3563 return std::pair<std::string, uint64_t>(
3564 llvm::StringRef(loc.getFilename()), loc.getLine());
3568 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3570 os <<
"_decl_tgt_ref_ptr";
3576 if (
auto addressOfOp = value.
getDefiningOp<LLVM::AddressOfOp>()) {
3577 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3578 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3579 if (
auto declareTargetGlobal =
3580 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3581 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3582 mlir::omp::DeclareTargetCaptureClause::link)
3591 static llvm::Value *
3593 LLVM::ModuleTranslation &moduleTranslation) {
3594 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3596 if (
auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
3601 if (
auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
3602 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3603 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3604 addressOfOp.getGlobalName()))) {
3606 if (
auto declareTargetGlobal =
3607 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3608 gOp.getOperation())) {
3612 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3613 mlir::omp::DeclareTargetCaptureClause::link) ||
3614 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3615 mlir::omp::DeclareTargetCaptureClause::to &&
3616 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3620 if (gOp.getSymName().contains(suffix))
3621 return moduleTranslation.getLLVMModule()->getNamedValue(
3624 return moduleTranslation.getLLVMModule()->getNamedValue(
3625 (gOp.getSymName().str() + suffix.str()).str());
3636 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3640 void append(MapInfosTy &curInfo) {
3641 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3642 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3651 struct MapInfoData : MapInfosTy {
3663 void append(MapInfoData &CurInfo) {
3664 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3665 CurInfo.IsDeclareTarget.end());
3666 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3667 OriginalValue.append(CurInfo.OriginalValue.begin(),
3668 CurInfo.OriginalValue.end());
3669 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3670 MapInfosTy::append(CurInfo);
3676 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3677 arrTy.getElementType()))
3693 Operation *clauseOp, llvm::Value *basePointer,
3694 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3695 LLVM::ModuleTranslation &moduleTranslation) {
3696 if (
auto memberClause =
3697 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3702 if (!memberClause.getBounds().empty()) {
3703 llvm::Value *elementCount = builder.getInt64(1);
3704 for (
auto bounds : memberClause.getBounds()) {
3705 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3706 bounds.getDefiningOp())) {
3711 elementCount = builder.CreateMul(
3715 moduleTranslation.lookupValue(boundOp.getUpperBound()),
3716 moduleTranslation.lookupValue(boundOp.getLowerBound())),
3717 builder.getInt64(1)));
3724 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3732 return builder.CreateMul(elementCount,
3733 builder.getInt64(underlyingTypeSzInBits / 8));
3742 LLVM::ModuleTranslation &moduleTranslation,
DataLayout &dl,
3746 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3754 for (
Value mapValue : mapVars) {
3755 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3756 for (
auto member : map.getMembers())
3757 if (member == mapOp)
3764 for (
Value mapValue : mapVars) {
3765 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3767 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3768 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
3769 mapData.Pointers.push_back(mapData.OriginalValue.back());
3771 if (llvm::Value *refPtr =
3773 moduleTranslation)) {
3774 mapData.IsDeclareTarget.push_back(
true);
3775 mapData.BasePointers.push_back(refPtr);
3777 mapData.IsDeclareTarget.push_back(
false);
3778 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3781 mapData.BaseType.push_back(
3782 moduleTranslation.convertType(mapOp.getVarType()));
3783 mapData.Sizes.push_back(
3784 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3785 mapData.BaseType.back(), builder, moduleTranslation));
3786 mapData.MapClause.push_back(mapOp.getOperation());
3787 mapData.Types.push_back(
3788 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3790 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
3792 if (mapOp.getMapperId())
3793 mapData.Mappers.push_back(
3794 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3795 mapOp, mapOp.getMapperIdAttr()));
3797 mapData.Mappers.push_back(
nullptr);
3798 mapData.IsAMapping.push_back(
true);
3799 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3802 auto findMapInfo = [&mapData](llvm::Value *val,
3803 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3806 for (llvm::Value *basePtr : mapData.OriginalValue) {
3807 if (basePtr == val && mapData.IsAMapping[index]) {
3809 mapData.Types[index] |=
3810 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3811 mapData.DevicePointers[index] = devInfoTy;
3820 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3821 for (
Value mapValue : useDevOperands) {
3822 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3824 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3825 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
3828 if (!findMapInfo(origValue, devInfoTy)) {
3829 mapData.OriginalValue.push_back(origValue);
3830 mapData.Pointers.push_back(mapData.OriginalValue.back());
3831 mapData.IsDeclareTarget.push_back(
false);
3832 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3833 mapData.BaseType.push_back(
3834 moduleTranslation.convertType(mapOp.getVarType()));
3835 mapData.Sizes.push_back(builder.getInt64(0));
3836 mapData.MapClause.push_back(mapOp.getOperation());
3837 mapData.Types.push_back(
3838 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3840 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
3841 mapData.DevicePointers.push_back(devInfoTy);
3842 mapData.Mappers.push_back(
nullptr);
3843 mapData.IsAMapping.push_back(
false);
3844 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3849 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3850 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3852 for (
Value mapValue : hasDevAddrOperands) {
3853 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3855 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3856 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
3858 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3859 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3861 mapData.OriginalValue.push_back(origValue);
3862 mapData.BasePointers.push_back(origValue);
3863 mapData.Pointers.push_back(origValue);
3864 mapData.IsDeclareTarget.push_back(
false);
3865 mapData.BaseType.push_back(
3866 moduleTranslation.convertType(mapOp.getVarType()));
3867 mapData.Sizes.push_back(
3868 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3869 mapData.MapClause.push_back(mapOp.getOperation());
3870 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3874 mapData.Types.push_back(mapType);
3878 if (mapOp.getMapperId()) {
3879 mapData.Mappers.push_back(
3880 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3881 mapOp, mapOp.getMapperIdAttr()));
3883 mapData.Mappers.push_back(
nullptr);
3886 mapData.Types.push_back(
3887 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3888 mapData.Mappers.push_back(
nullptr);
3891 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
3892 mapData.DevicePointers.push_back(
3893 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3894 mapData.IsAMapping.push_back(
false);
3895 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3900 auto *res = llvm::find(mapData.MapClause, memberOp);
3901 assert(res != mapData.MapClause.end() &&
3902 "MapInfoOp for member not found in MapData, cannot return index");
3903 return std::distance(mapData.MapClause.begin(), res);
3908 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3910 if (indexAttr.size() == 1)
3911 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3914 std::iota(indices.begin(), indices.end(), 0);
3916 llvm::sort(indices, [&](
const size_t a,
const size_t b) {
3917 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3918 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3919 for (
const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3920 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3921 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3923 if (aIndex == bIndex)
3926 if (aIndex < bIndex)
3929 if (aIndex > bIndex)
3936 return memberIndicesA.size() < memberIndicesB.size();
3939 return llvm::cast<omp::MapInfoOp>(
3940 mapInfo.getMembers()[indices.front()].getDefiningOp());
3962 std::vector<llvm::Value *>
3964 llvm::IRBuilderBase &builder,
bool isArrayTy,
3966 std::vector<llvm::Value *> idx;
3977 idx.push_back(builder.getInt64(0));
3978 for (
int i = bounds.size() - 1; i >= 0; --i) {
3979 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3980 bounds[i].getDefiningOp())) {
3981 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4003 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
4004 for (
size_t i = 1; i < bounds.size(); ++i) {
4005 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4006 bounds[i].getDefiningOp())) {
4007 dimensionIndexSizeOffset.push_back(builder.CreateMul(
4008 moduleTranslation.lookupValue(boundOp.getExtent()),
4009 dimensionIndexSizeOffset[i - 1]));
4017 for (
int i = bounds.size() - 1; i >= 0; --i) {
4018 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4019 bounds[i].getDefiningOp())) {
4021 idx.emplace_back(builder.CreateMul(
4022 moduleTranslation.lookupValue(boundOp.getLowerBound()),
4023 dimensionIndexSizeOffset[i]));
4025 idx.back() = builder.CreateAdd(
4026 idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
4027 boundOp.getLowerBound()),
4028 dimensionIndexSizeOffset[i]));
4052 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4053 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4054 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
4055 assert(!ompBuilder.Config.isTargetDevice() &&
4056 "function only supported for host device codegen");
4059 combinedInfo.Types.emplace_back(
4061 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4062 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
4063 combinedInfo.DevicePointers.emplace_back(
4064 mapData.DevicePointers[mapDataIndex]);
4065 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4067 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4068 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4078 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4080 llvm::Value *lowAddr, *highAddr;
4081 if (!parentClause.getPartialMap()) {
4082 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4083 builder.getPtrTy());
4084 highAddr = builder.CreatePointerCast(
4085 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4086 mapData.Pointers[mapDataIndex], 1),
4087 builder.getPtrTy());
4088 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4090 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4093 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4094 builder.getPtrTy());
4097 highAddr = builder.CreatePointerCast(
4098 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4099 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4100 builder.getPtrTy());
4101 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4104 llvm::Value *size = builder.CreateIntCast(
4105 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4106 builder.getInt64Ty(),
4108 combinedInfo.Sizes.push_back(size);
4110 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4111 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4119 if (!parentClause.getPartialMap()) {
4124 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4125 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4126 combinedInfo.Types.emplace_back(mapFlag);
4127 combinedInfo.DevicePointers.emplace_back(
4129 combinedInfo.Mappers.emplace_back(
nullptr);
4131 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4132 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4133 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4134 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4136 return memberOfFlag;
4148 if (mapOp.getVarPtrPtr())
4162 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4163 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4164 MapInfoData &mapData, uint64_t mapDataIndex,
4165 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4166 assert(!ompBuilder.Config.isTargetDevice() &&
4167 "function only supported for host device codegen");
4170 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4172 for (
auto mappedMembers : parentClause.getMembers()) {
4174 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4177 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4188 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4189 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4190 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4191 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4192 combinedInfo.Types.emplace_back(mapFlag);
4193 combinedInfo.DevicePointers.emplace_back(
4195 combinedInfo.Mappers.emplace_back(
nullptr);
4196 combinedInfo.Names.emplace_back(
4198 combinedInfo.BasePointers.emplace_back(
4199 mapData.BasePointers[mapDataIndex]);
4200 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4201 combinedInfo.Sizes.emplace_back(builder.getInt64(
4202 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
4208 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4209 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4210 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4211 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4213 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4215 combinedInfo.Types.emplace_back(mapFlag);
4216 combinedInfo.DevicePointers.emplace_back(
4217 mapData.DevicePointers[memberDataIdx]);
4218 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4219 combinedInfo.Names.emplace_back(
4221 uint64_t basePointerIndex =
4223 combinedInfo.BasePointers.emplace_back(
4224 mapData.BasePointers[basePointerIndex]);
4225 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4227 llvm::Value *size = mapData.Sizes[memberDataIdx];
4229 size = builder.CreateSelect(
4230 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4231 builder.getInt64(0), size);
4234 combinedInfo.Sizes.emplace_back(size);
4239 MapInfosTy &combinedInfo,
bool isTargetParams,
4240 int mapDataParentIdx = -1) {
4244 auto mapFlag = mapData.Types[mapDataIdx];
4245 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4249 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4251 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4252 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4254 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4256 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4261 if (mapDataParentIdx >= 0)
4262 combinedInfo.BasePointers.emplace_back(
4263 mapData.BasePointers[mapDataParentIdx]);
4265 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4267 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4268 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4269 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4270 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4271 combinedInfo.Types.emplace_back(mapFlag);
4272 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4276 llvm::IRBuilderBase &builder,
4277 llvm::OpenMPIRBuilder &ompBuilder,
4279 MapInfoData &mapData, uint64_t mapDataIndex,
4280 bool isTargetParams) {
4281 assert(!ompBuilder.Config.isTargetDevice() &&
4282 "function only supported for host device codegen");
4285 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4290 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4291 auto memberClause = llvm::cast<omp::MapInfoOp>(
4292 parentClause.getMembers()[0].getDefiningOp());
4309 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4311 combinedInfo, mapData, mapDataIndex, isTargetParams);
4313 combinedInfo, mapData, mapDataIndex,
4314 memberOfParentFlag);
4323 LLVM::ModuleTranslation &moduleTranslation,
4324 llvm::IRBuilderBase &builder) {
4325 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4326 "function only supported for host device codegen");
4327 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4329 if (!mapData.IsDeclareTarget[i]) {
4330 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4331 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4341 switch (captureKind) {
4342 case omp::VariableCaptureKind::ByRef: {
4343 llvm::Value *newV = mapData.Pointers[i];
4345 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4348 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4350 if (!offsetIdx.empty())
4351 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4353 mapData.Pointers[i] = newV;
4355 case omp::VariableCaptureKind::ByCopy: {
4356 llvm::Type *type = mapData.BaseType[i];
4358 if (mapData.Pointers[i]->getType()->isPointerTy())
4359 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4361 newV = mapData.Pointers[i];
4364 auto curInsert = builder.saveIP();
4365 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
4367 auto *memTempAlloc =
4368 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4369 builder.SetCurrentDebugLocation(DbgLoc);
4370 builder.restoreIP(curInsert);
4372 builder.CreateStore(newV, memTempAlloc);
4373 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4376 mapData.Pointers[i] = newV;
4377 mapData.BasePointers[i] = newV;
4379 case omp::VariableCaptureKind::This:
4380 case omp::VariableCaptureKind::VLAType:
4381 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4390 LLVM::ModuleTranslation &moduleTranslation,
4392 MapInfoData &mapData,
bool isTargetParams =
false) {
4393 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4394 "function only supported for host device codegen");
4409 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4416 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4419 if (mapData.IsAMember[i])
4422 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4423 if (!mapInfoOp.getMembers().empty()) {
4425 combinedInfo, mapData, i, isTargetParams);
4435 LLVM::ModuleTranslation &moduleTranslation,
4436 llvm::StringRef mapperFuncName);
4440 LLVM::ModuleTranslation &moduleTranslation) {
4441 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4442 "function only supported for host device codegen");
4443 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4444 std::string mapperFuncName =
4445 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
4446 {
"omp_mapper", declMapperOp.getSymName()});
4448 if (
auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
4457 LLVM::ModuleTranslation &moduleTranslation,
4458 llvm::StringRef mapperFuncName) {
4459 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4460 "function only supported for host device codegen");
4461 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4462 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4464 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4465 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
4468 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4471 MapInfosTy combinedInfo;
4473 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4474 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4475 builder.restoreIP(codeGenIP);
4476 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
4477 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
4478 builder.GetInsertBlock());
4479 if (
failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
4482 return llvm::make_error<PreviouslyReportedError>();
4483 MapInfoData mapData;
4486 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4490 moduleTranslation.forgetMapping(declMapperOp.getRegion());
4491 return combinedInfo;
4495 if (!combinedInfo.Mappers[i])
4502 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4504 return newFn.takeError();
4505 moduleTranslation.mapFunction(mapperFuncName, *newFn);
4509 static LogicalResult
4511 LLVM::ModuleTranslation &moduleTranslation) {
4512 llvm::Value *ifCond =
nullptr;
4513 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4517 llvm::omp::RuntimeFunction RTLFn;
4520 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4521 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4523 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4524 bool isOffloadEntry =
4525 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4527 LogicalResult result =
4529 .Case([&](omp::TargetDataOp dataOp) {
4533 if (
auto ifVar = dataOp.getIfExpr())
4534 ifCond = moduleTranslation.lookupValue(ifVar);
4536 if (
auto devId = dataOp.getDevice())
4537 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4538 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4539 deviceID = intAttr.getInt();
4541 mapVars = dataOp.getMapVars();
4542 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4543 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4546 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4550 if (
auto ifVar = enterDataOp.getIfExpr())
4551 ifCond = moduleTranslation.lookupValue(ifVar);
4553 if (
auto devId = enterDataOp.getDevice())
4554 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4555 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4556 deviceID = intAttr.getInt();
4558 enterDataOp.getNowait()
4559 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4560 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4561 mapVars = enterDataOp.getMapVars();
4562 info.HasNoWait = enterDataOp.getNowait();
4565 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4569 if (
auto ifVar = exitDataOp.getIfExpr())
4570 ifCond = moduleTranslation.lookupValue(ifVar);
4572 if (
auto devId = exitDataOp.getDevice())
4573 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4574 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4575 deviceID = intAttr.getInt();
4577 RTLFn = exitDataOp.getNowait()
4578 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4579 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4580 mapVars = exitDataOp.getMapVars();
4581 info.HasNoWait = exitDataOp.getNowait();
4584 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4588 if (
auto ifVar = updateDataOp.getIfExpr())
4589 ifCond = moduleTranslation.lookupValue(ifVar);
4591 if (
auto devId = updateDataOp.getDevice())
4592 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4593 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4594 deviceID = intAttr.getInt();
4597 updateDataOp.getNowait()
4598 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4599 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4600 mapVars = updateDataOp.getMapVars();
4601 info.HasNoWait = updateDataOp.getNowait();
4605 llvm_unreachable(
"unexpected operation");
4612 if (!isOffloadEntry)
4613 ifCond = builder.getFalse();
4615 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4616 MapInfoData mapData;
4618 builder, useDevicePtrVars, useDeviceAddrVars);
4621 MapInfosTy combinedInfo;
4622 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4623 builder.restoreIP(codeGenIP);
4624 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4625 return combinedInfo;
4631 [&moduleTranslation](
4632 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4636 for (
auto [arg, useDevVar] :
4637 llvm::zip_equal(blockArgs, useDeviceVars)) {
4639 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4640 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4641 : mapInfoOp.getVarPtr();
4644 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4645 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4646 mapInfoData.MapClause, mapInfoData.DevicePointers,
4647 mapInfoData.BasePointers)) {
4648 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4649 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4650 devicePointer != type)
4653 if (llvm::Value *devPtrInfoMap =
4654 mapper ? mapper(basePointer) : basePointer) {
4655 moduleTranslation.mapValue(arg, devPtrInfoMap);
4662 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4663 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4664 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4667 builder.restoreIP(codeGenIP);
4668 assert(isa<omp::TargetDataOp>(op) &&
4669 "BodyGen requested for non TargetDataOp");
4670 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4671 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4672 switch (bodyGenType) {
4673 case BodyGenTy::Priv:
4675 if (!info.DevicePtrInfoMap.empty()) {
4676 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4677 blockArgIface.getUseDeviceAddrBlockArgs(),
4678 useDeviceAddrVars, mapData,
4679 [&](llvm::Value *basePointer) -> llvm::Value * {
4680 if (!info.DevicePtrInfoMap[basePointer].second)
4682 return builder.CreateLoad(
4684 info.DevicePtrInfoMap[basePointer].second);
4686 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4687 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4688 mapData, [&](llvm::Value *basePointer) {
4689 return info.DevicePtrInfoMap[basePointer].second;
4693 moduleTranslation)))
4694 return llvm::make_error<PreviouslyReportedError>();
4697 case BodyGenTy::DupNoPriv:
4698 if (info.DevicePtrInfoMap.empty()) {
4701 if (!ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4702 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4703 blockArgIface.getUseDeviceAddrBlockArgs(),
4704 useDeviceAddrVars, mapData);
4705 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4706 blockArgIface.getUseDevicePtrBlockArgs(),
4707 useDevicePtrVars, mapData);
4711 case BodyGenTy::NoPriv:
4713 if (info.DevicePtrInfoMap.empty()) {
4716 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4717 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4718 blockArgIface.getUseDeviceAddrBlockArgs(),
4719 useDeviceAddrVars, mapData);
4720 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4721 blockArgIface.getUseDevicePtrBlockArgs(),
4722 useDevicePtrVars, mapData);
4726 moduleTranslation)))
4727 return llvm::make_error<PreviouslyReportedError>();
4731 return builder.saveIP();
4734 auto customMapperCB =
4736 if (!combinedInfo.Mappers[i])
4738 info.HasMapper =
true;
4743 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4744 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4746 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4747 if (isa<omp::TargetDataOp>(op))
4748 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4749 builder.getInt64(deviceID), ifCond,
4750 info, genMapInfoCB, customMapperCB,
4753 return ompBuilder->createTargetData(
4754 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4755 info, genMapInfoCB, customMapperCB, &RTLFn);
4761 builder.restoreIP(*afterIP);
4765 static LogicalResult
4767 LLVM::ModuleTranslation &moduleTranslation) {
4768 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4769 auto distributeOp = cast<omp::DistributeOp>(opInst);
4776 bool doDistributeReduction =
4780 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4785 if (doDistributeReduction) {
4786 isByRef =
getIsByRef(teamsOp.getReductionByref());
4787 assert(isByRef.size() == teamsOp.getNumReductionVars());
4790 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4794 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4795 .getReductionBlockArgs();
4798 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4799 reductionDecls, privateReductionVariables, reductionVariableMap,
4804 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4805 auto bodyGenCB = [&](InsertPointTy allocaIP,
4806 InsertPointTy codeGenIP) -> llvm::Error {
4810 moduleTranslation, allocaIP);
4813 builder.restoreIP(codeGenIP);
4819 return llvm::make_error<PreviouslyReportedError>();
4824 return llvm::make_error<PreviouslyReportedError>();
4827 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
4829 distributeOp.getPrivateNeedsBarrier())))
4830 return llvm::make_error<PreviouslyReportedError>();
4832 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4833 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4836 builder, moduleTranslation);
4838 return regionBlock.takeError();
4839 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4844 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4847 auto schedule = omp::ClauseScheduleKind::Static;
4848 bool isOrdered =
false;
4849 std::optional<omp::ScheduleModifier> scheduleMod;
4850 bool isSimd =
false;
4851 llvm::omp::WorksharingLoopType workshareLoopType =
4852 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4853 bool loopNeedsBarrier =
false;
4854 llvm::Value *chunk =
nullptr;
4856 llvm::CanonicalLoopInfo *loopInfo =
4858 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4859 ompBuilder->applyWorkshareLoop(
4860 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4861 convertToScheduleKind(schedule), chunk, isSimd,
4862 scheduleMod == omp::ScheduleModifier::monotonic,
4863 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4867 return wsloopIP.takeError();
4871 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4873 return llvm::make_error<PreviouslyReportedError>();
4875 return llvm::Error::success();
4878 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4880 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4881 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4882 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4887 builder.restoreIP(*afterIP);
4889 if (doDistributeReduction) {
4892 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4893 privateReductionVariables, isByRef,
4903 LLVM::ModuleTranslation &moduleTranslation) {
4904 if (!cast<mlir::ModuleOp>(op))
4907 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4909 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
4910 attribute.getOpenmpDeviceVersion());
4912 if (attribute.getNoGpuLib())
4915 ompBuilder->createGlobalFlag(
4916 attribute.getDebugKind() ,
4917 "__omp_rtl_debug_kind");
4918 ompBuilder->createGlobalFlag(
4920 .getAssumeTeamsOversubscription()
4922 "__omp_rtl_assume_teams_oversubscription");
4923 ompBuilder->createGlobalFlag(
4925 .getAssumeThreadsOversubscription()
4927 "__omp_rtl_assume_threads_oversubscription");
4928 ompBuilder->createGlobalFlag(
4929 attribute.getAssumeNoThreadState() ,
4930 "__omp_rtl_assume_no_thread_state");
4931 ompBuilder->createGlobalFlag(
4933 .getAssumeNoNestedParallelism()
4935 "__omp_rtl_assume_no_nested_parallelism");
4940 omp::TargetOp targetOp,
4941 llvm::StringRef parentName =
"") {
4942 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
4944 assert(fileLoc &&
"No file found from location");
4945 StringRef fileName = fileLoc.getFilename().getValue();
4947 llvm::sys::fs::UniqueID id;
4948 uint64_t line = fileLoc.getLine();
4949 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
4951 size_t deviceId = 0xdeadf17e;
4953 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4955 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
4956 id.getFile(), line);
4962 LLVM::ModuleTranslation &moduleTranslation,
4963 llvm::IRBuilderBase &builder, llvm::Function *func) {
4964 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4965 "function only supported for target device codegen");
4966 llvm::IRBuilderBase::InsertPointGuard guard(builder);
4967 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4980 if (mapData.IsDeclareTarget[i]) {
4987 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
4988 convertUsersOfConstantsToInstructions(constant, func,
false);
4995 for (llvm::User *user : mapData.OriginalValue[i]->users())
4996 userVec.push_back(user);
4998 for (llvm::User *user : userVec) {
4999 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5000 if (insn->getFunction() == func) {
5001 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5002 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
5003 mapData.BasePointers[i]);
5004 load->moveBefore(insn->getIterator());
5005 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
5052 static llvm::IRBuilderBase::InsertPoint
5054 llvm::Value *input, llvm::Value *&retVal,
5055 llvm::IRBuilderBase &builder,
5056 llvm::OpenMPIRBuilder &ompBuilder,
5057 LLVM::ModuleTranslation &moduleTranslation,
5058 llvm::IRBuilderBase::InsertPoint allocaIP,
5059 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5060 assert(ompBuilder.Config.isTargetDevice() &&
5061 "function only supported for target device codegen");
5062 builder.restoreIP(allocaIP);
5064 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5065 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
5066 ompBuilder.M.getContext());
5067 unsigned alignmentValue = 0;
5069 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
5070 if (mapData.OriginalValue[i] == input) {
5071 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5072 capture = mapOp.getMapCaptureType();
5074 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
5075 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5079 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5080 unsigned int defaultAS =
5081 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5084 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5086 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5087 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5089 builder.CreateStore(&arg, v);
5091 builder.restoreIP(codeGenIP);
5094 case omp::VariableCaptureKind::ByCopy: {
5098 case omp::VariableCaptureKind::ByRef: {
5099 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5101 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5116 if (v->getType()->isPointerTy() && alignmentValue) {
5117 llvm::MDBuilder MDB(builder.getContext());
5118 loadInst->setMetadata(
5119 llvm::LLVMContext::MD_align,
5122 llvm::Type::getInt64Ty(builder.getContext()),
5129 case omp::VariableCaptureKind::This:
5130 case omp::VariableCaptureKind::VLAType:
5133 assert(
false &&
"Currently unsupported capture kind");
5137 return builder.saveIP();
5154 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
5155 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5156 blockArgIface.getHostEvalBlockArgs())) {
5157 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5161 .Case([&](omp::TeamsOp teamsOp) {
5162 if (teamsOp.getNumTeamsLower() == blockArg)
5163 numTeamsLower = hostEvalVar;
5164 else if (teamsOp.getNumTeamsUpper() == blockArg)
5165 numTeamsUpper = hostEvalVar;
5166 else if (teamsOp.getThreadLimit() == blockArg)
5167 threadLimit = hostEvalVar;
5169 llvm_unreachable(
"unsupported host_eval use");
5171 .Case([&](omp::ParallelOp parallelOp) {
5172 if (parallelOp.getNumThreads() == blockArg)
5173 numThreads = hostEvalVar;
5175 llvm_unreachable(
"unsupported host_eval use");
5177 .Case([&](omp::LoopNestOp loopOp) {
5178 auto processBounds =
5183 if (lb == blockArg) {
5186 (*outBounds)[i] = hostEvalVar;
5192 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5193 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5195 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5197 assert(found &&
"unsupported host_eval use");
5200 llvm_unreachable(
"unsupported host_eval use");
5213 template <
typename OpTy>
5218 if (OpTy casted = dyn_cast<OpTy>(op))
5221 if (immediateParent)
5222 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5231 return std::nullopt;
5234 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5235 return constAttr.getInt();
5237 return std::nullopt;
5242 uint64_t sizeInBytes = sizeInBits / 8;
5246 template <
typename OpTy>
5248 if (op.getNumReductionVars() > 0) {
5253 members.reserve(reductions.size());
5254 for (omp::DeclareReductionOp &red : reductions)
5255 members.push_back(red.getType());
5257 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5273 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5274 bool isTargetDevice,
bool isGPU) {
5277 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5278 if (!isTargetDevice) {
5285 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5286 numTeamsLower = teamsOp.getNumTeamsLower();
5287 numTeamsUpper = teamsOp.getNumTeamsUpper();
5288 threadLimit = teamsOp.getThreadLimit();
5291 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5292 numThreads = parallelOp.getNumThreads();
5297 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5298 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5301 if (numTeamsUpper) {
5303 minTeamsVal = maxTeamsVal = *val;
5305 minTeamsVal = maxTeamsVal = 0;
5307 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5309 castOrGetParentOfType<omp::SimdOp>(capturedOp,
5311 minTeamsVal = maxTeamsVal = 1;
5313 minTeamsVal = maxTeamsVal = -1;
5318 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
5332 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5333 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5334 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5337 int32_t maxThreadsVal = -1;
5338 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5339 setMaxValueFromClause(numThreads, maxThreadsVal);
5340 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5347 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5348 if (combinedMaxThreadsVal < 0 ||
5349 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5350 combinedMaxThreadsVal = teamsThreadLimitVal;
5352 if (combinedMaxThreadsVal < 0 ||
5353 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5354 combinedMaxThreadsVal = maxThreadsVal;
5356 int32_t reductionDataSize = 0;
5357 if (isGPU && capturedOp) {
5358 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5363 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5365 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5366 omp::TargetRegionFlags::spmd) &&
5367 "invalid kernel flags");
5369 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5370 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5371 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5372 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5373 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5374 attrs.MinTeams = minTeamsVal;
5375 attrs.MaxTeams.front() = maxTeamsVal;
5376 attrs.MinThreads = 1;
5377 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5378 attrs.ReductionDataSize = reductionDataSize;
5381 if (attrs.ReductionDataSize != 0)
5382 attrs.ReductionBufferLength = 1024;
5393 LLVM::ModuleTranslation &moduleTranslation,
5394 omp::TargetOp targetOp,
Operation *capturedOp,
5395 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5396 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5397 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5399 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5403 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5406 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5407 attrs.TargetThreadLimit.front() =
5408 moduleTranslation.lookupValue(targetThreadLimit);
5411 attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
5414 attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);
5416 if (teamsThreadLimit)
5417 attrs.TeamsThreadLimit.front() =
5418 moduleTranslation.lookupValue(teamsThreadLimit);
5421 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
5423 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5424 omp::TargetRegionFlags::trip_count)) {
5425 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5426 attrs.LoopTripCount =
nullptr;
5431 for (
auto [loopLower, loopUpper, loopStep] :
5432 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5433 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
5434 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
5435 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
5437 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5438 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5439 loc, lowerBound, upperBound, step,
true,
5440 loopOp.getLoopInclusive());
5442 if (!attrs.LoopTripCount) {
5443 attrs.LoopTripCount = tripCount;
5448 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5454 static LogicalResult
5456 LLVM::ModuleTranslation &moduleTranslation) {
5457 auto targetOp = cast<omp::TargetOp>(opInst);
5461 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
5470 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
5471 assert(parentBB &&
"No insert block is set for the builder");
5472 llvm::Function *parentLLVMFn = parentBB->getParent();
5473 assert(parentLLVMFn &&
"Parent Function must be valid");
5474 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
5476 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
5477 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
5479 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5480 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5481 bool isGPU = ompBuilder->Config.isGPU();
5484 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5485 auto &targetRegion = targetOp.getRegion();
5502 llvm::Function *llvmOutlinedFn =
nullptr;
5506 bool isOffloadEntry =
5507 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5514 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5516 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5517 std::optional<DenseI64ArrayAttr> privateMapIndices =
5518 targetOp.getPrivateMapsAttr();
5520 for (
auto [privVarIdx, privVarSymPair] :
5522 auto privVar = std::get<0>(privVarSymPair);
5523 auto privSym = std::get<1>(privVarSymPair);
5525 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5526 omp::PrivateClauseOp privatizer =
5529 if (!privatizer.needsMap())
5533 targetOp.getMappedValueForPrivateVar(privVarIdx);
5534 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5535 "variable that needs mapping");
5540 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5541 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5545 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5547 varType == privVar.getType() &&
5548 "Type of private var doesn't match the type of the mapped value");
5552 mappedPrivateVars.insert(
5554 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5555 (*privateMapIndices)[privVarIdx])});
5559 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5560 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5561 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5562 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5563 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5566 llvm::Function *llvmParentFn =
5567 moduleTranslation.lookupFunction(parentFn.getName());
5568 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5569 assert(llvmParentFn && llvmOutlinedFn &&
5570 "Both parent and outlined functions must exist at this point");
5572 if (outlinedFnLoc && llvmParentFn->getSubprogram())
5573 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
5575 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5576 attr.isStringAttribute())
5577 llvmOutlinedFn->addFnAttr(attr);
5579 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5580 attr.isStringAttribute())
5581 llvmOutlinedFn->addFnAttr(attr);
5583 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5584 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5585 llvm::Value *mapOpValue =
5586 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
5587 moduleTranslation.mapValue(arg, mapOpValue);
5589 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5590 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5591 llvm::Value *mapOpValue =
5592 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
5593 moduleTranslation.mapValue(arg, mapOpValue);
5602 allocaIP, &mappedPrivateVars);
5605 return llvm::make_error<PreviouslyReportedError>();
5607 builder.restoreIP(codeGenIP);
5609 &mappedPrivateVars),
5612 return llvm::make_error<PreviouslyReportedError>();
5615 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
5617 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5618 return llvm::make_error<PreviouslyReportedError>();
5622 std::back_inserter(privateCleanupRegions),
5623 [](omp::PrivateClauseOp privatizer) {
5624 return &privatizer.getDeallocRegion();
5628 targetRegion,
"omp.target", builder, moduleTranslation);
5631 return exitBlock.takeError();
5633 builder.SetInsertPoint(*exitBlock);
5634 if (!privateCleanupRegions.empty()) {
5636 privateCleanupRegions, privateVarsInfo.
llvmVars,
5637 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5639 return llvm::createStringError(
5640 "failed to inline `dealloc` region of `omp.private` "
5641 "op in the target region");
5643 return builder.saveIP();
5646 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5649 StringRef parentName = parentFn.getName();
5651 llvm::TargetRegionEntryInfo entryInfo;
5655 MapInfoData mapData;
5660 MapInfosTy combinedInfos;
5662 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5663 builder.restoreIP(codeGenIP);
5664 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5665 return combinedInfos;
5668 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5669 llvm::Value *&retVal, InsertPointTy allocaIP,
5670 InsertPointTy codeGenIP)
5671 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5672 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5673 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5679 if (!isTargetDevice) {
5680 retVal = cast<llvm::Value>(&arg);
5685 *ompBuilder, moduleTranslation,
5686 allocaIP, codeGenIP);
5689 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5690 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5691 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5693 isTargetDevice, isGPU);
5697 if (!isTargetDevice)
5699 targetCapturedOp, runtimeAttrs);
5707 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5708 llvm::Value *value = moduleTranslation.lookupValue(var);
5709 moduleTranslation.mapValue(arg, value);
5711 if (!llvm::isa<llvm::Constant>(value))
5712 kernelInput.push_back(value);
5715 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5722 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5723 kernelInput.push_back(mapData.OriginalValue[i]);
5728 moduleTranslation, dds);
5730 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5732 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5734 llvm::OpenMPIRBuilder::TargetDataInfo info(
5738 auto customMapperCB =
5740 if (!combinedInfos.Mappers[i])
5742 info.HasMapper =
true;
5747 llvm::Value *ifCond =
nullptr;
5748 if (
Value targetIfCond = targetOp.getIfExpr())
5749 ifCond = moduleTranslation.lookupValue(targetIfCond);
5751 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5752 moduleTranslation.getOpenMPBuilder()->createTarget(
5753 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5754 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5755 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5760 builder.restoreIP(*afterIP);
5764 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
5771 static LogicalResult
5773 LLVM::ModuleTranslation &moduleTranslation) {
5781 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5782 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5784 if (!offloadMod.getIsTargetDevice())
5787 omp::DeclareTargetDeviceType declareType =
5788 attribute.getDeviceType().getValue();
5790 if (declareType == omp::DeclareTargetDeviceType::host) {
5791 llvm::Function *llvmFunc =
5792 moduleTranslation.lookupFunction(funcOp.getName());
5793 llvmFunc->dropAllReferences();
5794 llvmFunc->eraseFromParent();
5800 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5801 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
5802 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5803 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5804 bool isDeclaration = gOp.isDeclaration();
5805 bool isExternallyVisible =
5808 llvm::StringRef mangledName = gOp.getSymName();
5809 auto captureClause =
5815 std::vector<llvm::GlobalVariable *> generatedRefs;
5817 std::vector<llvm::Triple> targetTriple;
5818 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5820 LLVM::LLVMDialect::getTargetTripleAttrName()));
5821 if (targetTripleAttr)
5822 targetTriple.emplace_back(targetTripleAttr.data());
5824 auto fileInfoCallBack = [&loc]() {
5825 std::string filename =
"";
5826 std::uint64_t lineNo = 0;
5829 filename = loc.getFilename().str();
5830 lineNo = loc.getLine();
5833 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5837 ompBuilder->registerTargetGlobalVariable(
5838 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5839 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5840 generatedRefs,
false, targetTriple,
5842 gVal->getType(), gVal);
5844 if (ompBuilder->Config.isTargetDevice() &&
5845 (attribute.getCaptureClause().getValue() !=
5846 mlir::omp::DeclareTargetCaptureClause::to ||
5847 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5848 ompBuilder->getAddrOfDeclareTargetVar(
5849 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5850 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5851 generatedRefs,
false, targetTriple, gVal->getType(),
5873 if (mlir::isa<omp::ThreadprivateOp>(op))
5876 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
5877 mlir::isa<omp::TargetFreeMemOp>(op))
5881 if (
auto declareTargetIface =
5882 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5883 parentFn.getOperation()))
5884 if (declareTargetIface.isDeclareTarget() &&
5885 declareTargetIface.getDeclareTargetDeviceType() !=
5886 mlir::omp::DeclareTargetDeviceType::host)
5893 llvm::Module *llvmModule) {
5894 llvm::Type *i64Ty = builder.getInt64Ty();
5895 llvm::Type *i32Ty = builder.getInt32Ty();
5896 llvm::Type *returnType = builder.getPtrTy(0);
5897 llvm::FunctionType *fnType =
5899 llvm::Function *func = cast<llvm::Function>(
5900 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
5904 static LogicalResult
5906 LLVM::ModuleTranslation &moduleTranslation) {
5907 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
5912 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
5916 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
5918 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
5919 mlir::Type heapTy = allocMemOp.getAllocatedType();
5920 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
5921 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
5922 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
5923 for (
auto typeParam : allocMemOp.getTypeparams())
5925 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
5927 llvm::CallInst *call =
5928 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
5929 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
5932 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
5937 llvm::Module *llvmModule) {
5938 llvm::Type *ptrTy = builder.getPtrTy(0);
5939 llvm::Type *i32Ty = builder.getInt32Ty();
5940 llvm::Type *voidTy = builder.getVoidTy();
5941 llvm::FunctionType *fnType =
5943 llvm::Function *func = dyn_cast<llvm::Function>(
5944 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
5948 static LogicalResult
5950 LLVM::ModuleTranslation &moduleTranslation) {
5951 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
5956 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
5960 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
5963 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
5965 llvm::Value *intToPtr =
5966 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
5967 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
5973 static LogicalResult
5975 LLVM::ModuleTranslation &moduleTranslation) {
5976 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5984 bool isOutermostLoopWrapper =
5985 isa_and_present<omp::LoopWrapperInterface>(op) &&
5986 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
5988 if (isOutermostLoopWrapper)
5989 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
5993 .Case([&](omp::BarrierOp op) -> LogicalResult {
5997 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5998 ompBuilder->createBarrier(builder.saveIP(),
5999 llvm::omp::OMPD_barrier);
6001 if (res.succeeded()) {
6004 builder.restoreIP(*afterIP);
6008 .Case([&](omp::TaskyieldOp op) {
6012 ompBuilder->createTaskyield(builder.saveIP());
6015 .Case([&](omp::FlushOp op) {
6027 ompBuilder->createFlush(builder.saveIP());
6030 .Case([&](omp::ParallelOp op) {
6033 .Case([&](omp::MaskedOp) {
6036 .Case([&](omp::MasterOp) {
6039 .Case([&](omp::CriticalOp) {
6042 .Case([&](omp::OrderedRegionOp) {
6045 .Case([&](omp::OrderedOp) {
6048 .Case([&](omp::WsloopOp) {
6051 .Case([&](omp::SimdOp) {
6054 .Case([&](omp::AtomicReadOp) {
6057 .Case([&](omp::AtomicWriteOp) {
6060 .Case([&](omp::AtomicUpdateOp op) {
6063 .Case([&](omp::AtomicCaptureOp op) {
6066 .Case([&](omp::CancelOp op) {
6069 .Case([&](omp::CancellationPointOp op) {
6072 .Case([&](omp::SectionsOp) {
6075 .Case([&](omp::SingleOp op) {
6078 .Case([&](omp::TeamsOp op) {
6081 .Case([&](omp::TaskOp op) {
6084 .Case([&](omp::TaskgroupOp op) {
6087 .Case([&](omp::TaskwaitOp op) {
6090 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6091 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6092 omp::CriticalDeclareOp>([](
auto op) {
6105 .Case([&](omp::ThreadprivateOp) {
6108 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
6109 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
6112 .Case([&](omp::TargetOp) {
6115 .Case([&](omp::DistributeOp) {
6118 .Case([&](omp::LoopNestOp) {
6121 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
6128 .Case([&](omp::NewCliOp op) {
6133 .Case([&](omp::CanonicalLoopOp op) {
6136 .Case([&](omp::UnrollHeuristicOp op) {
6145 .Case([&](omp::TargetAllocMemOp) {
6148 .Case([&](omp::TargetFreeMemOp) {
6153 <<
"not yet implemented: " << inst->
getName();
6156 if (isOutermostLoopWrapper)
6157 moduleTranslation.stackPop();
6162 static LogicalResult
6164 LLVM::ModuleTranslation &moduleTranslation) {
6168 static LogicalResult
6170 LLVM::ModuleTranslation &moduleTranslation) {
6171 if (isa<omp::TargetOp>(op))
6173 if (isa<omp::TargetDataOp>(op))
6177 if (isa<omp::TargetOp>(oper)) {
6179 return WalkResult::interrupt();
6180 return WalkResult::skip();
6182 if (isa<omp::TargetDataOp>(oper)) {
6184 return WalkResult::interrupt();
6185 return WalkResult::skip();
6192 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
6193 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
6194 !oper->getRegions().empty()) {
6195 if (
auto blockArgsIface =
6196 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
6202 if (isa<mlir::omp::AtomicUpdateOp>(oper))
6203 for (
auto [operand, arg] :
6204 llvm::zip_equal(oper->getOperands(),
6205 oper->getRegion(0).getArguments())) {
6206 moduleTranslation.mapValue(
6207 arg, builder.CreateLoad(
6208 moduleTranslation.convertType(arg.getType()),
6209 moduleTranslation.lookupValue(operand)));
6213 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
6214 assert(builder.GetInsertBlock() &&
6215 "No insert block is set for the builder");
6216 for (
auto iv : loopNest.getIVs()) {
6218 moduleTranslation.mapValue(
6220 moduleTranslation.convertType(iv.getType())));
6224 for (
Region ®ion : oper->getRegions()) {
6231 region, oper->getName().getStringRef().str() +
".fake.region",
6232 builder, moduleTranslation, &phis);
6234 return WalkResult::interrupt();
6236 builder.SetInsertPoint(result.get(), result.get()->end());
6239 return WalkResult::skip();
6242 return WalkResult::advance();
6243 }).wasInterrupted();
6244 return failure(interrupted);
6251 class OpenMPDialectLLVMIRTranslationInterface
6260 LLVM::ModuleTranslation &moduleTranslation)
const final;
6267 LLVM::ModuleTranslation &moduleTranslation)
const final;
6272 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6275 LLVM::ModuleTranslation &moduleTranslation)
const {
6278 .Case(
"omp.is_target_device",
6280 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6281 llvm::OpenMPIRBuilderConfig &
config =
6282 moduleTranslation.getOpenMPBuilder()->Config;
6283 config.setIsTargetDevice(deviceAttr.getValue());
6290 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6291 llvm::OpenMPIRBuilderConfig &
config =
6292 moduleTranslation.getOpenMPBuilder()->Config;
6293 config.setIsGPU(gpuAttr.getValue());
6298 .Case(
"omp.host_ir_filepath",
6300 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6301 llvm::OpenMPIRBuilder *ompBuilder =
6302 moduleTranslation.getOpenMPBuilder();
6303 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
6310 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6314 .Case(
"omp.version",
6316 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6317 llvm::OpenMPIRBuilder *ompBuilder =
6318 moduleTranslation.getOpenMPBuilder();
6319 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6320 versionAttr.getVersion());
6325 .Case(
"omp.declare_target",
6327 if (
auto declareTargetAttr =
6328 dyn_cast<omp::DeclareTargetAttr>(attr))
6333 .Case(
"omp.requires",
6335 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6336 using Requires = omp::ClauseRequires;
6337 Requires flags = requiresAttr.getValue();
6338 llvm::OpenMPIRBuilderConfig &
config =
6339 moduleTranslation.getOpenMPBuilder()->Config;
6340 config.setHasRequiresReverseOffload(
6341 bitEnumContainsAll(flags, Requires::reverse_offload));
6342 config.setHasRequiresUnifiedAddress(
6343 bitEnumContainsAll(flags, Requires::unified_address));
6344 config.setHasRequiresUnifiedSharedMemory(
6345 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6346 config.setHasRequiresDynamicAllocators(
6347 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6352 .Case(
"omp.target_triples",
6354 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6355 llvm::OpenMPIRBuilderConfig &
config =
6356 moduleTranslation.getOpenMPBuilder()->Config;
6357 config.TargetTriples.clear();
6358 config.TargetTriples.reserve(triplesAttr.size());
6359 for (
Attribute tripleAttr : triplesAttr) {
6360 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6361 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6379 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6380 Operation *op, llvm::IRBuilderBase &builder,
6381 LLVM::ModuleTranslation &moduleTranslation)
const {
6383 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6384 if (ompBuilder->Config.isTargetDevice()) {
6394 registry.
insert<omp::OpenMPDialect>();
6396 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.