26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
31 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/MDBuilder.h"
37 #include "llvm/IR/ReplaceConstant.h"
38 #include "llvm/Support/FileSystem.h"
39 #include "llvm/TargetParser/Triple.h"
40 #include "llvm/Transforms/Utils/ModuleUtils.h"
52 static llvm::omp::ScheduleKind
53 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
66 return llvm::omp::OMP_SCHEDULE_Runtime;
68 llvm_unreachable(
"unhandled schedule clause argument");
73 class OpenMPAllocaStackFrame
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
86 class OpenMPLoopInfoStackFrame
90 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
109 class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
136 class LinearClauseProcessor {
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
150 void createLinearVar(llvm::IRBuilderBase &builder,
153 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
155 linearPreconditionVars.push_back(builder.CreateAlloca(
156 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_var"));
157 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
158 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_result");
159 linearOrigVal.push_back(moduleTranslation.
lookupValue(linearVar));
160 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
161 linearOrigVars.push_back(linearVarAlloca);
168 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173 initLinearVar(llvm::IRBuilderBase &builder,
175 llvm::BasicBlock *loopPreHeader) {
176 builder.SetInsertPoint(loopPreHeader->getTerminator());
177 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
178 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179 linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
180 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
184 builder.saveIP(), llvm::omp::OMPD_barrier);
185 return afterBarrierIP;
189 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190 llvm::Value *loopInductionVar) {
191 builder.SetInsertPoint(loopBody->getTerminator());
192 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
194 llvm::LoadInst *linearVarStart =
195 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
197 linearPreconditionVars[index]);
198 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
199 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
200 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
206 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207 llvm::BasicBlock *loopExit) {
208 linearFinalizationBB = loopExit->splitBasicBlock(
209 loopExit->getTerminator(),
"omp_loop.linear_finalization");
210 linearExitBB = linearFinalizationBB->splitBasicBlock(
211 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
212 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
213 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
217 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218 finalizeLinearVar(llvm::IRBuilderBase &builder,
220 llvm::Value *lastIter) {
222 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
223 llvm::Value *loopLastIterLoad = builder.CreateLoad(
224 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
225 llvm::Value *isLast =
226 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
228 llvm::Type::getInt32Ty(builder.getContext()), 0));
230 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
231 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
232 llvm::LoadInst *linearVarTemp =
233 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
234 linearLoopBodyTemps[index]);
235 builder.CreateStore(linearVarTemp, linearOrigVars[index]);
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
245 builder.SetInsertPoint(linearExitBB->getTerminator());
247 builder.saveIP(), llvm::omp::OMPD_barrier);
252 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(user);
257 for (
auto *user : users) {
258 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259 if (userInst->getParent()->getName().str() == BBName)
260 user->replaceUsesOfWith(linearOrigVal[varIndex],
261 linearLoopBodyTemps[varIndex]);
272 SymbolRefAttr symbolName) {
273 omp::PrivateClauseOp privatizer =
274 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
276 assert(privatizer &&
"privatizer not found in the symbol table");
287 auto todo = [&op](StringRef clauseName) {
288 return op.
emitError() <<
"not yet implemented: Unhandled clause "
289 << clauseName <<
" in " << op.
getName()
293 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
294 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
295 result = todo(
"allocate");
297 auto checkBare = [&todo](
auto op, LogicalResult &result) {
299 result = todo(
"ompx_bare");
301 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
302 omp::ClauseCancellationConstructType cancelledDirective =
303 op.getCancelDirective();
306 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
313 if (isa_and_nonnull<omp::TaskloopOp>(parent))
314 result = todo(
"cancel directive inside of taskloop");
317 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
318 if (!op.getDependVars().empty() || op.getDependKinds())
319 result = todo(
"depend");
321 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
323 result = todo(
"device");
325 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
326 if (op.getDistScheduleChunkSize())
327 result = todo(
"dist_schedule with chunk_size");
329 auto checkHint = [](
auto op, LogicalResult &) {
333 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
334 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
335 op.getInReductionSyms())
336 result = todo(
"in_reduction");
338 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
339 if (!op.getIsDevicePtrVars().empty())
340 result = todo(
"is_device_ptr");
342 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
343 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
344 result = todo(
"linear");
346 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
348 result = todo(
"nowait");
350 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
351 if (op.getOrder() || op.getOrderMod())
352 result = todo(
"order");
354 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
355 if (op.getParLevelSimd())
356 result = todo(
"parallelization-level");
358 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
359 if (op.getPriority())
360 result = todo(
"priority");
362 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
363 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
365 if (!op.getPrivateVars().empty() && op.getNowait())
366 result = todo(
"privatization for deferred target tasks");
368 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
369 result = todo(
"privatization");
372 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
373 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
374 if (!op.getReductionVars().empty() || op.getReductionByref() ||
375 op.getReductionSyms())
376 result = todo(
"reduction");
377 if (op.getReductionMod() &&
378 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
379 result = todo(
"reduction with modifier");
381 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
382 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
383 op.getTaskReductionSyms())
384 result = todo(
"task_reduction");
386 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
388 result = todo(
"untied");
391 LogicalResult result = success();
393 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
394 .Case([&](omp::CancellationPointOp op) {
395 checkCancelDirective(op, result);
397 .Case([&](omp::DistributeOp op) {
398 checkAllocate(op, result);
399 checkDistSchedule(op, result);
400 checkOrder(op, result);
402 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op, result);
405 checkPrivate(op, result);
406 checkReduction(op, result);
408 .Case([&](omp::SingleOp op) {
409 checkAllocate(op, result);
410 checkPrivate(op, result);
412 .Case([&](omp::TeamsOp op) {
413 checkAllocate(op, result);
414 checkPrivate(op, result);
416 .Case([&](omp::TaskOp op) {
417 checkAllocate(op, result);
418 checkInReduction(op, result);
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
424 .Case([&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
428 .Case([&](omp::TaskloopOp op) {
430 checkUntied(op, result);
431 checkPriority(op, result);
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkLinear(op, result);
436 checkOrder(op, result);
437 checkReduction(op, result);
439 .Case([&](omp::ParallelOp op) {
440 checkAllocate(op, result);
441 checkReduction(op, result);
443 .Case([&](omp::SimdOp op) {
444 checkLinear(op, result);
445 checkReduction(op, result);
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
450 [&](
auto op) { checkDepend(op, result); })
451 .Case([&](omp::TargetOp op) {
452 checkAllocate(op, result);
453 checkBare(op, result);
454 checkDevice(op, result);
455 checkInReduction(op, result);
456 checkIsDevicePtr(op, result);
457 checkPrivate(op, result);
467 LogicalResult result = success();
469 llvm::handleAllErrors(
471 [&](
const PreviouslyReportedError &) { result = failure(); },
472 [&](
const llvm::ErrorInfoBase &err) {
479 template <
typename T>
489 static llvm::OpenMPIRBuilder::InsertPointTy
495 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
497 [&](OpenMPAllocaStackFrame &frame) {
498 allocaInsertPoint = frame.allocaInsertPoint;
502 return allocaInsertPoint;
511 if (builder.GetInsertBlock() ==
512 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
513 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
514 "Assuming end of basic block");
515 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
516 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
517 builder.GetInsertBlock()->getNextNode());
518 builder.CreateBr(entryBB);
519 builder.SetInsertPoint(entryBB);
522 llvm::BasicBlock &funcEntryBlock =
523 builder.GetInsertBlock()->getParent()->getEntryBlock();
524 return llvm::OpenMPIRBuilder::InsertPointTy(
525 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
531 static llvm::CanonicalLoopInfo *
533 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
534 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
535 [&](OpenMPLoopInfoStackFrame &frame) {
536 loopInfo = frame.loopInfo;
548 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
551 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
553 llvm::BasicBlock *continuationBlock =
554 splitBB(builder,
true,
"omp.region.cont");
555 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
557 llvm::LLVMContext &llvmContext = builder.getContext();
558 for (
Block &bb : region) {
559 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
560 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
561 builder.GetInsertBlock()->getNextNode());
562 moduleTranslation.
mapBlock(&bb, llvmBB);
565 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
572 unsigned numYields = 0;
574 if (!isLoopWrapper) {
575 bool operandsProcessed =
false;
577 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
578 if (!operandsProcessed) {
579 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
580 continuationBlockPHITypes.push_back(
581 moduleTranslation.
convertType(yield->getOperand(i).getType()));
583 operandsProcessed =
true;
585 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
586 "mismatching number of values yielded from the region");
587 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
588 llvm::Type *operandType =
589 moduleTranslation.
convertType(yield->getOperand(i).getType());
591 assert(continuationBlockPHITypes[i] == operandType &&
592 "values of mismatching types yielded from the region");
602 if (!continuationBlockPHITypes.empty())
604 continuationBlockPHIs &&
605 "expected continuation block PHIs if converted regions yield values");
606 if (continuationBlockPHIs) {
607 llvm::IRBuilderBase::InsertPointGuard guard(builder);
608 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
609 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
610 for (llvm::Type *ty : continuationBlockPHITypes)
611 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
617 for (
Block *bb : blocks) {
618 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
621 if (bb->isEntryBlock()) {
622 assert(sourceTerminator->getNumSuccessors() == 1 &&
623 "provided entry block has multiple successors");
624 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
625 "ContinuationBlock is not the successor of the entry block");
626 sourceTerminator->setSuccessor(0, llvmBB);
629 llvm::IRBuilderBase::InsertPointGuard guard(builder);
631 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
632 return llvm::make_error<PreviouslyReportedError>();
637 builder.CreateBr(continuationBlock);
648 Operation *terminator = bb->getTerminator();
649 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
650 builder.CreateBr(continuationBlock);
652 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
653 (*continuationBlockPHIs)[i]->addIncoming(
667 return continuationBlock;
673 case omp::ClauseProcBindKind::Close:
674 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
675 case omp::ClauseProcBindKind::Master:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
677 case omp::ClauseProcBindKind::Primary:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
679 case omp::ClauseProcBindKind::Spread:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
682 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
692 omp::BlockArgOpenMPOpInterface blockArgIface) {
694 blockArgIface.getBlockArgsPairs(blockArgsPairs);
695 for (
auto [var, arg] : blockArgsPairs)
712 .Case([&](omp::SimdOp op) {
714 cast<omp::BlockArgOpenMPOpInterface>(*op));
715 op.emitWarning() <<
"simd information on composite construct discarded";
719 return op->emitError() <<
"cannot ignore wrapper";
727 auto maskedOp = cast<omp::MaskedOp>(opInst);
728 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
733 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
735 auto ®ion = maskedOp.getRegion();
736 builder.restoreIP(codeGenIP);
744 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
746 llvm::Value *filterVal =
nullptr;
747 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
748 filterVal = moduleTranslation.
lookupValue(filterVar);
750 llvm::LLVMContext &llvmContext = builder.getContext();
754 assert(filterVal !=
nullptr);
755 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
756 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
763 builder.restoreIP(*afterIP);
771 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
772 auto masterOp = cast<omp::MasterOp>(opInst);
777 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
779 auto ®ion = masterOp.getRegion();
780 builder.restoreIP(codeGenIP);
788 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
790 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
791 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
798 builder.restoreIP(*afterIP);
806 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
807 auto criticalOp = cast<omp::CriticalOp>(opInst);
812 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
814 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
815 builder.restoreIP(codeGenIP);
823 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
825 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
826 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
827 llvm::Constant *hint =
nullptr;
830 if (criticalOp.getNameAttr()) {
833 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
834 auto criticalDeclareOp =
835 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
839 static_cast<int>(criticalDeclareOp.getHint()));
841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
843 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
848 builder.restoreIP(*afterIP);
855 template <
typename OP>
858 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
859 mlirVars.reserve(blockArgs.size());
860 llvmVars.reserve(blockArgs.size());
861 collectPrivatizationDecls<OP>(op);
864 mlirVars.push_back(privateVar);
876 void collectPrivatizationDecls(OP op) {
877 std::optional<ArrayAttr> attr = op.getPrivateSyms();
881 privatizers.reserve(privatizers.size() + attr->size());
882 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
889 template <
typename T>
893 std::optional<ArrayAttr> attr = op.getReductionSyms();
897 reductions.reserve(reductions.size() + op.getNumReductionVars());
898 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
899 reductions.push_back(
900 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
911 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
919 if (llvm::hasSingleElement(region)) {
920 llvm::Instruction *potentialTerminator =
921 builder.GetInsertBlock()->empty() ? nullptr
922 : &builder.GetInsertBlock()->back();
924 if (potentialTerminator && potentialTerminator->isTerminator())
925 potentialTerminator->removeFromParent();
926 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
929 region.
front(),
true, builder)))
933 if (continuationBlockArgs)
935 *continuationBlockArgs,
942 if (potentialTerminator && potentialTerminator->isTerminator()) {
943 llvm::BasicBlock *block = builder.GetInsertBlock();
944 if (block->empty()) {
950 potentialTerminator->insertInto(block, block->begin());
952 potentialTerminator->insertAfter(&block->back());
966 if (continuationBlockArgs)
967 llvm::append_range(*continuationBlockArgs, phis);
968 builder.SetInsertPoint(*continuationBlock,
969 (*continuationBlock)->getFirstInsertionPt());
976 using OwningReductionGen =
977 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
978 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
980 using OwningAtomicReductionGen =
981 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
982 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
989 static OwningReductionGen
995 OwningReductionGen gen =
996 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
997 llvm::Value *lhs, llvm::Value *rhs,
998 llvm::Value *&result)
mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.
mapValue(decl.getReductionLhsArg(), lhs);
1001 moduleTranslation.
mapValue(decl.getReductionRhsArg(), rhs);
1002 builder.restoreIP(insertPoint);
1005 "omp.reduction.nonatomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `combiner` region of `omp.declare_reduction`");
1009 result = llvm::getSingleElement(phis);
1010 return builder.saveIP();
1019 static OwningAtomicReductionGen
1021 llvm::IRBuilderBase &builder,
1023 if (decl.getAtomicReductionRegion().empty())
1024 return OwningAtomicReductionGen();
1029 OwningAtomicReductionGen atomicGen =
1030 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1031 llvm::Value *lhs, llvm::Value *rhs)
mutable
1032 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1033 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(), lhs);
1034 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(), rhs);
1035 builder.restoreIP(insertPoint);
1038 "omp.reduction.atomic.body", builder,
1039 moduleTranslation, &phis)))
1040 return llvm::createStringError(
1041 "failed to inline `atomic` region of `omp.declare_reduction`");
1042 assert(phis.empty());
1043 return builder.saveIP();
1049 static LogicalResult
1052 auto orderedOp = cast<omp::OrderedOp>(opInst);
1057 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1058 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1059 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1061 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1063 size_t indexVecValues = 0;
1064 while (indexVecValues < vecValues.size()) {
1066 storeValues.reserve(numLoops);
1067 for (
unsigned i = 0; i < numLoops; i++) {
1068 storeValues.push_back(vecValues[indexVecValues]);
1071 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1073 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1074 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1075 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1082 static LogicalResult
1085 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1086 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1091 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1093 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1094 builder.restoreIP(codeGenIP);
1102 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1104 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1105 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1107 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1112 builder.restoreIP(*afterIP);
1118 struct DeferredStore {
1119 DeferredStore(llvm::Value *value, llvm::Value *address)
1120 : value(value), address(address) {}
1123 llvm::Value *address;
1130 template <
typename T>
1131 static LogicalResult
1133 llvm::IRBuilderBase &builder,
1135 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1141 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1142 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1145 deferredStores.reserve(loop.getNumReductionVars());
1147 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1148 Region &allocRegion = reductionDecls[i].getAllocRegion();
1150 if (allocRegion.
empty())
1155 builder, moduleTranslation, &phis)))
1156 return loop.emitError(
1157 "failed to inline `alloc` region of `omp.declare_reduction`");
1159 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1160 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1164 llvm::Value *var = builder.CreateAlloca(
1165 moduleTranslation.
convertType(reductionDecls[i].getType()));
1167 llvm::Type *ptrTy = builder.getPtrTy();
1168 llvm::Value *castVar =
1169 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1170 llvm::Value *castPhi =
1171 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1173 deferredStores.emplace_back(castPhi, castVar);
1175 privateReductionVariables[i] = castVar;
1176 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1177 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1179 assert(allocRegion.
empty() &&
1180 "allocaction is implicit for by-val reduction");
1181 llvm::Value *var = builder.CreateAlloca(
1182 moduleTranslation.
convertType(reductionDecls[i].getType()));
1184 llvm::Type *ptrTy = builder.getPtrTy();
1185 llvm::Value *castVar =
1186 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1188 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1189 privateReductionVariables[i] = castVar;
1190 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1198 template <
typename T>
1205 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1206 Region &initializerRegion = reduction.getInitializerRegion();
1209 mlir::Value mlirSource = loop.getReductionVars()[i];
1210 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1211 assert(llvmSource &&
"lookup reduction var");
1212 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1215 llvm::Value *allocation =
1216 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1217 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1223 llvm::BasicBlock *block =
nullptr) {
1224 if (block ==
nullptr)
1225 block = builder.GetInsertBlock();
1227 if (block->empty() || block->getTerminator() ==
nullptr)
1228 builder.SetInsertPoint(block);
1230 builder.SetInsertPoint(block->getTerminator());
1238 template <
typename OP>
1239 static LogicalResult
1241 llvm::IRBuilderBase &builder,
1243 llvm::BasicBlock *latestAllocaBlock,
1249 if (op.getNumReductionVars() == 0)
1252 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1253 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1254 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1255 builder.restoreIP(allocaIP);
1258 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1260 if (!reductionDecls[i].getAllocRegion().empty())
1266 byRefVars[i] = builder.CreateAlloca(
1267 moduleTranslation.
convertType(reductionDecls[i].getType()));
1275 for (
auto [data, addr] : deferredStores)
1276 builder.CreateStore(data, addr);
1281 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1286 reductionVariableMap, i);
1289 "omp.reduction.neutral", builder,
1290 moduleTranslation, &phis)))
1293 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1294 "reduction neutral element declaration region");
1299 if (!reductionDecls[i].getAllocRegion().empty())
1308 builder.CreateStore(phis[0], byRefVars[i]);
1310 privateReductionVariables[i] = byRefVars[i];
1311 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1312 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1315 builder.CreateStore(phis[0], privateReductionVariables[i]);
1322 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1329 template <
typename T>
1331 T loop, llvm::IRBuilderBase &builder,
1338 unsigned numReductions = loop.getNumReductionVars();
1340 for (
unsigned i = 0; i < numReductions; ++i) {
1341 owningReductionGens.push_back(
1343 owningAtomicReductionGens.push_back(
1348 reductionInfos.reserve(numReductions);
1349 for (
unsigned i = 0; i < numReductions; ++i) {
1350 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1351 if (owningAtomicReductionGens[i])
1352 atomicGen = owningAtomicReductionGens[i];
1353 llvm::Value *variable =
1354 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1355 reductionInfos.push_back(
1357 privateReductionVariables[i],
1358 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1359 owningReductionGens[i],
1360 nullptr, atomicGen});
1365 static LogicalResult
1369 llvm::IRBuilderBase &builder, StringRef regionName,
1370 bool shouldLoadCleanupRegionArg =
true) {
1372 if (cleanupRegion->empty())
1378 llvm::Instruction *potentialTerminator =
1379 builder.GetInsertBlock()->empty() ? nullptr
1380 : &builder.GetInsertBlock()->back();
1381 if (potentialTerminator && potentialTerminator->isTerminator())
1382 builder.SetInsertPoint(potentialTerminator);
1383 llvm::Value *privateVarValue =
1384 shouldLoadCleanupRegionArg
1385 ? builder.CreateLoad(
1387 privateVariables[i])
1388 : privateVariables[i];
1393 moduleTranslation)))
1406 OP op, llvm::IRBuilderBase &builder,
1408 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1411 bool isNowait =
false,
bool isTeamsReduction =
false) {
1413 if (op.getNumReductionVars() == 0)
1425 owningReductionGens, owningAtomicReductionGens,
1426 privateReductionVariables, reductionInfos);
1431 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1432 builder.SetInsertPoint(tempTerminator);
1433 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1434 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1435 isByRef, isNowait, isTeamsReduction);
1440 if (!contInsertPoint->getBlock())
1441 return op->emitOpError() <<
"failed to convert reductions";
1443 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1444 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1449 tempTerminator->eraseFromParent();
1450 builder.restoreIP(*afterIP);
1454 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1455 [](omp::DeclareReductionOp reductionDecl) {
1456 return &reductionDecl.getCleanupRegion();
1459 moduleTranslation, builder,
1460 "omp.reduction.cleanup");
1471 template <
typename OP>
1475 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1480 if (op.getNumReductionVars() == 0)
1486 allocaIP, reductionDecls,
1487 privateReductionVariables, reductionVariableMap,
1488 deferredStores, isByRef)))
1492 allocaIP.getBlock(), reductionDecls,
1493 privateReductionVariables, reductionVariableMap,
1494 isByRef, deferredStores);
1504 static llvm::Value *
1508 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1511 Value blockArg = (*mappedPrivateVars)[privateVar];
1514 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1515 "A block argument corresponding to a mapped var should have "
1518 if (privVarType == blockArgType)
1525 if (!isa<LLVM::LLVMPointerType>(privVarType))
1526 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1539 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1541 Region &initRegion = privDecl.getInitRegion();
1542 if (initRegion.
empty())
1543 return llvmPrivateVar;
1547 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1548 assert(nonPrivateVar);
1549 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1550 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1555 moduleTranslation, &phis)))
1556 return llvm::createStringError(
1557 "failed to inline `init` region of `omp.private`");
1559 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1578 return llvm::Error::success();
1580 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1586 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1588 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1589 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1592 return privVarOrErr.takeError();
1594 llvmPrivateVar = privVarOrErr.get();
1595 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1600 return llvm::Error::success();
1610 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1613 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1614 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1615 allocaTerminator->getIterator()),
1616 true, allocaTerminator->getStableDebugLoc(),
1617 "omp.region.after_alloca");
1619 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1621 allocaTerminator = allocaIP.getBlock()->getTerminator();
1622 builder.SetInsertPoint(allocaTerminator);
1624 assert(allocaTerminator->getNumSuccessors() == 1 &&
1625 "This is an unconditional branch created by splitBB");
1627 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1628 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1630 unsigned int allocaAS =
1631 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1634 .getProgramAddressSpace();
1636 for (
auto [privDecl, mlirPrivVar, blockArg] :
1639 llvm::Type *llvmAllocType =
1640 moduleTranslation.
convertType(privDecl.getType());
1641 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1642 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1643 llvmAllocType,
nullptr,
"omp.private.alloc");
1644 if (allocaAS != defaultAS)
1645 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1646 builder.getPtrTy(defaultAS));
1648 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1651 return afterAllocas;
1662 bool needsFirstprivate =
1663 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1664 return privOp.getDataSharingType() ==
1665 omp::DataSharingClauseType::FirstPrivate;
1668 if (!needsFirstprivate)
1671 llvm::BasicBlock *copyBlock =
1672 splitBB(builder,
true,
"omp.private.copy");
1675 for (
auto [decl, mlirVar, llvmVar] :
1676 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1677 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1681 Region ©Region = decl.getCopyRegion();
1685 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1686 assert(nonPrivateVar);
1687 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1690 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1694 moduleTranslation)))
1695 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1707 if (insertBarrier) {
1709 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1710 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1718 static LogicalResult
1725 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1726 [](omp::PrivateClauseOp privatizer) {
1727 return &privatizer.getDeallocRegion();
1731 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1732 "omp.private.dealloc",
false)))
1733 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1734 "`omp.private` op in");
1746 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1753 static LogicalResult
1756 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1757 using StorableBodyGenCallbackTy =
1758 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1760 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1766 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1770 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1774 sectionsOp.getNumReductionVars());
1778 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1781 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1782 reductionDecls, privateReductionVariables, reductionVariableMap,
1789 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1793 Region ®ion = sectionOp.getRegion();
1794 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1795 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1796 builder.restoreIP(codeGenIP);
1803 sectionsOp.getRegion().getNumArguments());
1804 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1805 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1806 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1808 moduleTranslation.
mapValue(sectionArg, llvmVal);
1815 sectionCBs.push_back(sectionCB);
1821 if (sectionCBs.empty())
1824 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1829 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1830 llvm::Value &vPtr, llvm::Value *&replacementValue)
1831 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1832 replacementValue = &vPtr;
1838 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1842 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1843 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1845 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1846 sectionsOp.getNowait());
1851 builder.restoreIP(*afterIP);
1855 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1856 privateReductionVariables, isByRef, sectionsOp.getNowait());
1860 static LogicalResult
1863 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1864 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1869 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1870 builder.restoreIP(codegenIP);
1872 builder, moduleTranslation)
1875 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1879 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1882 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1883 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1884 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1885 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1886 llvmCPFuncs.push_back(
1890 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1892 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1898 builder.restoreIP(*afterIP);
1904 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1909 for (
auto ra : iface.getReductionBlockArgs())
1910 for (
auto &use : ra.getUses()) {
1911 auto *useOp = use.getOwner();
1913 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1914 debugUses.push_back(useOp);
1918 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1923 Operation *currentOp = currentDistOp.getOperation();
1924 if (distOp && (distOp != currentOp))
1933 for (
auto use : debugUses)
1939 static LogicalResult
1942 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1947 unsigned numReductionVars = op.getNumReductionVars();
1951 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1957 if (doTeamsReduction) {
1958 isByRef =
getIsByRef(op.getReductionByref());
1960 assert(isByRef.size() == op.getNumReductionVars());
1963 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1968 op, reductionArgs, builder, moduleTranslation, allocaIP,
1969 reductionDecls, privateReductionVariables, reductionVariableMap,
1974 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1976 moduleTranslation, allocaIP);
1977 builder.restoreIP(codegenIP);
1983 llvm::Value *numTeamsLower =
nullptr;
1984 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1985 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1987 llvm::Value *numTeamsUpper =
nullptr;
1988 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1989 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1991 llvm::Value *threadLimit =
nullptr;
1992 if (
Value threadLimitVar = op.getThreadLimit())
1993 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1995 llvm::Value *ifExpr =
nullptr;
1996 if (
Value ifVar = op.getIfExpr())
1999 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2000 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2002 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2007 builder.restoreIP(*afterIP);
2008 if (doTeamsReduction) {
2011 op, builder, moduleTranslation, allocaIP, reductionDecls,
2012 privateReductionVariables, isByRef,
2022 if (dependVars.empty())
2024 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2025 llvm::omp::RTLDependenceKindTy type;
2027 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2028 case mlir::omp::ClauseTaskDepend::taskdependin:
2029 type = llvm::omp::RTLDependenceKindTy::DepIn;
2034 case mlir::omp::ClauseTaskDepend::taskdependout:
2035 case mlir::omp::ClauseTaskDepend::taskdependinout:
2036 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2038 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2039 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2041 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2042 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2045 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2046 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2047 dds.emplace_back(dd);
2059 llvm::IRBuilderBase &llvmBuilder,
2061 llvm::omp::Directive cancelDirective) {
2062 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2063 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2067 llvmBuilder.restoreIP(ip);
2073 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2074 return llvm::Error::success();
2079 ompBuilder.pushFinalizationCB(
2089 llvm::OpenMPIRBuilder &ompBuilder,
2090 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2091 ompBuilder.popFinalizationCB();
2092 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2093 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2094 assert(cancelBranch->getNumSuccessors() == 1 &&
2095 "cancel branch should have one target");
2096 cancelBranch->setSuccessor(0, constructFini);
2103 class TaskContextStructManager {
2105 TaskContextStructManager(llvm::IRBuilderBase &builder,
2108 : builder{builder}, moduleTranslation{moduleTranslation},
2109 privateDecls{privateDecls} {}
2115 void generateTaskContextStruct();
2121 void createGEPsToPrivateVars();
2124 void freeStructPtr();
2127 return llvmPrivateVarGEPs;
2130 llvm::Value *getStructPtr() {
return structPtr; }
2133 llvm::IRBuilderBase &builder;
2145 llvm::Value *structPtr =
nullptr;
2147 llvm::Type *structTy =
nullptr;
2151 void TaskContextStructManager::generateTaskContextStruct() {
2152 if (privateDecls.empty())
2154 privateVarTypes.reserve(privateDecls.size());
2156 for (omp::PrivateClauseOp &privOp : privateDecls) {
2159 if (!privOp.readsFromMold())
2161 Type mlirType = privOp.getType();
2162 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2168 llvm::DataLayout dataLayout =
2169 builder.GetInsertBlock()->getModule()->getDataLayout();
2170 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2171 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2174 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2176 "omp.task.context_ptr");
2179 void TaskContextStructManager::createGEPsToPrivateVars() {
2181 assert(privateVarTypes.empty());
2186 llvmPrivateVarGEPs.clear();
2187 llvmPrivateVarGEPs.reserve(privateDecls.size());
2188 llvm::Value *zero = builder.getInt32(0);
2190 for (
auto privDecl : privateDecls) {
2191 if (!privDecl.readsFromMold()) {
2193 llvmPrivateVarGEPs.push_back(
nullptr);
2196 llvm::Value *iVal = builder.getInt32(i);
2197 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2198 llvmPrivateVarGEPs.push_back(gep);
2203 void TaskContextStructManager::freeStructPtr() {
2207 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2209 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2210 builder.CreateFree(structPtr);
2214 static LogicalResult
2217 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2222 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2234 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2239 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2240 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2241 builder.getContext(),
"omp.task.start",
2242 builder.GetInsertBlock()->getParent());
2243 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2244 builder.SetInsertPoint(branchToTaskStartBlock);
2247 llvm::BasicBlock *copyBlock =
2248 splitBB(builder,
true,
"omp.private.copy");
2249 llvm::BasicBlock *initBlock =
2250 splitBB(builder,
true,
"omp.private.init");
2266 moduleTranslation, allocaIP);
2269 builder.SetInsertPoint(initBlock->getTerminator());
2272 taskStructMgr.generateTaskContextStruct();
2279 taskStructMgr.createGEPsToPrivateVars();
2281 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2284 taskStructMgr.getLLVMPrivateVarGEPs())) {
2286 if (!privDecl.readsFromMold())
2288 assert(llvmPrivateVarAlloc &&
2289 "reads from mold so shouldn't have been skipped");
2292 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2293 blockArg, llvmPrivateVarAlloc, initBlock);
2294 if (!privateVarOrErr)
2295 return handleError(privateVarOrErr, *taskOp.getOperation());
2304 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2305 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2306 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2308 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2309 llvmPrivateVarAlloc);
2311 assert(llvmPrivateVarAlloc->getType() ==
2312 moduleTranslation.
convertType(blockArg.getType()));
2322 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2323 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2324 taskOp.getPrivateNeedsBarrier())))
2325 return llvm::failure();
2328 builder.SetInsertPoint(taskStartBlock);
2330 auto bodyCB = [&](InsertPointTy allocaIP,
2331 InsertPointTy codegenIP) -> llvm::Error {
2335 moduleTranslation, allocaIP);
2338 builder.restoreIP(codegenIP);
2340 llvm::BasicBlock *privInitBlock =
nullptr;
2345 auto [blockArg, privDecl, mlirPrivVar] = zip;
2347 if (privDecl.readsFromMold())
2350 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2351 llvm::Type *llvmAllocType =
2352 moduleTranslation.
convertType(privDecl.getType());
2353 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2354 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2355 llvmAllocType,
nullptr,
"omp.private.alloc");
2358 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2359 blockArg, llvmPrivateVar, privInitBlock);
2360 if (!privateVarOrError)
2361 return privateVarOrError.takeError();
2362 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2363 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2366 taskStructMgr.createGEPsToPrivateVars();
2367 for (
auto [i, llvmPrivVar] :
2370 assert(privateVarsInfo.
llvmVars[i] &&
2371 "This is added in the loop above");
2374 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2379 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2383 if (!privateDecl.readsFromMold())
2386 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2387 llvmPrivateVar = builder.CreateLoad(
2388 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2390 assert(llvmPrivateVar->getType() ==
2391 moduleTranslation.
convertType(blockArg.getType()));
2392 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2396 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2397 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2398 return llvm::make_error<PreviouslyReportedError>();
2400 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2405 return llvm::make_error<PreviouslyReportedError>();
2408 taskStructMgr.freeStructPtr();
2410 return llvm::Error::success();
2419 llvm::omp::Directive::OMPD_taskgroup);
2423 moduleTranslation, dds);
2425 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2426 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2428 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2430 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2431 taskOp.getMergeable(),
2432 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2433 moduleTranslation.
lookupValue(taskOp.getPriority()));
2441 builder.restoreIP(*afterIP);
2446 static LogicalResult
2449 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2453 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2454 builder.restoreIP(codegenIP);
2456 builder, moduleTranslation)
2461 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2462 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2469 builder.restoreIP(*afterIP);
2473 static LogicalResult
2484 static LogicalResult
2488 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2492 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2494 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2498 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2501 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2502 llvm::Type *ivType = step->getType();
2503 llvm::Value *chunk =
nullptr;
2504 if (wsloopOp.getScheduleChunk()) {
2505 llvm::Value *chunkVar =
2506 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2507 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2514 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2518 wsloopOp.getNumReductionVars());
2521 builder, moduleTranslation, privateVarsInfo, allocaIP);
2528 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2533 moduleTranslation, allocaIP, reductionDecls,
2534 privateReductionVariables, reductionVariableMap,
2535 deferredStores, isByRef)))
2544 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2546 wsloopOp.getPrivateNeedsBarrier())))
2549 assert(afterAllocas.get()->getSinglePredecessor());
2552 afterAllocas.get()->getSinglePredecessor(),
2553 reductionDecls, privateReductionVariables,
2554 reductionVariableMap, isByRef, deferredStores)))
2558 bool isOrdered = wsloopOp.getOrdered().has_value();
2559 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2560 bool isSimd = wsloopOp.getScheduleSimd();
2561 bool loopNeedsBarrier = !wsloopOp.getNowait();
2566 llvm::omp::WorksharingLoopType workshareLoopType =
2567 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2568 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2569 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2573 llvm::omp::Directive::OMPD_for);
2575 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2578 LinearClauseProcessor linearClauseProcessor;
2579 if (wsloopOp.getLinearVars().size()) {
2580 for (
mlir::Value linearVar : wsloopOp.getLinearVars())
2581 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2583 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2584 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2588 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2596 if (wsloopOp.getLinearVars().size()) {
2597 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2598 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2599 loopInfo->getPreheader());
2602 builder.restoreIP(*afterBarrierIP);
2603 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2604 loopInfo->getIndVar());
2605 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2606 loopInfo->getExit());
2609 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2610 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2611 ompBuilder->applyWorkshareLoop(
2612 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2613 convertToScheduleKind(schedule), chunk, isSimd,
2614 scheduleMod == omp::ScheduleModifier::monotonic,
2615 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2622 if (wsloopOp.getLinearVars().size()) {
2623 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2624 assert(loopInfo->getLastIter() &&
2625 "`lastiter` in CanonicalLoopInfo is nullptr");
2626 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2627 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2628 loopInfo->getLastIter());
2631 for (
size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2632 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2634 builder.restoreIP(oldIP);
2642 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2643 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2653 static LogicalResult
2656 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2658 assert(isByRef.size() == opInst.getNumReductionVars());
2670 opInst.getNumReductionVars());
2673 auto bodyGenCB = [&](InsertPointTy allocaIP,
2674 InsertPointTy codeGenIP) -> llvm::Error {
2676 builder, moduleTranslation, privateVarsInfo, allocaIP);
2678 return llvm::make_error<PreviouslyReportedError>();
2684 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2687 InsertPointTy(allocaIP.getBlock(),
2688 allocaIP.getBlock()->getTerminator()->getIterator());
2691 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2692 reductionDecls, privateReductionVariables, reductionVariableMap,
2693 deferredStores, isByRef)))
2694 return llvm::make_error<PreviouslyReportedError>();
2696 assert(afterAllocas.get()->getSinglePredecessor());
2697 builder.restoreIP(codeGenIP);
2703 return llvm::make_error<PreviouslyReportedError>();
2706 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2708 opInst.getPrivateNeedsBarrier())))
2709 return llvm::make_error<PreviouslyReportedError>();
2713 afterAllocas.get()->getSinglePredecessor(),
2714 reductionDecls, privateReductionVariables,
2715 reductionVariableMap, isByRef, deferredStores)))
2716 return llvm::make_error<PreviouslyReportedError>();
2721 moduleTranslation, allocaIP);
2725 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2727 return regionBlock.takeError();
2730 if (opInst.getNumReductionVars() > 0) {
2736 owningReductionGens, owningAtomicReductionGens,
2737 privateReductionVariables, reductionInfos);
2740 builder.SetInsertPoint((*regionBlock)->getTerminator());
2743 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2744 builder.SetInsertPoint(tempTerminator);
2746 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2747 ompBuilder->createReductions(
2748 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2750 if (!contInsertPoint)
2751 return contInsertPoint.takeError();
2753 if (!contInsertPoint->getBlock())
2754 return llvm::make_error<PreviouslyReportedError>();
2756 tempTerminator->eraseFromParent();
2757 builder.restoreIP(*contInsertPoint);
2760 return llvm::Error::success();
2763 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2764 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2773 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2774 InsertPointTy oldIP = builder.saveIP();
2775 builder.restoreIP(codeGenIP);
2780 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2781 [](omp::DeclareReductionOp reductionDecl) {
2782 return &reductionDecl.getCleanupRegion();
2785 reductionCleanupRegions, privateReductionVariables,
2786 moduleTranslation, builder,
"omp.reduction.cleanup")))
2787 return llvm::createStringError(
2788 "failed to inline `cleanup` region of `omp.declare_reduction`");
2793 return llvm::make_error<PreviouslyReportedError>();
2795 builder.restoreIP(oldIP);
2796 return llvm::Error::success();
2799 llvm::Value *ifCond =
nullptr;
2800 if (
auto ifVar = opInst.getIfExpr())
2802 llvm::Value *numThreads =
nullptr;
2803 if (
auto numThreadsVar = opInst.getNumThreads())
2804 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2805 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2806 if (
auto bind = opInst.getProcBindKind())
2810 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2812 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2814 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2815 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2816 ifCond, numThreads, pbKind, isCancellable);
2821 builder.restoreIP(*afterIP);
2826 static llvm::omp::OrderKind
2829 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2831 case omp::ClauseOrderKind::Concurrent:
2832 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2834 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2838 static LogicalResult
2842 auto simdOp = cast<omp::SimdOp>(opInst);
2848 if (simdOp.isComposite()) {
2853 builder, moduleTranslation);
2861 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2865 builder, moduleTranslation, privateVarsInfo, allocaIP);
2874 llvm::ConstantInt *simdlen =
nullptr;
2875 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2876 simdlen = builder.getInt64(simdlenVar.value());
2878 llvm::ConstantInt *safelen =
nullptr;
2879 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2880 safelen = builder.getInt64(safelenVar.value());
2882 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2885 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2886 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2888 for (
size_t i = 0; i < operands.size(); ++i) {
2889 llvm::Value *alignment =
nullptr;
2890 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
2891 llvm::Type *ty = llvmVal->getType();
2893 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2894 alignment = builder.getInt64(intAttr.getInt());
2895 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2896 assert(alignment &&
"Invalid alignment value");
2897 auto curInsert = builder.saveIP();
2898 builder.SetInsertPoint(sourceBlock);
2899 llvmVal = builder.CreateLoad(ty, llvmVal);
2900 builder.restoreIP(curInsert);
2901 alignedVars[llvmVal] = alignment;
2905 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2910 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2912 ompBuilder->applySimd(loopInfo, alignedVars,
2914 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2916 order, simdlen, safelen);
2924 static LogicalResult
2928 auto loopOp = cast<omp::LoopNestOp>(opInst);
2931 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2936 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2937 llvm::Value *iv) -> llvm::Error {
2940 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2945 bodyInsertPoints.push_back(ip);
2947 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2948 return llvm::Error::success();
2951 builder.restoreIP(ip);
2953 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
2955 return regionBlock.takeError();
2957 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2958 return llvm::Error::success();
2966 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2967 llvm::Value *lowerBound =
2968 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
2969 llvm::Value *upperBound =
2970 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
2971 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
2976 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2977 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2979 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2981 computeIP = loopInfos.front()->getPreheaderIP();
2985 ompBuilder->createCanonicalLoop(
2986 loc, bodyGen, lowerBound, upperBound, step,
2987 true, loopOp.getLoopInclusive(), computeIP);
2992 loopInfos.push_back(*loopResult);
2997 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2998 loopInfos.front()->getAfterIP();
3002 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3003 [&](OpenMPLoopInfoStackFrame &frame) {
3004 frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
3012 builder.restoreIP(afterIP);
3017 static llvm::AtomicOrdering
3020 return llvm::AtomicOrdering::Monotonic;
3023 case omp::ClauseMemoryOrderKind::Seq_cst:
3024 return llvm::AtomicOrdering::SequentiallyConsistent;
3025 case omp::ClauseMemoryOrderKind::Acq_rel:
3026 return llvm::AtomicOrdering::AcquireRelease;
3027 case omp::ClauseMemoryOrderKind::Acquire:
3028 return llvm::AtomicOrdering::Acquire;
3029 case omp::ClauseMemoryOrderKind::Release:
3030 return llvm::AtomicOrdering::Release;
3031 case omp::ClauseMemoryOrderKind::Relaxed:
3032 return llvm::AtomicOrdering::Monotonic;
3034 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3038 static LogicalResult
3041 auto readOp = cast<omp::AtomicReadOp>(opInst);
3046 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3049 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3052 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3053 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3055 llvm::Type *elementType =
3056 moduleTranslation.
convertType(readOp.getElementType());
3058 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3059 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3060 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3065 static LogicalResult
3068 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3073 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3076 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3078 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3079 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3080 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3081 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3084 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3092 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3093 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3094 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3095 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3096 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3097 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3098 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3099 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3100 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3101 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3105 static LogicalResult
3107 llvm::IRBuilderBase &builder,
3114 auto &innerOpList = opInst.getRegion().front().getOperations();
3115 bool isXBinopExpr{
false};
3116 llvm::AtomicRMWInst::BinOp binop;
3118 llvm::Value *llvmExpr =
nullptr;
3119 llvm::Value *llvmX =
nullptr;
3120 llvm::Type *llvmXElementType =
nullptr;
3121 if (innerOpList.size() == 2) {
3127 opInst.getRegion().getArgument(0))) {
3128 return opInst.emitError(
"no atomic update operation with region argument"
3129 " as operand found inside atomic.update region");
3132 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3134 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3138 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3140 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3142 opInst.getRegion().getArgument(0).getType());
3143 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3147 llvm::AtomicOrdering atomicOrdering =
3152 [&opInst, &moduleTranslation](
3153 llvm::Value *atomicx,
3156 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3157 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3158 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3159 return llvm::make_error<PreviouslyReportedError>();
3161 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3162 assert(yieldop && yieldop.getResults().size() == 1 &&
3163 "terminator must be omp.yield op and it must have exactly one "
3165 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3170 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3171 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3172 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3173 atomicOrdering, binop, updateFn,
3179 builder.restoreIP(*afterIP);
3183 static LogicalResult
3185 llvm::IRBuilderBase &builder,
3192 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3193 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3195 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3196 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3198 assert((atomicUpdateOp || atomicWriteOp) &&
3199 "internal op must be an atomic.update or atomic.write op");
3201 if (atomicWriteOp) {
3202 isPostfixUpdate =
true;
3203 mlirExpr = atomicWriteOp.getExpr();
3205 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3206 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3207 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3210 if (innerOpList.size() == 2) {
3213 atomicUpdateOp.getRegion().getArgument(0))) {
3214 return atomicUpdateOp.emitError(
3215 "no atomic update operation with region argument"
3216 " as operand found inside atomic.update region");
3220 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3223 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3227 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3228 llvm::Value *llvmX =
3229 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3230 llvm::Value *llvmV =
3231 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3232 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3233 atomicCaptureOp.getAtomicReadOp().getElementType());
3234 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3237 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3241 llvm::AtomicOrdering atomicOrdering =
3245 [&](llvm::Value *atomicx,
3248 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3249 Block &bb = *atomicUpdateOp.getRegion().
begin();
3250 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3252 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3253 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3254 return llvm::make_error<PreviouslyReportedError>();
3256 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3257 assert(yieldop && yieldop.getResults().size() == 1 &&
3258 "terminator must be omp.yield op and it must have exactly one "
3260 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3265 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3266 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3267 ompBuilder->createAtomicCapture(
3268 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3269 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
3271 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3274 builder.restoreIP(*afterIP);
3279 omp::ClauseCancellationConstructType directive) {
3280 switch (directive) {
3281 case omp::ClauseCancellationConstructType::Loop:
3282 return llvm::omp::Directive::OMPD_for;
3283 case omp::ClauseCancellationConstructType::Parallel:
3284 return llvm::omp::Directive::OMPD_parallel;
3285 case omp::ClauseCancellationConstructType::Sections:
3286 return llvm::omp::Directive::OMPD_sections;
3287 case omp::ClauseCancellationConstructType::Taskgroup:
3288 return llvm::omp::Directive::OMPD_taskgroup;
3292 static LogicalResult
3298 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3301 llvm::Value *ifCond =
nullptr;
3302 if (
Value ifVar = op.getIfExpr())
3305 llvm::omp::Directive cancelledDirective =
3308 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3309 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3311 if (failed(
handleError(afterIP, *op.getOperation())))
3314 builder.restoreIP(afterIP.get());
3319 static LogicalResult
3321 llvm::IRBuilderBase &builder,
3326 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3329 llvm::omp::Directive cancelledDirective =
3332 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3333 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3335 if (failed(
handleError(afterIP, *op.getOperation())))
3338 builder.restoreIP(afterIP.get());
3345 static LogicalResult
3348 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3350 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3355 Value symAddr = threadprivateOp.getSymAddr();
3358 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3361 if (!isa<LLVM::AddressOfOp>(symOp))
3362 return opInst.
emitError(
"Addressing symbol not found");
3363 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3365 LLVM::GlobalOp global =
3366 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3367 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3369 if (!ompBuilder->Config.isTargetDevice()) {
3370 llvm::Type *type = globalValue->getValueType();
3371 llvm::TypeSize typeSize =
3372 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3374 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3375 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3376 ompLoc, globalValue, size, global.getSymName() +
".cache");
3385 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3387 switch (deviceClause) {
3388 case mlir::omp::DeclareTargetDeviceType::host:
3389 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3391 case mlir::omp::DeclareTargetDeviceType::nohost:
3392 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3394 case mlir::omp::DeclareTargetDeviceType::any:
3395 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3398 llvm_unreachable(
"unhandled device clause");
3401 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3403 mlir::omp::DeclareTargetCaptureClause captureClause) {
3404 switch (captureClause) {
3405 case mlir::omp::DeclareTargetCaptureClause::to:
3406 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3407 case mlir::omp::DeclareTargetCaptureClause::link:
3408 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3409 case mlir::omp::DeclareTargetCaptureClause::enter:
3410 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3412 llvm_unreachable(
"unhandled capture clause");
3417 llvm::OpenMPIRBuilder &ompBuilder) {
3419 llvm::raw_svector_ostream os(suffix);
3422 auto fileInfoCallBack = [&loc]() {
3423 return std::pair<std::string, uint64_t>(
3424 llvm::StringRef(loc.getFilename()), loc.getLine());
3428 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3430 os <<
"_decl_tgt_ref_ptr";
3436 if (
auto addressOfOp =
3437 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3438 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3439 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3440 if (
auto declareTargetGlobal =
3441 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3442 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3443 mlir::omp::DeclareTargetCaptureClause::link)
3452 static llvm::Value *
3459 if (
auto addressOfOp =
3460 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3461 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3462 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3463 addressOfOp.getGlobalName()))) {
3465 if (
auto declareTargetGlobal =
3466 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3467 gOp.getOperation())) {
3471 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3472 mlir::omp::DeclareTargetCaptureClause::link) ||
3473 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3474 mlir::omp::DeclareTargetCaptureClause::to &&
3475 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3479 if (gOp.getSymName().contains(suffix))
3484 (gOp.getSymName().str() + suffix.str()).str());
3495 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3499 void append(MapInfosTy &curInfo) {
3500 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3501 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3510 struct MapInfoData : MapInfosTy {
3522 void append(MapInfoData &CurInfo) {
3523 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3524 CurInfo.IsDeclareTarget.end());
3525 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3526 OriginalValue.append(CurInfo.OriginalValue.begin(),
3527 CurInfo.OriginalValue.end());
3528 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3529 MapInfosTy::append(CurInfo);
3535 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3536 arrTy.getElementType()))
3552 Operation *clauseOp, llvm::Value *basePointer,
3553 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3555 if (
auto memberClause =
3556 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3561 if (!memberClause.getBounds().empty()) {
3562 llvm::Value *elementCount = builder.getInt64(1);
3563 for (
auto bounds : memberClause.getBounds()) {
3564 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3565 bounds.getDefiningOp())) {
3570 elementCount = builder.CreateMul(
3574 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3575 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3576 builder.getInt64(1)));
3583 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3591 return builder.CreateMul(elementCount,
3592 builder.getInt64(underlyingTypeSzInBits / 8));
3605 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3613 for (
Value mapValue : mapVars) {
3614 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3615 for (
auto member : map.getMembers())
3616 if (member == mapOp)
3623 for (
Value mapValue : mapVars) {
3624 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3626 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3627 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
3628 mapData.Pointers.push_back(mapData.OriginalValue.back());
3630 if (llvm::Value *refPtr =
3632 moduleTranslation)) {
3633 mapData.IsDeclareTarget.push_back(
true);
3634 mapData.BasePointers.push_back(refPtr);
3636 mapData.IsDeclareTarget.push_back(
false);
3637 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3640 mapData.BaseType.push_back(
3641 moduleTranslation.
convertType(mapOp.getVarType()));
3642 mapData.Sizes.push_back(
3643 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3644 mapData.BaseType.back(), builder, moduleTranslation));
3645 mapData.MapClause.push_back(mapOp.getOperation());
3646 mapData.Types.push_back(
3647 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3651 if (mapOp.getMapperId())
3652 mapData.Mappers.push_back(
3653 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3654 mapOp, mapOp.getMapperIdAttr()));
3656 mapData.Mappers.push_back(
nullptr);
3657 mapData.IsAMapping.push_back(
true);
3658 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3661 auto findMapInfo = [&mapData](llvm::Value *val,
3662 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3665 for (llvm::Value *basePtr : mapData.OriginalValue) {
3666 if (basePtr == val && mapData.IsAMapping[index]) {
3668 mapData.Types[index] |=
3669 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3670 mapData.DevicePointers[index] = devInfoTy;
3679 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3680 for (
Value mapValue : useDevOperands) {
3681 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3683 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3684 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3687 if (!findMapInfo(origValue, devInfoTy)) {
3688 mapData.OriginalValue.push_back(origValue);
3689 mapData.Pointers.push_back(mapData.OriginalValue.back());
3690 mapData.IsDeclareTarget.push_back(
false);
3691 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3692 mapData.BaseType.push_back(
3693 moduleTranslation.
convertType(mapOp.getVarType()));
3694 mapData.Sizes.push_back(builder.getInt64(0));
3695 mapData.MapClause.push_back(mapOp.getOperation());
3696 mapData.Types.push_back(
3697 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3700 mapData.DevicePointers.push_back(devInfoTy);
3701 mapData.Mappers.push_back(
nullptr);
3702 mapData.IsAMapping.push_back(
false);
3703 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3708 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3709 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3711 for (
Value mapValue : hasDevAddrOperands) {
3712 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3714 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3715 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3717 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3718 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3720 mapData.OriginalValue.push_back(origValue);
3721 mapData.BasePointers.push_back(origValue);
3722 mapData.Pointers.push_back(origValue);
3723 mapData.IsDeclareTarget.push_back(
false);
3724 mapData.BaseType.push_back(
3725 moduleTranslation.
convertType(mapOp.getVarType()));
3726 mapData.Sizes.push_back(
3727 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3728 mapData.MapClause.push_back(mapOp.getOperation());
3729 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3733 mapData.Types.push_back(mapType);
3737 if (mapOp.getMapperId()) {
3738 mapData.Mappers.push_back(
3739 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3740 mapOp, mapOp.getMapperIdAttr()));
3742 mapData.Mappers.push_back(
nullptr);
3745 mapData.Types.push_back(
3746 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3747 mapData.Mappers.push_back(
nullptr);
3751 mapData.DevicePointers.push_back(
3752 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3753 mapData.IsAMapping.push_back(
false);
3754 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3759 auto *res = llvm::find(mapData.MapClause, memberOp);
3760 assert(res != mapData.MapClause.end() &&
3761 "MapInfoOp for member not found in MapData, cannot return index");
3762 return std::distance(mapData.MapClause.begin(), res);
3767 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3769 if (indexAttr.size() == 1)
3770 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3773 std::iota(indices.begin(), indices.end(), 0);
3775 llvm::sort(indices.begin(), indices.end(),
3776 [&](
const size_t a,
const size_t b) {
3777 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3778 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3779 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3780 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3781 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3783 if (aIndex == bIndex)
3786 if (aIndex < bIndex)
3789 if (aIndex > bIndex)
3796 return memberIndicesA.size() < memberIndicesB.size();
3799 return llvm::cast<omp::MapInfoOp>(
3800 mapInfo.getMembers()[indices.front()].getDefiningOp());
3822 std::vector<llvm::Value *>
3824 llvm::IRBuilderBase &builder,
bool isArrayTy,
3826 std::vector<llvm::Value *> idx;
3837 idx.push_back(builder.getInt64(0));
3838 for (
int i = bounds.size() - 1; i >= 0; --i) {
3839 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3840 bounds[i].getDefiningOp())) {
3841 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
3863 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3864 for (
size_t i = 1; i < bounds.size(); ++i) {
3865 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3866 bounds[i].getDefiningOp())) {
3867 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3868 moduleTranslation.
lookupValue(boundOp.getExtent()),
3869 dimensionIndexSizeOffset[i - 1]));
3877 for (
int i = bounds.size() - 1; i >= 0; --i) {
3878 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3879 bounds[i].getDefiningOp())) {
3881 idx.emplace_back(builder.CreateMul(
3882 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
3883 dimensionIndexSizeOffset[i]));
3885 idx.back() = builder.CreateAdd(
3886 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
3887 boundOp.getLowerBound()),
3888 dimensionIndexSizeOffset[i]));
3913 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3914 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
3915 assert(!ompBuilder.Config.isTargetDevice() &&
3916 "function only supported for host device codegen");
3919 combinedInfo.Types.emplace_back(
3921 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3922 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3923 combinedInfo.DevicePointers.emplace_back(
3924 mapData.DevicePointers[mapDataIndex]);
3925 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
3927 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3928 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3938 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3940 llvm::Value *lowAddr, *highAddr;
3941 if (!parentClause.getPartialMap()) {
3942 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3943 builder.getPtrTy());
3944 highAddr = builder.CreatePointerCast(
3945 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3946 mapData.Pointers[mapDataIndex], 1),
3947 builder.getPtrTy());
3948 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3950 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3953 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3954 builder.getPtrTy());
3957 highAddr = builder.CreatePointerCast(
3958 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3959 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3960 builder.getPtrTy());
3961 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3964 llvm::Value *size = builder.CreateIntCast(
3965 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3966 builder.getInt64Ty(),
3968 combinedInfo.Sizes.push_back(size);
3970 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3971 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3979 if (!parentClause.getPartialMap()) {
3984 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3985 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3986 combinedInfo.Types.emplace_back(mapFlag);
3987 combinedInfo.DevicePointers.emplace_back(
3989 combinedInfo.Mappers.emplace_back(
nullptr);
3991 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3992 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3993 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3994 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3996 return memberOfFlag;
4008 if (mapOp.getVarPtrPtr())
4023 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4024 MapInfoData &mapData, uint64_t mapDataIndex,
4025 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4026 assert(!ompBuilder.Config.isTargetDevice() &&
4027 "function only supported for host device codegen");
4030 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4032 for (
auto mappedMembers : parentClause.getMembers()) {
4034 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4037 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4048 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4049 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4050 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4051 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4052 combinedInfo.Types.emplace_back(mapFlag);
4053 combinedInfo.DevicePointers.emplace_back(
4055 combinedInfo.Mappers.emplace_back(
nullptr);
4056 combinedInfo.Names.emplace_back(
4058 combinedInfo.BasePointers.emplace_back(
4059 mapData.BasePointers[mapDataIndex]);
4060 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4061 combinedInfo.Sizes.emplace_back(builder.getInt64(
4062 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
4068 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4069 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4070 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4071 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4073 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4075 combinedInfo.Types.emplace_back(mapFlag);
4076 combinedInfo.DevicePointers.emplace_back(
4077 mapData.DevicePointers[memberDataIdx]);
4078 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4079 combinedInfo.Names.emplace_back(
4081 uint64_t basePointerIndex =
4083 combinedInfo.BasePointers.emplace_back(
4084 mapData.BasePointers[basePointerIndex]);
4085 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4087 llvm::Value *size = mapData.Sizes[memberDataIdx];
4089 size = builder.CreateSelect(
4090 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4091 builder.getInt64(0), size);
4094 combinedInfo.Sizes.emplace_back(size);
4099 MapInfosTy &combinedInfo,
bool isTargetParams,
4100 int mapDataParentIdx = -1) {
4104 auto mapFlag = mapData.Types[mapDataIdx];
4105 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4109 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4111 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4112 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4114 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4116 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4121 if (mapDataParentIdx >= 0)
4122 combinedInfo.BasePointers.emplace_back(
4123 mapData.BasePointers[mapDataParentIdx]);
4125 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4127 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4128 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4129 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4130 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4131 combinedInfo.Types.emplace_back(mapFlag);
4132 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4136 llvm::IRBuilderBase &builder,
4137 llvm::OpenMPIRBuilder &ompBuilder,
4139 MapInfoData &mapData, uint64_t mapDataIndex,
4140 bool isTargetParams) {
4141 assert(!ompBuilder.Config.isTargetDevice() &&
4142 "function only supported for host device codegen");
4145 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4150 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4151 auto memberClause = llvm::cast<omp::MapInfoOp>(
4152 parentClause.getMembers()[0].getDefiningOp());
4169 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4171 combinedInfo, mapData, mapDataIndex, isTargetParams);
4173 combinedInfo, mapData, mapDataIndex,
4174 memberOfParentFlag);
4184 llvm::IRBuilderBase &builder) {
4186 "function only supported for host device codegen");
4187 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4189 if (!mapData.IsDeclareTarget[i]) {
4190 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4191 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4201 switch (captureKind) {
4202 case omp::VariableCaptureKind::ByRef: {
4203 llvm::Value *newV = mapData.Pointers[i];
4205 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4208 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4210 if (!offsetIdx.empty())
4211 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4213 mapData.Pointers[i] = newV;
4215 case omp::VariableCaptureKind::ByCopy: {
4216 llvm::Type *type = mapData.BaseType[i];
4218 if (mapData.Pointers[i]->getType()->isPointerTy())
4219 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4221 newV = mapData.Pointers[i];
4224 auto curInsert = builder.saveIP();
4226 auto *memTempAlloc =
4227 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4228 builder.restoreIP(curInsert);
4230 builder.CreateStore(newV, memTempAlloc);
4231 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4234 mapData.Pointers[i] = newV;
4235 mapData.BasePointers[i] = newV;
4237 case omp::VariableCaptureKind::This:
4238 case omp::VariableCaptureKind::VLAType:
4239 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4250 MapInfoData &mapData,
bool isTargetParams =
false) {
4252 "function only supported for host device codegen");
4274 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4277 if (mapData.IsAMember[i])
4280 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4281 if (!mapInfoOp.getMembers().empty()) {
4283 combinedInfo, mapData, i, isTargetParams);
4294 llvm::StringRef mapperFuncName);
4300 "function only supported for host device codegen");
4301 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4302 std::string mapperFuncName =
4304 {
"omp_mapper", declMapperOp.getSymName()});
4306 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4316 llvm::StringRef mapperFuncName) {
4318 "function only supported for host device codegen");
4319 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4320 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4323 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4326 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4329 MapInfosTy combinedInfo;
4331 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4332 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4333 builder.restoreIP(codeGenIP);
4334 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4335 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4336 builder.GetInsertBlock());
4337 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4340 return llvm::make_error<PreviouslyReportedError>();
4341 MapInfoData mapData;
4344 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4349 return combinedInfo;
4353 if (!combinedInfo.Mappers[i])
4360 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4362 return newFn.takeError();
4363 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4367 static LogicalResult
4370 llvm::Value *ifCond =
nullptr;
4371 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4375 llvm::omp::RuntimeFunction RTLFn;
4379 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4381 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4382 bool isOffloadEntry =
4383 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4385 LogicalResult result =
4387 .Case([&](omp::TargetDataOp dataOp) {
4391 if (
auto ifVar = dataOp.getIfExpr())
4394 if (
auto devId = dataOp.getDevice())
4396 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4397 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4398 deviceID = intAttr.getInt();
4400 mapVars = dataOp.getMapVars();
4401 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4402 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4405 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4409 if (
auto ifVar = enterDataOp.getIfExpr())
4412 if (
auto devId = enterDataOp.getDevice())
4414 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4415 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4416 deviceID = intAttr.getInt();
4418 enterDataOp.getNowait()
4419 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4420 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4421 mapVars = enterDataOp.getMapVars();
4422 info.HasNoWait = enterDataOp.getNowait();
4425 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4429 if (
auto ifVar = exitDataOp.getIfExpr())
4432 if (
auto devId = exitDataOp.getDevice())
4434 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4435 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4436 deviceID = intAttr.getInt();
4438 RTLFn = exitDataOp.getNowait()
4439 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4440 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4441 mapVars = exitDataOp.getMapVars();
4442 info.HasNoWait = exitDataOp.getNowait();
4445 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4449 if (
auto ifVar = updateDataOp.getIfExpr())
4452 if (
auto devId = updateDataOp.getDevice())
4454 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4455 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4456 deviceID = intAttr.getInt();
4459 updateDataOp.getNowait()
4460 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4461 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4462 mapVars = updateDataOp.getMapVars();
4463 info.HasNoWait = updateDataOp.getNowait();
4467 llvm_unreachable(
"unexpected operation");
4474 if (!isOffloadEntry)
4475 ifCond = builder.getFalse();
4477 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4478 MapInfoData mapData;
4480 builder, useDevicePtrVars, useDeviceAddrVars);
4483 MapInfosTy combinedInfo;
4484 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4485 builder.restoreIP(codeGenIP);
4486 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4487 return combinedInfo;
4493 [&moduleTranslation](
4494 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4498 for (
auto [arg, useDevVar] :
4499 llvm::zip_equal(blockArgs, useDeviceVars)) {
4501 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4502 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4503 : mapInfoOp.getVarPtr();
4506 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4507 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4508 mapInfoData.MapClause, mapInfoData.DevicePointers,
4509 mapInfoData.BasePointers)) {
4510 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4511 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4512 devicePointer != type)
4515 if (llvm::Value *devPtrInfoMap =
4516 mapper ? mapper(basePointer) : basePointer) {
4517 moduleTranslation.
mapValue(arg, devPtrInfoMap);
4524 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4525 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4526 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4527 builder.restoreIP(codeGenIP);
4528 assert(isa<omp::TargetDataOp>(op) &&
4529 "BodyGen requested for non TargetDataOp");
4530 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4531 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4532 switch (bodyGenType) {
4533 case BodyGenTy::Priv:
4535 if (!info.DevicePtrInfoMap.empty()) {
4536 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4537 blockArgIface.getUseDeviceAddrBlockArgs(),
4538 useDeviceAddrVars, mapData,
4539 [&](llvm::Value *basePointer) -> llvm::Value * {
4540 if (!info.DevicePtrInfoMap[basePointer].second)
4542 return builder.CreateLoad(
4544 info.DevicePtrInfoMap[basePointer].second);
4546 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4547 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4548 mapData, [&](llvm::Value *basePointer) {
4549 return info.DevicePtrInfoMap[basePointer].second;
4553 moduleTranslation)))
4554 return llvm::make_error<PreviouslyReportedError>();
4557 case BodyGenTy::DupNoPriv:
4560 builder.restoreIP(codeGenIP);
4562 case BodyGenTy::NoPriv:
4564 if (info.DevicePtrInfoMap.empty()) {
4567 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4568 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4569 blockArgIface.getUseDeviceAddrBlockArgs(),
4570 useDeviceAddrVars, mapData);
4571 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4572 blockArgIface.getUseDevicePtrBlockArgs(),
4573 useDevicePtrVars, mapData);
4577 moduleTranslation)))
4578 return llvm::make_error<PreviouslyReportedError>();
4582 return builder.saveIP();
4585 auto customMapperCB =
4587 if (!combinedInfo.Mappers[i])
4589 info.HasMapper =
true;
4594 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4595 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4597 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4598 if (isa<omp::TargetDataOp>(op))
4599 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4600 builder.getInt64(deviceID), ifCond,
4601 info, genMapInfoCB, customMapperCB,
4604 return ompBuilder->createTargetData(
4605 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4606 info, genMapInfoCB, customMapperCB, &RTLFn);
4612 builder.restoreIP(*afterIP);
4616 static LogicalResult
4620 auto distributeOp = cast<omp::DistributeOp>(opInst);
4627 bool doDistributeReduction =
4631 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4636 if (doDistributeReduction) {
4637 isByRef =
getIsByRef(teamsOp.getReductionByref());
4638 assert(isByRef.size() == teamsOp.getNumReductionVars());
4641 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4645 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4646 .getReductionBlockArgs();
4649 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4650 reductionDecls, privateReductionVariables, reductionVariableMap,
4655 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4656 auto bodyGenCB = [&](InsertPointTy allocaIP,
4657 InsertPointTy codeGenIP) -> llvm::Error {
4661 moduleTranslation, allocaIP);
4664 builder.restoreIP(codeGenIP);
4670 return llvm::make_error<PreviouslyReportedError>();
4675 return llvm::make_error<PreviouslyReportedError>();
4678 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
4680 distributeOp.getPrivateNeedsBarrier())))
4681 return llvm::make_error<PreviouslyReportedError>();
4684 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4687 builder, moduleTranslation);
4689 return regionBlock.takeError();
4690 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4695 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4698 auto schedule = omp::ClauseScheduleKind::Static;
4699 bool isOrdered =
false;
4700 std::optional<omp::ScheduleModifier> scheduleMod;
4701 bool isSimd =
false;
4702 llvm::omp::WorksharingLoopType workshareLoopType =
4703 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4704 bool loopNeedsBarrier =
false;
4705 llvm::Value *chunk =
nullptr;
4707 llvm::CanonicalLoopInfo *loopInfo =
4709 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4710 ompBuilder->applyWorkshareLoop(
4711 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4712 convertToScheduleKind(schedule), chunk, isSimd,
4713 scheduleMod == omp::ScheduleModifier::monotonic,
4714 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4718 return wsloopIP.takeError();
4722 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4724 return llvm::make_error<PreviouslyReportedError>();
4726 return llvm::Error::success();
4729 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4731 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4732 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4733 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4738 builder.restoreIP(*afterIP);
4740 if (doDistributeReduction) {
4743 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4744 privateReductionVariables, isByRef,
4755 if (!cast<mlir::ModuleOp>(op))
4760 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
4761 attribute.getOpenmpDeviceVersion());
4763 if (attribute.getNoGpuLib())
4766 ompBuilder->createGlobalFlag(
4767 attribute.getDebugKind() ,
4768 "__omp_rtl_debug_kind");
4769 ompBuilder->createGlobalFlag(
4771 .getAssumeTeamsOversubscription()
4773 "__omp_rtl_assume_teams_oversubscription");
4774 ompBuilder->createGlobalFlag(
4776 .getAssumeThreadsOversubscription()
4778 "__omp_rtl_assume_threads_oversubscription");
4779 ompBuilder->createGlobalFlag(
4780 attribute.getAssumeNoThreadState() ,
4781 "__omp_rtl_assume_no_thread_state");
4782 ompBuilder->createGlobalFlag(
4784 .getAssumeNoNestedParallelism()
4786 "__omp_rtl_assume_no_nested_parallelism");
4791 omp::TargetOp targetOp,
4792 llvm::StringRef parentName =
"") {
4793 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
4795 assert(fileLoc &&
"No file found from location");
4796 StringRef fileName = fileLoc.getFilename().getValue();
4798 llvm::sys::fs::UniqueID id;
4799 uint64_t line = fileLoc.getLine();
4800 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
4802 size_t deviceId = 0xdeadf17e;
4804 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4806 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
4807 id.getFile(), line);
4814 llvm::IRBuilderBase &builder, llvm::Function *func) {
4816 "function only supported for target device codegen");
4817 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4830 if (mapData.IsDeclareTarget[i]) {
4837 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
4838 convertUsersOfConstantsToInstructions(constant, func,
false);
4845 for (llvm::User *user : mapData.OriginalValue[i]->users())
4846 userVec.push_back(user);
4848 for (llvm::User *user : userVec) {
4849 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
4850 if (insn->getFunction() == func) {
4851 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
4852 mapData.BasePointers[i]);
4853 load->moveBefore(insn->getIterator());
4854 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
4901 static llvm::IRBuilderBase::InsertPoint
4903 llvm::Value *input, llvm::Value *&retVal,
4904 llvm::IRBuilderBase &builder,
4905 llvm::OpenMPIRBuilder &ompBuilder,
4907 llvm::IRBuilderBase::InsertPoint allocaIP,
4908 llvm::IRBuilderBase::InsertPoint codeGenIP) {
4909 assert(ompBuilder.Config.isTargetDevice() &&
4910 "function only supported for target device codegen");
4911 builder.restoreIP(allocaIP);
4913 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
4915 ompBuilder.M.getContext());
4916 unsigned alignmentValue = 0;
4918 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
4919 if (mapData.OriginalValue[i] == input) {
4920 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4921 capture = mapOp.getMapCaptureType();
4924 mapOp.getVarType(), ompBuilder.M.getDataLayout());
4928 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
4929 unsigned int defaultAS =
4930 ompBuilder.M.getDataLayout().getProgramAddressSpace();
4933 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
4935 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
4936 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
4938 builder.CreateStore(&arg, v);
4940 builder.restoreIP(codeGenIP);
4943 case omp::VariableCaptureKind::ByCopy: {
4947 case omp::VariableCaptureKind::ByRef: {
4948 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
4950 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
4965 if (v->getType()->isPointerTy() && alignmentValue) {
4966 llvm::MDBuilder MDB(builder.getContext());
4967 loadInst->setMetadata(
4968 llvm::LLVMContext::MD_align,
4971 llvm::Type::getInt64Ty(builder.getContext()),
4978 case omp::VariableCaptureKind::This:
4979 case omp::VariableCaptureKind::VLAType:
4982 assert(
false &&
"Currently unsupported capture kind");
4986 return builder.saveIP();
5003 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
5004 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5005 blockArgIface.getHostEvalBlockArgs())) {
5006 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5010 .Case([&](omp::TeamsOp teamsOp) {
5011 if (teamsOp.getNumTeamsLower() == blockArg)
5012 numTeamsLower = hostEvalVar;
5013 else if (teamsOp.getNumTeamsUpper() == blockArg)
5014 numTeamsUpper = hostEvalVar;
5015 else if (teamsOp.getThreadLimit() == blockArg)
5016 threadLimit = hostEvalVar;
5018 llvm_unreachable(
"unsupported host_eval use");
5020 .Case([&](omp::ParallelOp parallelOp) {
5021 if (parallelOp.getNumThreads() == blockArg)
5022 numThreads = hostEvalVar;
5024 llvm_unreachable(
"unsupported host_eval use");
5026 .Case([&](omp::LoopNestOp loopOp) {
5027 auto processBounds =
5032 if (lb == blockArg) {
5035 (*outBounds)[i] = hostEvalVar;
5041 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5042 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5044 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5046 assert(found &&
"unsupported host_eval use");
5049 llvm_unreachable(
"unsupported host_eval use");
5062 template <
typename OpTy>
5067 if (OpTy casted = dyn_cast<OpTy>(op))
5070 if (immediateParent)
5071 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5080 return std::nullopt;
5083 dyn_cast_if_present<LLVM::ConstantOp>(value.
getDefiningOp()))
5084 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5085 return constAttr.getInt();
5087 return std::nullopt;
5092 uint64_t sizeInBytes = sizeInBits / 8;
5096 template <
typename OpTy>
5098 if (op.getNumReductionVars() > 0) {
5103 members.reserve(reductions.size());
5104 for (omp::DeclareReductionOp &red : reductions)
5105 members.push_back(red.getType());
5107 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5123 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5124 bool isTargetDevice,
bool isGPU) {
5127 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5128 if (!isTargetDevice) {
5135 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5136 numTeamsLower = teamsOp.getNumTeamsLower();
5137 numTeamsUpper = teamsOp.getNumTeamsUpper();
5138 threadLimit = teamsOp.getThreadLimit();
5141 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5142 numThreads = parallelOp.getNumThreads();
5147 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5148 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5151 if (numTeamsUpper) {
5153 minTeamsVal = maxTeamsVal = *val;
5155 minTeamsVal = maxTeamsVal = 0;
5157 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5159 castOrGetParentOfType<omp::SimdOp>(capturedOp,
5161 minTeamsVal = maxTeamsVal = 1;
5163 minTeamsVal = maxTeamsVal = -1;
5168 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
5182 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5183 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5184 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5187 int32_t maxThreadsVal = -1;
5188 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5189 setMaxValueFromClause(numThreads, maxThreadsVal);
5190 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5197 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5198 if (combinedMaxThreadsVal < 0 ||
5199 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5200 combinedMaxThreadsVal = teamsThreadLimitVal;
5202 if (combinedMaxThreadsVal < 0 ||
5203 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5204 combinedMaxThreadsVal = maxThreadsVal;
5206 int32_t reductionDataSize = 0;
5207 if (isGPU && capturedOp) {
5208 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5213 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5215 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5216 omp::TargetRegionFlags::spmd) &&
5217 "invalid kernel flags");
5219 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5220 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5221 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5222 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5223 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5224 attrs.MinTeams = minTeamsVal;
5225 attrs.MaxTeams.front() = maxTeamsVal;
5226 attrs.MinThreads = 1;
5227 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5228 attrs.ReductionDataSize = reductionDataSize;
5231 if (attrs.ReductionDataSize != 0)
5232 attrs.ReductionBufferLength = 1024;
5244 omp::TargetOp targetOp,
Operation *capturedOp,
5245 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5246 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5247 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5249 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5253 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5256 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5257 attrs.TargetThreadLimit.front() =
5261 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5264 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5266 if (teamsThreadLimit)
5267 attrs.TeamsThreadLimit.front() =
5271 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5273 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5274 omp::TargetRegionFlags::trip_count)) {
5276 attrs.LoopTripCount =
nullptr;
5281 for (
auto [loopLower, loopUpper, loopStep] :
5282 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5283 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5284 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5285 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5287 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5288 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5289 loc, lowerBound, upperBound, step,
true,
5290 loopOp.getLoopInclusive());
5292 if (!attrs.LoopTripCount) {
5293 attrs.LoopTripCount = tripCount;
5298 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5304 static LogicalResult
5307 auto targetOp = cast<omp::TargetOp>(opInst);
5312 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5313 bool isGPU = ompBuilder->Config.isGPU();
5316 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5317 auto &targetRegion = targetOp.getRegion();
5334 llvm::Function *llvmOutlinedFn =
nullptr;
5338 bool isOffloadEntry =
5339 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5346 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5348 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5349 std::optional<DenseI64ArrayAttr> privateMapIndices =
5350 targetOp.getPrivateMapsAttr();
5352 for (
auto [privVarIdx, privVarSymPair] :
5354 auto privVar = std::get<0>(privVarSymPair);
5355 auto privSym = std::get<1>(privVarSymPair);
5357 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5358 omp::PrivateClauseOp privatizer =
5361 if (!privatizer.needsMap())
5365 targetOp.getMappedValueForPrivateVar(privVarIdx);
5366 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5367 "variable that needs mapping");
5372 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5373 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5377 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5379 varType == privVar.getType() &&
5380 "Type of private var doesn't match the type of the mapped value");
5384 mappedPrivateVars.insert(
5386 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5387 (*privateMapIndices)[privVarIdx])});
5391 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5392 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5393 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5394 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5395 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5398 llvm::Function *llvmParentFn =
5400 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5401 assert(llvmParentFn && llvmOutlinedFn &&
5402 "Both parent and outlined functions must exist at this point");
5404 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5405 attr.isStringAttribute())
5406 llvmOutlinedFn->addFnAttr(attr);
5408 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5409 attr.isStringAttribute())
5410 llvmOutlinedFn->addFnAttr(attr);
5412 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5413 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5414 llvm::Value *mapOpValue =
5415 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5416 moduleTranslation.
mapValue(arg, mapOpValue);
5418 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5419 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5420 llvm::Value *mapOpValue =
5421 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5422 moduleTranslation.
mapValue(arg, mapOpValue);
5431 allocaIP, &mappedPrivateVars);
5434 return llvm::make_error<PreviouslyReportedError>();
5436 builder.restoreIP(codeGenIP);
5438 &mappedPrivateVars),
5441 return llvm::make_error<PreviouslyReportedError>();
5444 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
5446 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5447 return llvm::make_error<PreviouslyReportedError>();
5451 std::back_inserter(privateCleanupRegions),
5452 [](omp::PrivateClauseOp privatizer) {
5453 return &privatizer.getDeallocRegion();
5457 targetRegion,
"omp.target", builder, moduleTranslation);
5460 return exitBlock.takeError();
5462 builder.SetInsertPoint(*exitBlock);
5463 if (!privateCleanupRegions.empty()) {
5465 privateCleanupRegions, privateVarsInfo.
llvmVars,
5466 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5468 return llvm::createStringError(
5469 "failed to inline `dealloc` region of `omp.private` "
5470 "op in the target region");
5472 return builder.saveIP();
5475 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5478 StringRef parentName = parentFn.getName();
5480 llvm::TargetRegionEntryInfo entryInfo;
5484 MapInfoData mapData;
5489 MapInfosTy combinedInfos;
5491 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5492 builder.restoreIP(codeGenIP);
5493 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5494 return combinedInfos;
5497 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5498 llvm::Value *&retVal, InsertPointTy allocaIP,
5499 InsertPointTy codeGenIP)
5500 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5501 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5502 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5508 if (!isTargetDevice) {
5509 retVal = cast<llvm::Value>(&arg);
5514 *ompBuilder, moduleTranslation,
5515 allocaIP, codeGenIP);
5518 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5519 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5520 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5522 isTargetDevice, isGPU);
5526 if (!isTargetDevice)
5528 targetCapturedOp, runtimeAttrs);
5536 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5537 llvm::Value *value = moduleTranslation.
lookupValue(var);
5538 moduleTranslation.
mapValue(arg, value);
5540 if (!llvm::isa<llvm::Constant>(value))
5541 kernelInput.push_back(value);
5544 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5551 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5552 kernelInput.push_back(mapData.OriginalValue[i]);
5557 moduleTranslation, dds);
5559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5561 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5563 llvm::OpenMPIRBuilder::TargetDataInfo info(
5567 auto customMapperCB =
5569 if (!combinedInfos.Mappers[i])
5571 info.HasMapper =
true;
5576 llvm::Value *ifCond =
nullptr;
5577 if (
Value targetIfCond = targetOp.getIfExpr())
5578 ifCond = moduleTranslation.
lookupValue(targetIfCond);
5580 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5582 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5583 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5584 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5589 builder.restoreIP(*afterIP);
5600 static LogicalResult
5610 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5611 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5613 if (!offloadMod.getIsTargetDevice())
5616 omp::DeclareTargetDeviceType declareType =
5617 attribute.getDeviceType().getValue();
5619 if (declareType == omp::DeclareTargetDeviceType::host) {
5620 llvm::Function *llvmFunc =
5622 llvmFunc->dropAllReferences();
5623 llvmFunc->eraseFromParent();
5629 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5630 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
5631 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5633 bool isDeclaration = gOp.isDeclaration();
5634 bool isExternallyVisible =
5637 llvm::StringRef mangledName = gOp.getSymName();
5638 auto captureClause =
5644 std::vector<llvm::GlobalVariable *> generatedRefs;
5646 std::vector<llvm::Triple> targetTriple;
5647 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5649 LLVM::LLVMDialect::getTargetTripleAttrName()));
5650 if (targetTripleAttr)
5651 targetTriple.emplace_back(targetTripleAttr.data());
5653 auto fileInfoCallBack = [&loc]() {
5654 std::string filename =
"";
5655 std::uint64_t lineNo = 0;
5658 filename = loc.getFilename().str();
5659 lineNo = loc.getLine();
5662 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5666 ompBuilder->registerTargetGlobalVariable(
5667 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5668 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5669 generatedRefs,
false, targetTriple,
5671 gVal->getType(), gVal);
5673 if (ompBuilder->Config.isTargetDevice() &&
5674 (attribute.getCaptureClause().getValue() !=
5675 mlir::omp::DeclareTargetCaptureClause::to ||
5676 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5677 ompBuilder->getAddrOfDeclareTargetVar(
5678 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5679 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5680 generatedRefs,
false, targetTriple, gVal->getType(),
5702 if (mlir::isa<omp::ThreadprivateOp>(op))
5706 if (
auto declareTargetIface =
5707 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5708 parentFn.getOperation()))
5709 if (declareTargetIface.isDeclareTarget() &&
5710 declareTargetIface.getDeclareTargetDeviceType() !=
5711 mlir::omp::DeclareTargetDeviceType::host)
5719 static LogicalResult
5730 bool isOutermostLoopWrapper =
5731 isa_and_present<omp::LoopWrapperInterface>(op) &&
5732 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
5734 if (isOutermostLoopWrapper)
5735 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
5739 .Case([&](omp::BarrierOp op) -> LogicalResult {
5743 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5744 ompBuilder->createBarrier(builder.saveIP(),
5745 llvm::omp::OMPD_barrier);
5747 if (res.succeeded()) {
5750 builder.restoreIP(*afterIP);
5754 .Case([&](omp::TaskyieldOp op) {
5758 ompBuilder->createTaskyield(builder.saveIP());
5761 .Case([&](omp::FlushOp op) {
5773 ompBuilder->createFlush(builder.saveIP());
5776 .Case([&](omp::ParallelOp op) {
5779 .Case([&](omp::MaskedOp) {
5782 .Case([&](omp::MasterOp) {
5785 .Case([&](omp::CriticalOp) {
5788 .Case([&](omp::OrderedRegionOp) {
5791 .Case([&](omp::OrderedOp) {
5794 .Case([&](omp::WsloopOp) {
5797 .Case([&](omp::SimdOp) {
5800 .Case([&](omp::AtomicReadOp) {
5803 .Case([&](omp::AtomicWriteOp) {
5806 .Case([&](omp::AtomicUpdateOp op) {
5809 .Case([&](omp::AtomicCaptureOp op) {
5812 .Case([&](omp::CancelOp op) {
5815 .Case([&](omp::CancellationPointOp op) {
5818 .Case([&](omp::SectionsOp) {
5821 .Case([&](omp::SingleOp op) {
5824 .Case([&](omp::TeamsOp op) {
5827 .Case([&](omp::TaskOp op) {
5830 .Case([&](omp::TaskgroupOp op) {
5833 .Case([&](omp::TaskwaitOp op) {
5836 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5837 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5838 omp::CriticalDeclareOp>([](
auto op) {
5851 .Case([&](omp::ThreadprivateOp) {
5854 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
5855 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
5858 .Case([&](omp::TargetOp) {
5861 .Case([&](omp::DistributeOp) {
5864 .Case([&](omp::LoopNestOp) {
5867 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
5876 <<
"not yet implemented: " << inst->
getName();
5879 if (isOutermostLoopWrapper)
5885 static LogicalResult
5891 static LogicalResult
5894 if (isa<omp::TargetOp>(op))
5896 if (isa<omp::TargetDataOp>(op))
5900 if (isa<omp::TargetOp>(oper)) {
5902 return WalkResult::interrupt();
5903 return WalkResult::skip();
5905 if (isa<omp::TargetDataOp>(oper)) {
5907 return WalkResult::interrupt();
5908 return WalkResult::skip();
5915 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5916 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5917 !oper->getRegions().empty()) {
5918 if (
auto blockArgsIface =
5919 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5925 if (isa<mlir::omp::AtomicUpdateOp>(oper))
5926 for (
auto [operand, arg] :
5927 llvm::zip_equal(oper->getOperands(),
5928 oper->getRegion(0).getArguments())) {
5930 arg, builder.CreateLoad(
5936 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5937 assert(builder.GetInsertBlock() &&
5938 "No insert block is set for the builder");
5939 for (
auto iv : loopNest.getIVs()) {
5947 for (
Region ®ion : oper->getRegions()) {
5954 region, oper->getName().getStringRef().str() +
".fake.region",
5955 builder, moduleTranslation, &phis);
5957 return WalkResult::interrupt();
5959 builder.SetInsertPoint(result.get(), result.get()->end());
5962 return WalkResult::skip();
5965 return WalkResult::advance();
5966 }).wasInterrupted();
5967 return failure(interrupted);
5974 class OpenMPDialectLLVMIRTranslationInterface
5995 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6001 .Case(
"omp.is_target_device",
6003 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6004 llvm::OpenMPIRBuilderConfig &
config =
6006 config.setIsTargetDevice(deviceAttr.getValue());
6013 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6014 llvm::OpenMPIRBuilderConfig &
config =
6016 config.setIsGPU(gpuAttr.getValue());
6021 .Case(
"omp.host_ir_filepath",
6023 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6024 llvm::OpenMPIRBuilder *ompBuilder =
6026 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
6033 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6037 .Case(
"omp.version",
6039 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6040 llvm::OpenMPIRBuilder *ompBuilder =
6042 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6043 versionAttr.getVersion());
6048 .Case(
"omp.declare_target",
6050 if (
auto declareTargetAttr =
6051 dyn_cast<omp::DeclareTargetAttr>(attr))
6056 .Case(
"omp.requires",
6058 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6059 using Requires = omp::ClauseRequires;
6060 Requires flags = requiresAttr.getValue();
6061 llvm::OpenMPIRBuilderConfig &
config =
6063 config.setHasRequiresReverseOffload(
6064 bitEnumContainsAll(flags, Requires::reverse_offload));
6065 config.setHasRequiresUnifiedAddress(
6066 bitEnumContainsAll(flags, Requires::unified_address));
6067 config.setHasRequiresUnifiedSharedMemory(
6068 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6069 config.setHasRequiresDynamicAllocators(
6070 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6075 .Case(
"omp.target_triples",
6077 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6078 llvm::OpenMPIRBuilderConfig &
config =
6080 config.TargetTriples.clear();
6081 config.TargetTriples.reserve(triplesAttr.size());
6082 for (
Attribute tripleAttr : triplesAttr) {
6083 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6084 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6102 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6103 Operation *op, llvm::IRBuilderBase &builder,
6107 if (ompBuilder->Config.isTargetDevice()) {
6118 registry.
insert<omp::OpenMPDialect>();
6120 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult convertIgnoredWrapper(omp::LoopWrapperInterface opInst, LLVM::ModuleTranslation &moduleTranslation)
Helper function to map block arguments defined by ignored loop wrappers to LLVM values and prevent an...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.