26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
31 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/MDBuilder.h"
37 #include "llvm/IR/ReplaceConstant.h"
38 #include "llvm/Support/FileSystem.h"
39 #include "llvm/TargetParser/Triple.h"
40 #include "llvm/Transforms/Utils/ModuleUtils.h"
52 static llvm::omp::ScheduleKind
53 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
66 return llvm::omp::OMP_SCHEDULE_Runtime;
68 llvm_unreachable(
"unhandled schedule clause argument");
73 class OpenMPAllocaStackFrame
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
86 class OpenMPLoopInfoStackFrame
90 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
109 class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
136 class LinearClauseProcessor {
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
150 void createLinearVar(llvm::IRBuilderBase &builder,
153 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
155 linearPreconditionVars.push_back(builder.CreateAlloca(
156 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_var"));
157 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
158 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_result");
159 linearOrigVal.push_back(moduleTranslation.
lookupValue(linearVar));
160 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
161 linearOrigVars.push_back(linearVarAlloca);
168 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173 initLinearVar(llvm::IRBuilderBase &builder,
175 llvm::BasicBlock *loopPreHeader) {
176 builder.SetInsertPoint(loopPreHeader->getTerminator());
177 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
178 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179 linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
180 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
184 builder.saveIP(), llvm::omp::OMPD_barrier);
185 return afterBarrierIP;
189 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190 llvm::Value *loopInductionVar) {
191 builder.SetInsertPoint(loopBody->getTerminator());
192 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
194 llvm::LoadInst *linearVarStart =
195 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
197 linearPreconditionVars[index]);
198 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
199 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
200 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
206 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207 llvm::BasicBlock *loopExit) {
208 linearFinalizationBB = loopExit->splitBasicBlock(
209 loopExit->getTerminator(),
"omp_loop.linear_finalization");
210 linearExitBB = linearFinalizationBB->splitBasicBlock(
211 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
212 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
213 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
217 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218 finalizeLinearVar(llvm::IRBuilderBase &builder,
220 llvm::Value *lastIter) {
222 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
223 llvm::Value *loopLastIterLoad = builder.CreateLoad(
224 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
225 llvm::Value *isLast =
226 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
228 llvm::Type::getInt32Ty(builder.getContext()), 0));
230 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
231 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
232 llvm::LoadInst *linearVarTemp =
233 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
234 linearLoopBodyTemps[index]);
235 builder.CreateStore(linearVarTemp, linearOrigVars[index]);
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
245 builder.SetInsertPoint(linearExitBB->getTerminator());
247 builder.saveIP(), llvm::omp::OMPD_barrier);
252 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(user);
257 for (
auto *user : users) {
258 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259 if (userInst->getParent()->getName().str() == BBName)
260 user->replaceUsesOfWith(linearOrigVal[varIndex],
261 linearLoopBodyTemps[varIndex]);
272 SymbolRefAttr symbolName) {
273 omp::PrivateClauseOp privatizer =
274 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
276 assert(privatizer &&
"privatizer not found in the symbol table");
287 auto todo = [&op](StringRef clauseName) {
288 return op.
emitError() <<
"not yet implemented: Unhandled clause "
289 << clauseName <<
" in " << op.
getName()
293 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
294 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
295 result = todo(
"allocate");
297 auto checkBare = [&todo](
auto op, LogicalResult &result) {
299 result = todo(
"ompx_bare");
301 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
302 omp::ClauseCancellationConstructType cancelledDirective =
303 op.getCancelDirective();
306 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
313 if (isa_and_nonnull<omp::TaskloopOp>(parent))
314 result = todo(
"cancel directive inside of taskloop");
317 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
318 if (!op.getDependVars().empty() || op.getDependKinds())
319 result = todo(
"depend");
321 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
323 result = todo(
"device");
325 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
326 if (op.getDistScheduleChunkSize())
327 result = todo(
"dist_schedule with chunk_size");
329 auto checkHint = [](
auto op, LogicalResult &) {
333 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
334 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
335 op.getInReductionSyms())
336 result = todo(
"in_reduction");
338 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
339 if (!op.getIsDevicePtrVars().empty())
340 result = todo(
"is_device_ptr");
342 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
343 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
344 result = todo(
"linear");
346 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
348 result = todo(
"nowait");
350 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
351 if (op.getOrder() || op.getOrderMod())
352 result = todo(
"order");
354 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
355 if (op.getParLevelSimd())
356 result = todo(
"parallelization-level");
358 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
359 if (op.getPriority())
360 result = todo(
"priority");
362 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
363 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
365 if (!op.getPrivateVars().empty() && op.getNowait())
366 result = todo(
"privatization for deferred target tasks");
368 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
369 result = todo(
"privatization");
372 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
373 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
374 if (!op.getReductionVars().empty() || op.getReductionByref() ||
375 op.getReductionSyms())
376 result = todo(
"reduction");
377 if (op.getReductionMod() &&
378 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
379 result = todo(
"reduction with modifier");
381 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
382 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
383 op.getTaskReductionSyms())
384 result = todo(
"task_reduction");
386 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
388 result = todo(
"untied");
391 LogicalResult result = success();
393 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
394 .Case([&](omp::CancellationPointOp op) {
395 checkCancelDirective(op, result);
397 .Case([&](omp::DistributeOp op) {
398 checkAllocate(op, result);
399 checkDistSchedule(op, result);
400 checkOrder(op, result);
402 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op, result);
405 checkPrivate(op, result);
406 checkReduction(op, result);
408 .Case([&](omp::SingleOp op) {
409 checkAllocate(op, result);
410 checkPrivate(op, result);
412 .Case([&](omp::TeamsOp op) {
413 checkAllocate(op, result);
414 checkPrivate(op, result);
416 .Case([&](omp::TaskOp op) {
417 checkAllocate(op, result);
418 checkInReduction(op, result);
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
424 .Case([&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
428 .Case([&](omp::TaskloopOp op) {
430 checkUntied(op, result);
431 checkPriority(op, result);
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkLinear(op, result);
436 checkOrder(op, result);
437 checkReduction(op, result);
439 .Case([&](omp::ParallelOp op) {
440 checkAllocate(op, result);
441 checkReduction(op, result);
443 .Case([&](omp::SimdOp op) {
444 checkLinear(op, result);
445 checkReduction(op, result);
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
450 [&](
auto op) { checkDepend(op, result); })
451 .Case([&](omp::TargetOp op) {
452 checkAllocate(op, result);
453 checkBare(op, result);
454 checkDevice(op, result);
455 checkInReduction(op, result);
456 checkIsDevicePtr(op, result);
457 checkPrivate(op, result);
467 LogicalResult result = success();
469 llvm::handleAllErrors(
471 [&](
const PreviouslyReportedError &) { result = failure(); },
472 [&](
const llvm::ErrorInfoBase &err) {
479 template <
typename T>
489 static llvm::OpenMPIRBuilder::InsertPointTy
495 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
497 [&](OpenMPAllocaStackFrame &frame) {
498 allocaInsertPoint = frame.allocaInsertPoint;
502 return allocaInsertPoint;
511 if (builder.GetInsertBlock() ==
512 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
513 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
514 "Assuming end of basic block");
515 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
516 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
517 builder.GetInsertBlock()->getNextNode());
518 builder.CreateBr(entryBB);
519 builder.SetInsertPoint(entryBB);
522 llvm::BasicBlock &funcEntryBlock =
523 builder.GetInsertBlock()->getParent()->getEntryBlock();
524 return llvm::OpenMPIRBuilder::InsertPointTy(
525 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
531 static llvm::CanonicalLoopInfo *
533 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
534 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
535 [&](OpenMPLoopInfoStackFrame &frame) {
536 loopInfo = frame.loopInfo;
548 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
551 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
553 llvm::BasicBlock *continuationBlock =
554 splitBB(builder,
true,
"omp.region.cont");
555 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
557 llvm::LLVMContext &llvmContext = builder.getContext();
558 for (
Block &bb : region) {
559 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
560 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
561 builder.GetInsertBlock()->getNextNode());
562 moduleTranslation.
mapBlock(&bb, llvmBB);
565 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
572 unsigned numYields = 0;
574 if (!isLoopWrapper) {
575 bool operandsProcessed =
false;
577 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
578 if (!operandsProcessed) {
579 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
580 continuationBlockPHITypes.push_back(
581 moduleTranslation.
convertType(yield->getOperand(i).getType()));
583 operandsProcessed =
true;
585 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
586 "mismatching number of values yielded from the region");
587 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
588 llvm::Type *operandType =
589 moduleTranslation.
convertType(yield->getOperand(i).getType());
591 assert(continuationBlockPHITypes[i] == operandType &&
592 "values of mismatching types yielded from the region");
602 if (!continuationBlockPHITypes.empty())
604 continuationBlockPHIs &&
605 "expected continuation block PHIs if converted regions yield values");
606 if (continuationBlockPHIs) {
607 llvm::IRBuilderBase::InsertPointGuard guard(builder);
608 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
609 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
610 for (llvm::Type *ty : continuationBlockPHITypes)
611 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
617 for (
Block *bb : blocks) {
618 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
621 if (bb->isEntryBlock()) {
622 assert(sourceTerminator->getNumSuccessors() == 1 &&
623 "provided entry block has multiple successors");
624 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
625 "ContinuationBlock is not the successor of the entry block");
626 sourceTerminator->setSuccessor(0, llvmBB);
629 llvm::IRBuilderBase::InsertPointGuard guard(builder);
631 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
632 return llvm::make_error<PreviouslyReportedError>();
637 builder.CreateBr(continuationBlock);
648 Operation *terminator = bb->getTerminator();
649 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
650 builder.CreateBr(continuationBlock);
652 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
653 (*continuationBlockPHIs)[i]->addIncoming(
667 return continuationBlock;
673 case omp::ClauseProcBindKind::Close:
674 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
675 case omp::ClauseProcBindKind::Master:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
677 case omp::ClauseProcBindKind::Primary:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
679 case omp::ClauseProcBindKind::Spread:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
682 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
692 omp::BlockArgOpenMPOpInterface blockArgIface) {
694 blockArgIface.getBlockArgsPairs(blockArgsPairs);
695 for (
auto [var, arg] : blockArgsPairs)
712 .Case([&](omp::SimdOp op) {
714 cast<omp::BlockArgOpenMPOpInterface>(*op));
715 op.emitWarning() <<
"simd information on composite construct discarded";
719 return op->emitError() <<
"cannot ignore wrapper";
727 auto maskedOp = cast<omp::MaskedOp>(opInst);
728 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
733 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
735 auto ®ion = maskedOp.getRegion();
736 builder.restoreIP(codeGenIP);
744 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
746 llvm::Value *filterVal =
nullptr;
747 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
748 filterVal = moduleTranslation.
lookupValue(filterVar);
750 llvm::LLVMContext &llvmContext = builder.getContext();
754 assert(filterVal !=
nullptr);
755 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
756 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
763 builder.restoreIP(*afterIP);
771 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
772 auto masterOp = cast<omp::MasterOp>(opInst);
777 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
779 auto ®ion = masterOp.getRegion();
780 builder.restoreIP(codeGenIP);
788 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
790 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
791 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
798 builder.restoreIP(*afterIP);
806 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
807 auto criticalOp = cast<omp::CriticalOp>(opInst);
812 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
814 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
815 builder.restoreIP(codeGenIP);
823 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
825 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
826 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
827 llvm::Constant *hint =
nullptr;
830 if (criticalOp.getNameAttr()) {
833 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
834 auto criticalDeclareOp =
835 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
839 static_cast<int>(criticalDeclareOp.getHint()));
841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
843 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
848 builder.restoreIP(*afterIP);
855 template <
typename OP>
858 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
859 mlirVars.reserve(blockArgs.size());
860 llvmVars.reserve(blockArgs.size());
861 collectPrivatizationDecls<OP>(op);
864 mlirVars.push_back(privateVar);
876 void collectPrivatizationDecls(OP op) {
877 std::optional<ArrayAttr> attr = op.getPrivateSyms();
881 privatizers.reserve(privatizers.size() + attr->size());
882 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
889 template <
typename T>
893 std::optional<ArrayAttr> attr = op.getReductionSyms();
897 reductions.reserve(reductions.size() + op.getNumReductionVars());
898 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
899 reductions.push_back(
900 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
911 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
919 if (llvm::hasSingleElement(region)) {
920 llvm::Instruction *potentialTerminator =
921 builder.GetInsertBlock()->empty() ? nullptr
922 : &builder.GetInsertBlock()->back();
924 if (potentialTerminator && potentialTerminator->isTerminator())
925 potentialTerminator->removeFromParent();
926 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
929 region.
front(),
true, builder)))
933 if (continuationBlockArgs)
935 *continuationBlockArgs,
942 if (potentialTerminator && potentialTerminator->isTerminator()) {
943 llvm::BasicBlock *block = builder.GetInsertBlock();
944 if (block->empty()) {
950 potentialTerminator->insertInto(block, block->begin());
952 potentialTerminator->insertAfter(&block->back());
966 if (continuationBlockArgs)
967 llvm::append_range(*continuationBlockArgs, phis);
968 builder.SetInsertPoint(*continuationBlock,
969 (*continuationBlock)->getFirstInsertionPt());
976 using OwningReductionGen =
977 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
978 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
980 using OwningAtomicReductionGen =
981 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
982 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
989 static OwningReductionGen
995 OwningReductionGen gen =
996 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
997 llvm::Value *lhs, llvm::Value *rhs,
998 llvm::Value *&result)
mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.
mapValue(decl.getReductionLhsArg(), lhs);
1001 moduleTranslation.
mapValue(decl.getReductionRhsArg(), rhs);
1002 builder.restoreIP(insertPoint);
1005 "omp.reduction.nonatomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `combiner` region of `omp.declare_reduction`");
1009 result = llvm::getSingleElement(phis);
1010 return builder.saveIP();
1019 static OwningAtomicReductionGen
1021 llvm::IRBuilderBase &builder,
1023 if (decl.getAtomicReductionRegion().empty())
1024 return OwningAtomicReductionGen();
1029 OwningAtomicReductionGen atomicGen =
1030 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1031 llvm::Value *lhs, llvm::Value *rhs)
mutable
1032 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1033 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(), lhs);
1034 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(), rhs);
1035 builder.restoreIP(insertPoint);
1038 "omp.reduction.atomic.body", builder,
1039 moduleTranslation, &phis)))
1040 return llvm::createStringError(
1041 "failed to inline `atomic` region of `omp.declare_reduction`");
1042 assert(phis.empty());
1043 return builder.saveIP();
1049 static LogicalResult
1052 auto orderedOp = cast<omp::OrderedOp>(opInst);
1057 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1058 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1059 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1061 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1063 size_t indexVecValues = 0;
1064 while (indexVecValues < vecValues.size()) {
1066 storeValues.reserve(numLoops);
1067 for (
unsigned i = 0; i < numLoops; i++) {
1068 storeValues.push_back(vecValues[indexVecValues]);
1071 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1073 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1074 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1075 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1082 static LogicalResult
1085 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1086 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1091 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1093 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1094 builder.restoreIP(codeGenIP);
1102 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1104 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1105 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1107 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1112 builder.restoreIP(*afterIP);
1118 struct DeferredStore {
1119 DeferredStore(llvm::Value *value, llvm::Value *address)
1120 : value(value), address(address) {}
1123 llvm::Value *address;
1130 template <
typename T>
1131 static LogicalResult
1133 llvm::IRBuilderBase &builder,
1135 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1141 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1142 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1145 deferredStores.reserve(loop.getNumReductionVars());
1147 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1148 Region &allocRegion = reductionDecls[i].getAllocRegion();
1150 if (allocRegion.
empty())
1155 builder, moduleTranslation, &phis)))
1156 return loop.emitError(
1157 "failed to inline `alloc` region of `omp.declare_reduction`");
1159 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1160 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1164 llvm::Value *var = builder.CreateAlloca(
1165 moduleTranslation.
convertType(reductionDecls[i].getType()));
1167 llvm::Type *ptrTy = builder.getPtrTy();
1168 llvm::Value *castVar =
1169 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1170 llvm::Value *castPhi =
1171 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1173 deferredStores.emplace_back(castPhi, castVar);
1175 privateReductionVariables[i] = castVar;
1176 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1177 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1179 assert(allocRegion.
empty() &&
1180 "allocaction is implicit for by-val reduction");
1181 llvm::Value *var = builder.CreateAlloca(
1182 moduleTranslation.
convertType(reductionDecls[i].getType()));
1184 llvm::Type *ptrTy = builder.getPtrTy();
1185 llvm::Value *castVar =
1186 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1188 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1189 privateReductionVariables[i] = castVar;
1190 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1198 template <
typename T>
1205 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1206 Region &initializerRegion = reduction.getInitializerRegion();
1209 mlir::Value mlirSource = loop.getReductionVars()[i];
1210 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1211 assert(llvmSource &&
"lookup reduction var");
1212 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1215 llvm::Value *allocation =
1216 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1217 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1223 llvm::BasicBlock *block =
nullptr) {
1224 if (block ==
nullptr)
1225 block = builder.GetInsertBlock();
1227 if (block->empty() || block->getTerminator() ==
nullptr)
1228 builder.SetInsertPoint(block);
1230 builder.SetInsertPoint(block->getTerminator());
1238 template <
typename OP>
1239 static LogicalResult
1241 llvm::IRBuilderBase &builder,
1243 llvm::BasicBlock *latestAllocaBlock,
1249 if (op.getNumReductionVars() == 0)
1252 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1253 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1254 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1255 builder.restoreIP(allocaIP);
1258 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1260 if (!reductionDecls[i].getAllocRegion().empty())
1266 byRefVars[i] = builder.CreateAlloca(
1267 moduleTranslation.
convertType(reductionDecls[i].getType()));
1275 for (
auto [data, addr] : deferredStores)
1276 builder.CreateStore(data, addr);
1281 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1286 reductionVariableMap, i);
1289 "omp.reduction.neutral", builder,
1290 moduleTranslation, &phis)))
1293 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1294 "reduction neutral element declaration region");
1299 if (!reductionDecls[i].getAllocRegion().empty())
1308 builder.CreateStore(phis[0], byRefVars[i]);
1310 privateReductionVariables[i] = byRefVars[i];
1311 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1312 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1315 builder.CreateStore(phis[0], privateReductionVariables[i]);
1322 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1329 template <
typename T>
1331 T loop, llvm::IRBuilderBase &builder,
1338 unsigned numReductions = loop.getNumReductionVars();
1340 for (
unsigned i = 0; i < numReductions; ++i) {
1341 owningReductionGens.push_back(
1343 owningAtomicReductionGens.push_back(
1348 reductionInfos.reserve(numReductions);
1349 for (
unsigned i = 0; i < numReductions; ++i) {
1350 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1351 if (owningAtomicReductionGens[i])
1352 atomicGen = owningAtomicReductionGens[i];
1353 llvm::Value *variable =
1354 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1355 reductionInfos.push_back(
1357 privateReductionVariables[i],
1358 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1359 owningReductionGens[i],
1360 nullptr, atomicGen});
1365 static LogicalResult
1369 llvm::IRBuilderBase &builder, StringRef regionName,
1370 bool shouldLoadCleanupRegionArg =
true) {
1372 if (cleanupRegion->empty())
1378 llvm::Instruction *potentialTerminator =
1379 builder.GetInsertBlock()->empty() ? nullptr
1380 : &builder.GetInsertBlock()->back();
1381 if (potentialTerminator && potentialTerminator->isTerminator())
1382 builder.SetInsertPoint(potentialTerminator);
1383 llvm::Value *privateVarValue =
1384 shouldLoadCleanupRegionArg
1385 ? builder.CreateLoad(
1387 privateVariables[i])
1388 : privateVariables[i];
1393 moduleTranslation)))
1406 OP op, llvm::IRBuilderBase &builder,
1408 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1411 bool isNowait =
false,
bool isTeamsReduction =
false) {
1413 if (op.getNumReductionVars() == 0)
1425 owningReductionGens, owningAtomicReductionGens,
1426 privateReductionVariables, reductionInfos);
1431 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1432 builder.SetInsertPoint(tempTerminator);
1433 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1434 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1435 isByRef, isNowait, isTeamsReduction);
1440 if (!contInsertPoint->getBlock())
1441 return op->emitOpError() <<
"failed to convert reductions";
1443 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1444 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1449 tempTerminator->eraseFromParent();
1450 builder.restoreIP(*afterIP);
1454 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1455 [](omp::DeclareReductionOp reductionDecl) {
1456 return &reductionDecl.getCleanupRegion();
1459 moduleTranslation, builder,
1460 "omp.reduction.cleanup");
1471 template <
typename OP>
1475 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1480 if (op.getNumReductionVars() == 0)
1486 allocaIP, reductionDecls,
1487 privateReductionVariables, reductionVariableMap,
1488 deferredStores, isByRef)))
1492 allocaIP.getBlock(), reductionDecls,
1493 privateReductionVariables, reductionVariableMap,
1494 isByRef, deferredStores);
1504 static llvm::Value *
1508 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1511 Value blockArg = (*mappedPrivateVars)[privateVar];
1514 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1515 "A block argument corresponding to a mapped var should have "
1518 if (privVarType == blockArgType)
1525 if (!isa<LLVM::LLVMPointerType>(privVarType))
1526 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1539 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1541 Region &initRegion = privDecl.getInitRegion();
1542 if (initRegion.
empty())
1543 return llvmPrivateVar;
1547 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1548 assert(nonPrivateVar);
1549 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1550 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1555 moduleTranslation, &phis)))
1556 return llvm::createStringError(
1557 "failed to inline `init` region of `omp.private`");
1559 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1578 return llvm::Error::success();
1580 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1586 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1588 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1589 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1592 return privVarOrErr.takeError();
1594 llvmPrivateVar = privVarOrErr.get();
1595 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1600 return llvm::Error::success();
1610 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1613 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1614 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1615 allocaTerminator->getIterator()),
1616 true, allocaTerminator->getStableDebugLoc(),
1617 "omp.region.after_alloca");
1619 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1621 allocaTerminator = allocaIP.getBlock()->getTerminator();
1622 builder.SetInsertPoint(allocaTerminator);
1624 assert(allocaTerminator->getNumSuccessors() == 1 &&
1625 "This is an unconditional branch created by splitBB");
1627 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1628 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1630 unsigned int allocaAS =
1631 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1634 .getProgramAddressSpace();
1636 for (
auto [privDecl, mlirPrivVar, blockArg] :
1639 llvm::Type *llvmAllocType =
1640 moduleTranslation.
convertType(privDecl.getType());
1641 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1642 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1643 llvmAllocType,
nullptr,
"omp.private.alloc");
1644 if (allocaAS != defaultAS)
1645 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1646 builder.getPtrTy(defaultAS));
1648 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1651 return afterAllocas;
1662 bool needsFirstprivate =
1663 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1664 return privOp.getDataSharingType() ==
1665 omp::DataSharingClauseType::FirstPrivate;
1668 if (!needsFirstprivate)
1671 llvm::BasicBlock *copyBlock =
1672 splitBB(builder,
true,
"omp.private.copy");
1675 for (
auto [decl, mlirVar, llvmVar] :
1676 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1677 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1681 Region ©Region = decl.getCopyRegion();
1685 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1686 assert(nonPrivateVar);
1687 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1690 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1694 moduleTranslation)))
1695 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1707 if (insertBarrier) {
1709 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1710 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1718 static LogicalResult
1725 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1726 [](omp::PrivateClauseOp privatizer) {
1727 return &privatizer.getDeallocRegion();
1731 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1732 "omp.private.dealloc",
false)))
1733 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1734 "`omp.private` op in");
1746 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1753 static LogicalResult
1756 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1757 using StorableBodyGenCallbackTy =
1758 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1760 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1766 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1770 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1774 sectionsOp.getNumReductionVars());
1778 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1781 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1782 reductionDecls, privateReductionVariables, reductionVariableMap,
1789 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1793 Region ®ion = sectionOp.getRegion();
1794 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1795 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1796 builder.restoreIP(codeGenIP);
1803 sectionsOp.getRegion().getNumArguments());
1804 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1805 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1806 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1808 moduleTranslation.
mapValue(sectionArg, llvmVal);
1815 sectionCBs.push_back(sectionCB);
1821 if (sectionCBs.empty())
1824 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1829 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1830 llvm::Value &vPtr, llvm::Value *&replacementValue)
1831 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1832 replacementValue = &vPtr;
1838 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1842 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1843 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1845 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1846 sectionsOp.getNowait());
1851 builder.restoreIP(*afterIP);
1855 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1856 privateReductionVariables, isByRef, sectionsOp.getNowait());
1860 static LogicalResult
1863 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1864 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1869 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1870 builder.restoreIP(codegenIP);
1872 builder, moduleTranslation)
1875 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1879 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1882 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1883 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1884 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1885 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1886 llvmCPFuncs.push_back(
1890 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1892 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1898 builder.restoreIP(*afterIP);
1904 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1909 for (
auto ra : iface.getReductionBlockArgs())
1910 for (
auto &use : ra.getUses()) {
1911 auto *useOp = use.getOwner();
1913 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1914 debugUses.push_back(useOp);
1918 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1923 Operation *currentOp = currentDistOp.getOperation();
1924 if (distOp && (distOp != currentOp))
1933 for (
auto use : debugUses)
1939 static LogicalResult
1942 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1947 unsigned numReductionVars = op.getNumReductionVars();
1951 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1957 if (doTeamsReduction) {
1958 isByRef =
getIsByRef(op.getReductionByref());
1960 assert(isByRef.size() == op.getNumReductionVars());
1963 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1968 op, reductionArgs, builder, moduleTranslation, allocaIP,
1969 reductionDecls, privateReductionVariables, reductionVariableMap,
1974 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1976 moduleTranslation, allocaIP);
1977 builder.restoreIP(codegenIP);
1983 llvm::Value *numTeamsLower =
nullptr;
1984 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1985 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1987 llvm::Value *numTeamsUpper =
nullptr;
1988 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1989 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1991 llvm::Value *threadLimit =
nullptr;
1992 if (
Value threadLimitVar = op.getThreadLimit())
1993 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1995 llvm::Value *ifExpr =
nullptr;
1996 if (
Value ifVar = op.getIfExpr())
1999 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2000 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2002 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2007 builder.restoreIP(*afterIP);
2008 if (doTeamsReduction) {
2011 op, builder, moduleTranslation, allocaIP, reductionDecls,
2012 privateReductionVariables, isByRef,
2022 if (dependVars.empty())
2024 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2025 llvm::omp::RTLDependenceKindTy type;
2027 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2028 case mlir::omp::ClauseTaskDepend::taskdependin:
2029 type = llvm::omp::RTLDependenceKindTy::DepIn;
2034 case mlir::omp::ClauseTaskDepend::taskdependout:
2035 case mlir::omp::ClauseTaskDepend::taskdependinout:
2036 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2038 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2039 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2041 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2042 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2045 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2046 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2047 dds.emplace_back(dd);
2059 llvm::IRBuilderBase &llvmBuilder,
2061 llvm::omp::Directive cancelDirective) {
2062 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2063 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2067 llvmBuilder.restoreIP(ip);
2073 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2074 return llvm::Error::success();
2079 ompBuilder.pushFinalizationCB(
2089 llvm::OpenMPIRBuilder &ompBuilder,
2090 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2091 ompBuilder.popFinalizationCB();
2092 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2093 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2094 assert(cancelBranch->getNumSuccessors() == 1 &&
2095 "cancel branch should have one target");
2096 cancelBranch->setSuccessor(0, constructFini);
2103 class TaskContextStructManager {
2105 TaskContextStructManager(llvm::IRBuilderBase &builder,
2108 : builder{builder}, moduleTranslation{moduleTranslation},
2109 privateDecls{privateDecls} {}
2115 void generateTaskContextStruct();
2121 void createGEPsToPrivateVars();
2124 void freeStructPtr();
2127 return llvmPrivateVarGEPs;
2130 llvm::Value *getStructPtr() {
return structPtr; }
2133 llvm::IRBuilderBase &builder;
2145 llvm::Value *structPtr =
nullptr;
2147 llvm::Type *structTy =
nullptr;
2151 void TaskContextStructManager::generateTaskContextStruct() {
2152 if (privateDecls.empty())
2154 privateVarTypes.reserve(privateDecls.size());
2156 for (omp::PrivateClauseOp &privOp : privateDecls) {
2159 if (!privOp.readsFromMold())
2161 Type mlirType = privOp.getType();
2162 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2168 llvm::DataLayout dataLayout =
2169 builder.GetInsertBlock()->getModule()->getDataLayout();
2170 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2171 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2174 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2176 "omp.task.context_ptr");
2179 void TaskContextStructManager::createGEPsToPrivateVars() {
2181 assert(privateVarTypes.empty());
2186 llvmPrivateVarGEPs.clear();
2187 llvmPrivateVarGEPs.reserve(privateDecls.size());
2188 llvm::Value *zero = builder.getInt32(0);
2190 for (
auto privDecl : privateDecls) {
2191 if (!privDecl.readsFromMold()) {
2193 llvmPrivateVarGEPs.push_back(
nullptr);
2196 llvm::Value *iVal = builder.getInt32(i);
2197 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2198 llvmPrivateVarGEPs.push_back(gep);
2203 void TaskContextStructManager::freeStructPtr() {
2207 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2209 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2210 builder.CreateFree(structPtr);
2214 static LogicalResult
2217 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2222 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2234 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2239 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2240 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2241 builder.getContext(),
"omp.task.start",
2242 builder.GetInsertBlock()->getParent());
2243 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2244 builder.SetInsertPoint(branchToTaskStartBlock);
2247 llvm::BasicBlock *copyBlock =
2248 splitBB(builder,
true,
"omp.private.copy");
2249 llvm::BasicBlock *initBlock =
2250 splitBB(builder,
true,
"omp.private.init");
2266 moduleTranslation, allocaIP);
2269 builder.SetInsertPoint(initBlock->getTerminator());
2272 taskStructMgr.generateTaskContextStruct();
2279 taskStructMgr.createGEPsToPrivateVars();
2281 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2284 taskStructMgr.getLLVMPrivateVarGEPs())) {
2286 if (!privDecl.readsFromMold())
2288 assert(llvmPrivateVarAlloc &&
2289 "reads from mold so shouldn't have been skipped");
2292 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2293 blockArg, llvmPrivateVarAlloc, initBlock);
2294 if (!privateVarOrErr)
2295 return handleError(privateVarOrErr, *taskOp.getOperation());
2304 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2305 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2306 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2308 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2309 llvmPrivateVarAlloc);
2311 assert(llvmPrivateVarAlloc->getType() ==
2312 moduleTranslation.
convertType(blockArg.getType()));
2322 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2323 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2324 taskOp.getPrivateNeedsBarrier())))
2325 return llvm::failure();
2328 builder.SetInsertPoint(taskStartBlock);
2330 auto bodyCB = [&](InsertPointTy allocaIP,
2331 InsertPointTy codegenIP) -> llvm::Error {
2335 moduleTranslation, allocaIP);
2338 builder.restoreIP(codegenIP);
2340 llvm::BasicBlock *privInitBlock =
nullptr;
2345 auto [blockArg, privDecl, mlirPrivVar] = zip;
2347 if (privDecl.readsFromMold())
2350 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2351 llvm::Type *llvmAllocType =
2352 moduleTranslation.
convertType(privDecl.getType());
2353 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2354 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2355 llvmAllocType,
nullptr,
"omp.private.alloc");
2358 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2359 blockArg, llvmPrivateVar, privInitBlock);
2360 if (!privateVarOrError)
2361 return privateVarOrError.takeError();
2362 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2363 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2366 taskStructMgr.createGEPsToPrivateVars();
2367 for (
auto [i, llvmPrivVar] :
2370 assert(privateVarsInfo.
llvmVars[i] &&
2371 "This is added in the loop above");
2374 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2379 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2383 if (!privateDecl.readsFromMold())
2386 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2387 llvmPrivateVar = builder.CreateLoad(
2388 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2390 assert(llvmPrivateVar->getType() ==
2391 moduleTranslation.
convertType(blockArg.getType()));
2392 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2396 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2397 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2398 return llvm::make_error<PreviouslyReportedError>();
2400 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2405 return llvm::make_error<PreviouslyReportedError>();
2408 taskStructMgr.freeStructPtr();
2410 return llvm::Error::success();
2419 llvm::omp::Directive::OMPD_taskgroup);
2423 moduleTranslation, dds);
2425 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2426 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2428 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2430 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2431 taskOp.getMergeable(),
2432 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2433 moduleTranslation.
lookupValue(taskOp.getPriority()));
2441 builder.restoreIP(*afterIP);
2446 static LogicalResult
2449 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2453 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2454 builder.restoreIP(codegenIP);
2456 builder, moduleTranslation)
2461 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2462 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2469 builder.restoreIP(*afterIP);
2473 static LogicalResult
2484 static LogicalResult
2488 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2492 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2494 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2498 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2501 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2502 llvm::Type *ivType = step->getType();
2503 llvm::Value *chunk =
nullptr;
2504 if (wsloopOp.getScheduleChunk()) {
2505 llvm::Value *chunkVar =
2506 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2507 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2514 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2518 wsloopOp.getNumReductionVars());
2521 builder, moduleTranslation, privateVarsInfo, allocaIP);
2528 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2533 moduleTranslation, allocaIP, reductionDecls,
2534 privateReductionVariables, reductionVariableMap,
2535 deferredStores, isByRef)))
2544 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2546 wsloopOp.getPrivateNeedsBarrier())))
2549 assert(afterAllocas.get()->getSinglePredecessor());
2552 afterAllocas.get()->getSinglePredecessor(),
2553 reductionDecls, privateReductionVariables,
2554 reductionVariableMap, isByRef, deferredStores)))
2558 bool isOrdered = wsloopOp.getOrdered().has_value();
2559 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2560 bool isSimd = wsloopOp.getScheduleSimd();
2561 bool loopNeedsBarrier = !wsloopOp.getNowait();
2566 llvm::omp::WorksharingLoopType workshareLoopType =
2567 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2568 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2569 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2573 llvm::omp::Directive::OMPD_for);
2575 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2578 LinearClauseProcessor linearClauseProcessor;
2579 if (wsloopOp.getLinearVars().size()) {
2580 for (
mlir::Value linearVar : wsloopOp.getLinearVars())
2581 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2583 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2584 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2588 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2596 if (wsloopOp.getLinearVars().size()) {
2597 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2598 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2599 loopInfo->getPreheader());
2602 builder.restoreIP(*afterBarrierIP);
2603 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2604 loopInfo->getIndVar());
2605 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2606 loopInfo->getExit());
2609 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2610 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2611 ompBuilder->applyWorkshareLoop(
2612 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2613 convertToScheduleKind(schedule), chunk, isSimd,
2614 scheduleMod == omp::ScheduleModifier::monotonic,
2615 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2622 if (wsloopOp.getLinearVars().size()) {
2623 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2624 assert(loopInfo->getLastIter() &&
2625 "`lastiter` in CanonicalLoopInfo is nullptr");
2626 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2627 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2628 loopInfo->getLastIter());
2631 for (
size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2632 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2634 builder.restoreIP(oldIP);
2642 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2643 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2653 static LogicalResult
2656 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2658 assert(isByRef.size() == opInst.getNumReductionVars());
2670 opInst.getNumReductionVars());
2673 auto bodyGenCB = [&](InsertPointTy allocaIP,
2674 InsertPointTy codeGenIP) -> llvm::Error {
2676 builder, moduleTranslation, privateVarsInfo, allocaIP);
2678 return llvm::make_error<PreviouslyReportedError>();
2684 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2687 InsertPointTy(allocaIP.getBlock(),
2688 allocaIP.getBlock()->getTerminator()->getIterator());
2691 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2692 reductionDecls, privateReductionVariables, reductionVariableMap,
2693 deferredStores, isByRef)))
2694 return llvm::make_error<PreviouslyReportedError>();
2696 assert(afterAllocas.get()->getSinglePredecessor());
2697 builder.restoreIP(codeGenIP);
2703 return llvm::make_error<PreviouslyReportedError>();
2706 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2708 opInst.getPrivateNeedsBarrier())))
2709 return llvm::make_error<PreviouslyReportedError>();
2713 afterAllocas.get()->getSinglePredecessor(),
2714 reductionDecls, privateReductionVariables,
2715 reductionVariableMap, isByRef, deferredStores)))
2716 return llvm::make_error<PreviouslyReportedError>();
2721 moduleTranslation, allocaIP);
2725 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2727 return regionBlock.takeError();
2730 if (opInst.getNumReductionVars() > 0) {
2736 owningReductionGens, owningAtomicReductionGens,
2737 privateReductionVariables, reductionInfos);
2740 builder.SetInsertPoint((*regionBlock)->getTerminator());
2743 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2744 builder.SetInsertPoint(tempTerminator);
2746 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2747 ompBuilder->createReductions(
2748 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2750 if (!contInsertPoint)
2751 return contInsertPoint.takeError();
2753 if (!contInsertPoint->getBlock())
2754 return llvm::make_error<PreviouslyReportedError>();
2756 tempTerminator->eraseFromParent();
2757 builder.restoreIP(*contInsertPoint);
2760 return llvm::Error::success();
2763 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2764 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2773 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2774 InsertPointTy oldIP = builder.saveIP();
2775 builder.restoreIP(codeGenIP);
2780 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2781 [](omp::DeclareReductionOp reductionDecl) {
2782 return &reductionDecl.getCleanupRegion();
2785 reductionCleanupRegions, privateReductionVariables,
2786 moduleTranslation, builder,
"omp.reduction.cleanup")))
2787 return llvm::createStringError(
2788 "failed to inline `cleanup` region of `omp.declare_reduction`");
2793 return llvm::make_error<PreviouslyReportedError>();
2795 builder.restoreIP(oldIP);
2796 return llvm::Error::success();
2799 llvm::Value *ifCond =
nullptr;
2800 if (
auto ifVar = opInst.getIfExpr())
2802 llvm::Value *numThreads =
nullptr;
2803 if (
auto numThreadsVar = opInst.getNumThreads())
2804 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2805 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2806 if (
auto bind = opInst.getProcBindKind())
2810 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2812 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2814 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2815 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2816 ifCond, numThreads, pbKind, isCancellable);
2821 builder.restoreIP(*afterIP);
2826 static llvm::omp::OrderKind
2829 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2831 case omp::ClauseOrderKind::Concurrent:
2832 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2834 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2838 static LogicalResult
2842 auto simdOp = cast<omp::SimdOp>(opInst);
2848 if (simdOp.isComposite()) {
2853 builder, moduleTranslation);
2861 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2865 builder, moduleTranslation, privateVarsInfo, allocaIP);
2874 llvm::ConstantInt *simdlen =
nullptr;
2875 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2876 simdlen = builder.getInt64(simdlenVar.value());
2878 llvm::ConstantInt *safelen =
nullptr;
2879 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2880 safelen = builder.getInt64(safelenVar.value());
2882 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2885 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2886 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2888 for (
size_t i = 0; i < operands.size(); ++i) {
2889 llvm::Value *alignment =
nullptr;
2890 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
2891 llvm::Type *ty = llvmVal->getType();
2893 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2894 alignment = builder.getInt64(intAttr.getInt());
2895 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2896 assert(alignment &&
"Invalid alignment value");
2897 auto curInsert = builder.saveIP();
2898 builder.SetInsertPoint(sourceBlock);
2899 llvmVal = builder.CreateLoad(ty, llvmVal);
2900 builder.restoreIP(curInsert);
2901 alignedVars[llvmVal] = alignment;
2905 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2910 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2912 ompBuilder->applySimd(loopInfo, alignedVars,
2914 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2916 order, simdlen, safelen);
2924 static LogicalResult
2928 auto loopOp = cast<omp::LoopNestOp>(opInst);
2931 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2936 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2937 llvm::Value *iv) -> llvm::Error {
2940 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2945 bodyInsertPoints.push_back(ip);
2947 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2948 return llvm::Error::success();
2951 builder.restoreIP(ip);
2953 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
2955 return regionBlock.takeError();
2957 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2958 return llvm::Error::success();
2966 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2967 llvm::Value *lowerBound =
2968 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
2969 llvm::Value *upperBound =
2970 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
2971 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
2976 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2977 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2979 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2981 computeIP = loopInfos.front()->getPreheaderIP();
2985 ompBuilder->createCanonicalLoop(
2986 loc, bodyGen, lowerBound, upperBound, step,
2987 true, loopOp.getLoopInclusive(), computeIP);
2992 loopInfos.push_back(*loopResult);
2997 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2998 loopInfos.front()->getAfterIP();
3002 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3003 [&](OpenMPLoopInfoStackFrame &frame) {
3004 frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
3012 builder.restoreIP(afterIP);
3017 static llvm::AtomicOrdering
3020 return llvm::AtomicOrdering::Monotonic;
3023 case omp::ClauseMemoryOrderKind::Seq_cst:
3024 return llvm::AtomicOrdering::SequentiallyConsistent;
3025 case omp::ClauseMemoryOrderKind::Acq_rel:
3026 return llvm::AtomicOrdering::AcquireRelease;
3027 case omp::ClauseMemoryOrderKind::Acquire:
3028 return llvm::AtomicOrdering::Acquire;
3029 case omp::ClauseMemoryOrderKind::Release:
3030 return llvm::AtomicOrdering::Release;
3031 case omp::ClauseMemoryOrderKind::Relaxed:
3032 return llvm::AtomicOrdering::Monotonic;
3034 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3038 static LogicalResult
3041 auto readOp = cast<omp::AtomicReadOp>(opInst);
3046 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3049 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3052 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3053 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3055 llvm::Type *elementType =
3056 moduleTranslation.
convertType(readOp.getElementType());
3058 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3059 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3060 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3065 static LogicalResult
3068 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3073 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3076 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3078 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3079 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3080 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3081 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3084 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3092 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3093 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3094 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3095 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3096 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3097 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3098 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3099 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3100 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3101 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3105 static LogicalResult
3107 llvm::IRBuilderBase &builder,
3114 auto &innerOpList = opInst.getRegion().front().getOperations();
3115 bool isXBinopExpr{
false};
3116 llvm::AtomicRMWInst::BinOp binop;
3118 llvm::Value *llvmExpr =
nullptr;
3119 llvm::Value *llvmX =
nullptr;
3120 llvm::Type *llvmXElementType =
nullptr;
3121 if (innerOpList.size() == 2) {
3127 opInst.getRegion().getArgument(0))) {
3128 return opInst.emitError(
"no atomic update operation with region argument"
3129 " as operand found inside atomic.update region");
3132 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3134 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3138 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3140 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3142 opInst.getRegion().getArgument(0).getType());
3143 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3147 llvm::AtomicOrdering atomicOrdering =
3152 [&opInst, &moduleTranslation](
3153 llvm::Value *atomicx,
3156 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3157 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3158 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3159 return llvm::make_error<PreviouslyReportedError>();
3161 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3162 assert(yieldop && yieldop.getResults().size() == 1 &&
3163 "terminator must be omp.yield op and it must have exactly one "
3165 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3170 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3171 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3172 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3173 atomicOrdering, binop, updateFn,
3179 builder.restoreIP(*afterIP);
3183 static LogicalResult
3185 llvm::IRBuilderBase &builder,
3192 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3193 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3195 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3196 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3198 assert((atomicUpdateOp || atomicWriteOp) &&
3199 "internal op must be an atomic.update or atomic.write op");
3201 if (atomicWriteOp) {
3202 isPostfixUpdate =
true;
3203 mlirExpr = atomicWriteOp.getExpr();
3205 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3206 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3207 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3210 if (innerOpList.size() == 2) {
3213 atomicUpdateOp.getRegion().getArgument(0))) {
3214 return atomicUpdateOp.emitError(
3215 "no atomic update operation with region argument"
3216 " as operand found inside atomic.update region");
3220 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3223 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3227 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3228 llvm::Value *llvmX =
3229 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3230 llvm::Value *llvmV =
3231 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3232 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3233 atomicCaptureOp.getAtomicReadOp().getElementType());
3234 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3237 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3241 llvm::AtomicOrdering atomicOrdering =
3245 [&](llvm::Value *atomicx,
3248 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3249 Block &bb = *atomicUpdateOp.getRegion().
begin();
3250 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3252 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3253 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3254 return llvm::make_error<PreviouslyReportedError>();
3256 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3257 assert(yieldop && yieldop.getResults().size() == 1 &&
3258 "terminator must be omp.yield op and it must have exactly one "
3260 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3265 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3266 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3267 ompBuilder->createAtomicCapture(
3268 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3269 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
3271 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3274 builder.restoreIP(*afterIP);
3279 omp::ClauseCancellationConstructType directive) {
3280 switch (directive) {
3281 case omp::ClauseCancellationConstructType::Loop:
3282 return llvm::omp::Directive::OMPD_for;
3283 case omp::ClauseCancellationConstructType::Parallel:
3284 return llvm::omp::Directive::OMPD_parallel;
3285 case omp::ClauseCancellationConstructType::Sections:
3286 return llvm::omp::Directive::OMPD_sections;
3287 case omp::ClauseCancellationConstructType::Taskgroup:
3288 return llvm::omp::Directive::OMPD_taskgroup;
3292 static LogicalResult
3298 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3301 llvm::Value *ifCond =
nullptr;
3302 if (
Value ifVar = op.getIfExpr())
3305 llvm::omp::Directive cancelledDirective =
3308 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3309 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3311 if (failed(
handleError(afterIP, *op.getOperation())))
3314 builder.restoreIP(afterIP.get());
3319 static LogicalResult
3321 llvm::IRBuilderBase &builder,
3326 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3329 llvm::omp::Directive cancelledDirective =
3332 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3333 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3335 if (failed(
handleError(afterIP, *op.getOperation())))
3338 builder.restoreIP(afterIP.get());
3345 static LogicalResult
3348 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3350 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3355 Value symAddr = threadprivateOp.getSymAddr();
3358 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3361 if (!isa<LLVM::AddressOfOp>(symOp))
3362 return opInst.
emitError(
"Addressing symbol not found");
3363 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3365 LLVM::GlobalOp global =
3366 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3367 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3369 if (!ompBuilder->Config.isTargetDevice()) {
3370 llvm::Type *type = globalValue->getValueType();
3371 llvm::TypeSize typeSize =
3372 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3374 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3375 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3376 ompLoc, globalValue, size, global.getSymName() +
".cache");
3385 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3387 switch (deviceClause) {
3388 case mlir::omp::DeclareTargetDeviceType::host:
3389 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3391 case mlir::omp::DeclareTargetDeviceType::nohost:
3392 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3394 case mlir::omp::DeclareTargetDeviceType::any:
3395 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3398 llvm_unreachable(
"unhandled device clause");
3401 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3403 mlir::omp::DeclareTargetCaptureClause captureClause) {
3404 switch (captureClause) {
3405 case mlir::omp::DeclareTargetCaptureClause::to:
3406 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3407 case mlir::omp::DeclareTargetCaptureClause::link:
3408 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3409 case mlir::omp::DeclareTargetCaptureClause::enter:
3410 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3412 llvm_unreachable(
"unhandled capture clause");
3417 llvm::OpenMPIRBuilder &ompBuilder) {
3419 llvm::raw_svector_ostream os(suffix);
3422 auto fileInfoCallBack = [&loc]() {
3423 return std::pair<std::string, uint64_t>(
3424 llvm::StringRef(loc.getFilename()), loc.getLine());
3428 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3430 os <<
"_decl_tgt_ref_ptr";
3436 if (
auto addressOfOp =
3437 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3438 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3439 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3440 if (
auto declareTargetGlobal =
3441 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3442 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3443 mlir::omp::DeclareTargetCaptureClause::link)
3452 static llvm::Value *
3459 if (
auto addressOfOp =
3460 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3461 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3462 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3463 addressOfOp.getGlobalName()))) {
3465 if (
auto declareTargetGlobal =
3466 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3467 gOp.getOperation())) {
3471 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3472 mlir::omp::DeclareTargetCaptureClause::link) ||
3473 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3474 mlir::omp::DeclareTargetCaptureClause::to &&
3475 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3479 if (gOp.getSymName().contains(suffix))
3484 (gOp.getSymName().str() + suffix.str()).str());
3495 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3499 void append(MapInfosTy &curInfo) {
3500 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3501 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3510 struct MapInfoData : MapInfosTy {
3522 void append(MapInfoData &CurInfo) {
3523 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3524 CurInfo.IsDeclareTarget.end());
3525 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3526 OriginalValue.append(CurInfo.OriginalValue.begin(),
3527 CurInfo.OriginalValue.end());
3528 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3529 MapInfosTy::append(CurInfo);
3535 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3536 arrTy.getElementType()))
3552 Operation *clauseOp, llvm::Value *basePointer,
3553 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3555 if (
auto memberClause =
3556 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3561 if (!memberClause.getBounds().empty()) {
3562 llvm::Value *elementCount = builder.getInt64(1);
3563 for (
auto bounds : memberClause.getBounds()) {
3564 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3565 bounds.getDefiningOp())) {
3570 elementCount = builder.CreateMul(
3574 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3575 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3576 builder.getInt64(1)));
3583 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3591 return builder.CreateMul(elementCount,
3592 builder.getInt64(underlyingTypeSzInBits / 8));
3605 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3613 for (
Value mapValue : mapVars) {
3614 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3615 for (
auto member : map.getMembers())
3616 if (member == mapOp)
3623 for (
Value mapValue : mapVars) {
3624 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3626 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3627 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
3628 mapData.Pointers.push_back(mapData.OriginalValue.back());
3630 if (llvm::Value *refPtr =
3632 moduleTranslation)) {
3633 mapData.IsDeclareTarget.push_back(
true);
3634 mapData.BasePointers.push_back(refPtr);
3636 mapData.IsDeclareTarget.push_back(
false);
3637 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3640 mapData.BaseType.push_back(
3641 moduleTranslation.
convertType(mapOp.getVarType()));
3642 mapData.Sizes.push_back(
3643 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3644 mapData.BaseType.back(), builder, moduleTranslation));
3645 mapData.MapClause.push_back(mapOp.getOperation());
3646 mapData.Types.push_back(
3647 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3651 if (mapOp.getMapperId())
3652 mapData.Mappers.push_back(
3653 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3654 mapOp, mapOp.getMapperIdAttr()));
3656 mapData.Mappers.push_back(
nullptr);
3657 mapData.IsAMapping.push_back(
true);
3658 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3661 auto findMapInfo = [&mapData](llvm::Value *val,
3662 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3665 for (llvm::Value *basePtr : mapData.OriginalValue) {
3666 if (basePtr == val && mapData.IsAMapping[index]) {
3668 mapData.Types[index] |=
3669 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3670 mapData.DevicePointers[index] = devInfoTy;
3679 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3680 for (
Value mapValue : useDevOperands) {
3681 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3683 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3684 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3687 if (!findMapInfo(origValue, devInfoTy)) {
3688 mapData.OriginalValue.push_back(origValue);
3689 mapData.Pointers.push_back(mapData.OriginalValue.back());
3690 mapData.IsDeclareTarget.push_back(
false);
3691 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3692 mapData.BaseType.push_back(
3693 moduleTranslation.
convertType(mapOp.getVarType()));
3694 mapData.Sizes.push_back(builder.getInt64(0));
3695 mapData.MapClause.push_back(mapOp.getOperation());
3696 mapData.Types.push_back(
3697 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3700 mapData.DevicePointers.push_back(devInfoTy);
3701 mapData.Mappers.push_back(
nullptr);
3702 mapData.IsAMapping.push_back(
false);
3703 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3708 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3709 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3711 for (
Value mapValue : hasDevAddrOperands) {
3712 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3714 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3715 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3717 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3718 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3720 mapData.OriginalValue.push_back(origValue);
3721 mapData.BasePointers.push_back(origValue);
3722 mapData.Pointers.push_back(origValue);
3723 mapData.IsDeclareTarget.push_back(
false);
3724 mapData.BaseType.push_back(
3725 moduleTranslation.
convertType(mapOp.getVarType()));
3726 mapData.Sizes.push_back(
3727 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3728 mapData.MapClause.push_back(mapOp.getOperation());
3729 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3733 mapData.Types.push_back(mapType);
3737 if (mapOp.getMapperId()) {
3738 mapData.Mappers.push_back(
3739 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3740 mapOp, mapOp.getMapperIdAttr()));
3742 mapData.Mappers.push_back(
nullptr);
3745 mapData.Types.push_back(
3746 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3747 mapData.Mappers.push_back(
nullptr);
3751 mapData.DevicePointers.push_back(
3752 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3753 mapData.IsAMapping.push_back(
false);
3754 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3759 auto *res = llvm::find(mapData.MapClause, memberOp);
3760 assert(res != mapData.MapClause.end() &&
3761 "MapInfoOp for member not found in MapData, cannot return index");
3762 return std::distance(mapData.MapClause.begin(), res);
3767 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3769 if (indexAttr.size() == 1)
3770 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3773 std::iota(indices.begin(), indices.end(), 0);
3775 llvm::sort(indices.begin(), indices.end(),
3776 [&](
const size_t a,
const size_t b) {
3777 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3778 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3779 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3780 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3781 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3783 if (aIndex == bIndex)
3786 if (aIndex < bIndex)
3789 if (aIndex > bIndex)
3796 return memberIndicesA.size() < memberIndicesB.size();
3799 return llvm::cast<omp::MapInfoOp>(
3800 mapInfo.getMembers()[indices.front()].getDefiningOp());
3822 std::vector<llvm::Value *>
3824 llvm::IRBuilderBase &builder,
bool isArrayTy,
3826 std::vector<llvm::Value *> idx;
3837 idx.push_back(builder.getInt64(0));
3838 for (
int i = bounds.size() - 1; i >= 0; --i) {
3839 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3840 bounds[i].getDefiningOp())) {
3841 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
3863 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3864 for (
size_t i = 1; i < bounds.size(); ++i) {
3865 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3866 bounds[i].getDefiningOp())) {
3867 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3868 moduleTranslation.
lookupValue(boundOp.getExtent()),
3869 dimensionIndexSizeOffset[i - 1]));
3877 for (
int i = bounds.size() - 1; i >= 0; --i) {
3878 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3879 bounds[i].getDefiningOp())) {
3881 idx.emplace_back(builder.CreateMul(
3882 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
3883 dimensionIndexSizeOffset[i]));
3885 idx.back() = builder.CreateAdd(
3886 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
3887 boundOp.getLowerBound()),
3888 dimensionIndexSizeOffset[i]));
3913 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3914 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
3915 assert(!ompBuilder.Config.isTargetDevice() &&
3916 "function only supported for host device codegen");
3919 combinedInfo.Types.emplace_back(
3921 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3922 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3923 combinedInfo.DevicePointers.emplace_back(
3924 mapData.DevicePointers[mapDataIndex]);
3925 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
3927 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3928 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3938 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3940 llvm::Value *lowAddr, *highAddr;
3941 if (!parentClause.getPartialMap()) {
3942 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3943 builder.getPtrTy());
3944 highAddr = builder.CreatePointerCast(
3945 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3946 mapData.Pointers[mapDataIndex], 1),
3947 builder.getPtrTy());
3948 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3950 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3953 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3954 builder.getPtrTy());
3957 highAddr = builder.CreatePointerCast(
3958 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3959 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3960 builder.getPtrTy());
3961 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3964 llvm::Value *size = builder.CreateIntCast(
3965 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3966 builder.getInt64Ty(),
3968 combinedInfo.Sizes.push_back(size);
3970 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3971 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3979 if (!parentClause.getPartialMap()) {
3984 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3985 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3986 combinedInfo.Types.emplace_back(mapFlag);
3987 combinedInfo.DevicePointers.emplace_back(
3989 combinedInfo.Mappers.emplace_back(
nullptr);
3991 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3992 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3993 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3994 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3996 return memberOfFlag;
4008 if (mapOp.getVarPtrPtr())
4023 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4024 MapInfoData &mapData, uint64_t mapDataIndex,
4025 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4026 assert(!ompBuilder.Config.isTargetDevice() &&
4027 "function only supported for host device codegen");
4030 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4032 for (
auto mappedMembers : parentClause.getMembers()) {
4034 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4037 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4048 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4049 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4050 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4051 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4052 combinedInfo.Types.emplace_back(mapFlag);
4053 combinedInfo.DevicePointers.emplace_back(
4055 combinedInfo.Mappers.emplace_back(
nullptr);
4056 combinedInfo.Names.emplace_back(
4058 combinedInfo.BasePointers.emplace_back(
4059 mapData.BasePointers[mapDataIndex]);
4060 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4061 combinedInfo.Sizes.emplace_back(builder.getInt64(
4062 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
4068 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4069 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4070 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4071 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4073 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4075 combinedInfo.Types.emplace_back(mapFlag);
4076 combinedInfo.DevicePointers.emplace_back(
4077 mapData.DevicePointers[memberDataIdx]);
4078 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4079 combinedInfo.Names.emplace_back(
4081 uint64_t basePointerIndex =
4083 combinedInfo.BasePointers.emplace_back(
4084 mapData.BasePointers[basePointerIndex]);
4085 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4087 llvm::Value *size = mapData.Sizes[memberDataIdx];
4089 size = builder.CreateSelect(
4090 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4091 builder.getInt64(0), size);
4094 combinedInfo.Sizes.emplace_back(size);
4099 MapInfosTy &combinedInfo,
bool isTargetParams,
4100 int mapDataParentIdx = -1) {
4104 auto mapFlag = mapData.Types[mapDataIdx];
4105 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4109 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4111 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4112 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4114 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4116 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4121 if (mapDataParentIdx >= 0)
4122 combinedInfo.BasePointers.emplace_back(
4123 mapData.BasePointers[mapDataParentIdx]);
4125 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4127 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4128 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4129 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4130 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4131 combinedInfo.Types.emplace_back(mapFlag);
4132 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4136 llvm::IRBuilderBase &builder,
4137 llvm::OpenMPIRBuilder &ompBuilder,
4139 MapInfoData &mapData, uint64_t mapDataIndex,
4140 bool isTargetParams) {
4141 assert(!ompBuilder.Config.isTargetDevice() &&
4142 "function only supported for host device codegen");
4145 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4150 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4151 auto memberClause = llvm::cast<omp::MapInfoOp>(
4152 parentClause.getMembers()[0].getDefiningOp());
4169 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4171 combinedInfo, mapData, mapDataIndex, isTargetParams);
4173 combinedInfo, mapData, mapDataIndex,
4174 memberOfParentFlag);
4184 llvm::IRBuilderBase &builder) {
4186 "function only supported for host device codegen");
4187 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4189 if (!mapData.IsDeclareTarget[i]) {
4190 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4191 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4201 switch (captureKind) {
4202 case omp::VariableCaptureKind::ByRef: {
4203 llvm::Value *newV = mapData.Pointers[i];
4205 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4208 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4210 if (!offsetIdx.empty())
4211 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4213 mapData.Pointers[i] = newV;
4215 case omp::VariableCaptureKind::ByCopy: {
4216 llvm::Type *type = mapData.BaseType[i];
4218 if (mapData.Pointers[i]->getType()->isPointerTy())
4219 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4221 newV = mapData.Pointers[i];
4224 auto curInsert = builder.saveIP();
4226 auto *memTempAlloc =
4227 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4228 builder.restoreIP(curInsert);
4230 builder.CreateStore(newV, memTempAlloc);
4231 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4234 mapData.Pointers[i] = newV;
4235 mapData.BasePointers[i] = newV;
4237 case omp::VariableCaptureKind::This:
4238 case omp::VariableCaptureKind::VLAType:
4239 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4250 MapInfoData &mapData,
bool isTargetParams =
false) {
4252 "function only supported for host device codegen");
4274 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4277 if (mapData.IsAMember[i])
4280 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4281 if (!mapInfoOp.getMembers().empty()) {
4283 combinedInfo, mapData, i, isTargetParams);
4294 llvm::StringRef mapperFuncName);
4300 "function only supported for host device codegen");
4301 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4302 std::string mapperFuncName =
4304 {
"omp_mapper", declMapperOp.getSymName()});
4306 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4316 llvm::StringRef mapperFuncName) {
4318 "function only supported for host device codegen");
4319 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4320 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4323 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4326 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4329 MapInfosTy combinedInfo;
4331 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4332 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4333 builder.restoreIP(codeGenIP);
4334 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4335 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4336 builder.GetInsertBlock());
4337 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4340 return llvm::make_error<PreviouslyReportedError>();
4341 MapInfoData mapData;
4344 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4349 return combinedInfo;
4353 if (!combinedInfo.Mappers[i])
4360 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4362 return newFn.takeError();
4363 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4367 static LogicalResult
4370 llvm::Value *ifCond =
nullptr;
4371 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4375 llvm::omp::RuntimeFunction RTLFn;
4379 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4382 LogicalResult result =
4384 .Case([&](omp::TargetDataOp dataOp) {
4388 if (
auto ifVar = dataOp.getIfExpr())
4391 if (
auto devId = dataOp.getDevice())
4393 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4394 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4395 deviceID = intAttr.getInt();
4397 mapVars = dataOp.getMapVars();
4398 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4399 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4402 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4406 if (
auto ifVar = enterDataOp.getIfExpr())
4409 if (
auto devId = enterDataOp.getDevice())
4411 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4412 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4413 deviceID = intAttr.getInt();
4415 enterDataOp.getNowait()
4416 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4417 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4418 mapVars = enterDataOp.getMapVars();
4419 info.HasNoWait = enterDataOp.getNowait();
4422 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4426 if (
auto ifVar = exitDataOp.getIfExpr())
4429 if (
auto devId = exitDataOp.getDevice())
4431 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4432 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4433 deviceID = intAttr.getInt();
4435 RTLFn = exitDataOp.getNowait()
4436 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4437 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4438 mapVars = exitDataOp.getMapVars();
4439 info.HasNoWait = exitDataOp.getNowait();
4442 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4446 if (
auto ifVar = updateDataOp.getIfExpr())
4449 if (
auto devId = updateDataOp.getDevice())
4451 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4452 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4453 deviceID = intAttr.getInt();
4456 updateDataOp.getNowait()
4457 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4458 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4459 mapVars = updateDataOp.getMapVars();
4460 info.HasNoWait = updateDataOp.getNowait();
4464 llvm_unreachable(
"unexpected operation");
4471 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4472 MapInfoData mapData;
4474 builder, useDevicePtrVars, useDeviceAddrVars);
4477 MapInfosTy combinedInfo;
4478 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4479 builder.restoreIP(codeGenIP);
4480 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4481 return combinedInfo;
4487 [&moduleTranslation](
4488 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4492 for (
auto [arg, useDevVar] :
4493 llvm::zip_equal(blockArgs, useDeviceVars)) {
4495 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4496 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4497 : mapInfoOp.getVarPtr();
4500 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4501 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4502 mapInfoData.MapClause, mapInfoData.DevicePointers,
4503 mapInfoData.BasePointers)) {
4504 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4505 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4506 devicePointer != type)
4509 if (llvm::Value *devPtrInfoMap =
4510 mapper ? mapper(basePointer) : basePointer) {
4511 moduleTranslation.
mapValue(arg, devPtrInfoMap);
4518 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4519 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4520 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4521 builder.restoreIP(codeGenIP);
4522 assert(isa<omp::TargetDataOp>(op) &&
4523 "BodyGen requested for non TargetDataOp");
4524 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4525 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4526 switch (bodyGenType) {
4527 case BodyGenTy::Priv:
4529 if (!info.DevicePtrInfoMap.empty()) {
4530 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4531 blockArgIface.getUseDeviceAddrBlockArgs(),
4532 useDeviceAddrVars, mapData,
4533 [&](llvm::Value *basePointer) -> llvm::Value * {
4534 if (!info.DevicePtrInfoMap[basePointer].second)
4536 return builder.CreateLoad(
4538 info.DevicePtrInfoMap[basePointer].second);
4540 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4541 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4542 mapData, [&](llvm::Value *basePointer) {
4543 return info.DevicePtrInfoMap[basePointer].second;
4547 moduleTranslation)))
4548 return llvm::make_error<PreviouslyReportedError>();
4551 case BodyGenTy::DupNoPriv:
4554 builder.restoreIP(codeGenIP);
4556 case BodyGenTy::NoPriv:
4558 if (info.DevicePtrInfoMap.empty()) {
4561 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4562 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4563 blockArgIface.getUseDeviceAddrBlockArgs(),
4564 useDeviceAddrVars, mapData);
4565 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4566 blockArgIface.getUseDevicePtrBlockArgs(),
4567 useDevicePtrVars, mapData);
4571 moduleTranslation)))
4572 return llvm::make_error<PreviouslyReportedError>();
4576 return builder.saveIP();
4579 auto customMapperCB =
4581 if (!combinedInfo.Mappers[i])
4583 info.HasMapper =
true;
4588 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4589 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4591 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4592 if (isa<omp::TargetDataOp>(op))
4593 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4594 builder.getInt64(deviceID), ifCond,
4595 info, genMapInfoCB, customMapperCB,
4598 return ompBuilder->createTargetData(
4599 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4600 info, genMapInfoCB, customMapperCB, &RTLFn);
4606 builder.restoreIP(*afterIP);
4610 static LogicalResult
4614 auto distributeOp = cast<omp::DistributeOp>(opInst);
4621 bool doDistributeReduction =
4625 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4630 if (doDistributeReduction) {
4631 isByRef =
getIsByRef(teamsOp.getReductionByref());
4632 assert(isByRef.size() == teamsOp.getNumReductionVars());
4635 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4639 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4640 .getReductionBlockArgs();
4643 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4644 reductionDecls, privateReductionVariables, reductionVariableMap,
4649 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4650 auto bodyGenCB = [&](InsertPointTy allocaIP,
4651 InsertPointTy codeGenIP) -> llvm::Error {
4655 moduleTranslation, allocaIP);
4658 builder.restoreIP(codeGenIP);
4664 return llvm::make_error<PreviouslyReportedError>();
4669 return llvm::make_error<PreviouslyReportedError>();
4672 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
4674 distributeOp.getPrivateNeedsBarrier())))
4675 return llvm::make_error<PreviouslyReportedError>();
4678 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4681 builder, moduleTranslation);
4683 return regionBlock.takeError();
4684 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4689 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4692 auto schedule = omp::ClauseScheduleKind::Static;
4693 bool isOrdered =
false;
4694 std::optional<omp::ScheduleModifier> scheduleMod;
4695 bool isSimd =
false;
4696 llvm::omp::WorksharingLoopType workshareLoopType =
4697 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4698 bool loopNeedsBarrier =
false;
4699 llvm::Value *chunk =
nullptr;
4701 llvm::CanonicalLoopInfo *loopInfo =
4703 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4704 ompBuilder->applyWorkshareLoop(
4705 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4706 convertToScheduleKind(schedule), chunk, isSimd,
4707 scheduleMod == omp::ScheduleModifier::monotonic,
4708 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4712 return wsloopIP.takeError();
4716 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4718 return llvm::make_error<PreviouslyReportedError>();
4720 return llvm::Error::success();
4723 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4725 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4726 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4727 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4732 builder.restoreIP(*afterIP);
4734 if (doDistributeReduction) {
4737 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4738 privateReductionVariables, isByRef,
4749 if (!cast<mlir::ModuleOp>(op))
4754 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
4755 attribute.getOpenmpDeviceVersion());
4757 if (attribute.getNoGpuLib())
4760 ompBuilder->createGlobalFlag(
4761 attribute.getDebugKind() ,
4762 "__omp_rtl_debug_kind");
4763 ompBuilder->createGlobalFlag(
4765 .getAssumeTeamsOversubscription()
4767 "__omp_rtl_assume_teams_oversubscription");
4768 ompBuilder->createGlobalFlag(
4770 .getAssumeThreadsOversubscription()
4772 "__omp_rtl_assume_threads_oversubscription");
4773 ompBuilder->createGlobalFlag(
4774 attribute.getAssumeNoThreadState() ,
4775 "__omp_rtl_assume_no_thread_state");
4776 ompBuilder->createGlobalFlag(
4778 .getAssumeNoNestedParallelism()
4780 "__omp_rtl_assume_no_nested_parallelism");
4785 omp::TargetOp targetOp,
4786 llvm::StringRef parentName =
"") {
4787 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
4789 assert(fileLoc &&
"No file found from location");
4790 StringRef fileName = fileLoc.getFilename().getValue();
4792 llvm::sys::fs::UniqueID id;
4793 uint64_t line = fileLoc.getLine();
4794 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
4796 size_t deviceId = 0xdeadf17e;
4798 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4800 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
4801 id.getFile(), line);
4808 llvm::IRBuilderBase &builder, llvm::Function *func) {
4810 "function only supported for target device codegen");
4811 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4824 if (mapData.IsDeclareTarget[i]) {
4831 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
4832 convertUsersOfConstantsToInstructions(constant, func,
false);
4839 for (llvm::User *user : mapData.OriginalValue[i]->users())
4840 userVec.push_back(user);
4842 for (llvm::User *user : userVec) {
4843 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
4844 if (insn->getFunction() == func) {
4845 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
4846 mapData.BasePointers[i]);
4847 load->moveBefore(insn->getIterator());
4848 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
4895 static llvm::IRBuilderBase::InsertPoint
4897 llvm::Value *input, llvm::Value *&retVal,
4898 llvm::IRBuilderBase &builder,
4899 llvm::OpenMPIRBuilder &ompBuilder,
4901 llvm::IRBuilderBase::InsertPoint allocaIP,
4902 llvm::IRBuilderBase::InsertPoint codeGenIP) {
4903 assert(ompBuilder.Config.isTargetDevice() &&
4904 "function only supported for target device codegen");
4905 builder.restoreIP(allocaIP);
4907 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
4909 ompBuilder.M.getContext());
4910 unsigned alignmentValue = 0;
4912 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
4913 if (mapData.OriginalValue[i] == input) {
4914 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4915 capture = mapOp.getMapCaptureType();
4918 mapOp.getVarType(), ompBuilder.M.getDataLayout());
4922 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
4923 unsigned int defaultAS =
4924 ompBuilder.M.getDataLayout().getProgramAddressSpace();
4927 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
4929 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
4930 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
4932 builder.CreateStore(&arg, v);
4934 builder.restoreIP(codeGenIP);
4937 case omp::VariableCaptureKind::ByCopy: {
4941 case omp::VariableCaptureKind::ByRef: {
4942 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
4944 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
4959 if (v->getType()->isPointerTy() && alignmentValue) {
4960 llvm::MDBuilder MDB(builder.getContext());
4961 loadInst->setMetadata(
4962 llvm::LLVMContext::MD_align,
4965 llvm::Type::getInt64Ty(builder.getContext()),
4972 case omp::VariableCaptureKind::This:
4973 case omp::VariableCaptureKind::VLAType:
4976 assert(
false &&
"Currently unsupported capture kind");
4980 return builder.saveIP();
4997 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4998 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
4999 blockArgIface.getHostEvalBlockArgs())) {
5000 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5004 .Case([&](omp::TeamsOp teamsOp) {
5005 if (teamsOp.getNumTeamsLower() == blockArg)
5006 numTeamsLower = hostEvalVar;
5007 else if (teamsOp.getNumTeamsUpper() == blockArg)
5008 numTeamsUpper = hostEvalVar;
5009 else if (teamsOp.getThreadLimit() == blockArg)
5010 threadLimit = hostEvalVar;
5012 llvm_unreachable(
"unsupported host_eval use");
5014 .Case([&](omp::ParallelOp parallelOp) {
5015 if (parallelOp.getNumThreads() == blockArg)
5016 numThreads = hostEvalVar;
5018 llvm_unreachable(
"unsupported host_eval use");
5020 .Case([&](omp::LoopNestOp loopOp) {
5021 auto processBounds =
5026 if (lb == blockArg) {
5029 (*outBounds)[i] = hostEvalVar;
5035 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5036 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5038 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5040 assert(found &&
"unsupported host_eval use");
5043 llvm_unreachable(
"unsupported host_eval use");
5056 template <
typename OpTy>
5061 if (OpTy casted = dyn_cast<OpTy>(op))
5064 if (immediateParent)
5065 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5074 return std::nullopt;
5077 dyn_cast_if_present<LLVM::ConstantOp>(value.
getDefiningOp()))
5078 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5079 return constAttr.getInt();
5081 return std::nullopt;
5086 uint64_t sizeInBytes = sizeInBits / 8;
5090 template <
typename OpTy>
5092 if (op.getNumReductionVars() > 0) {
5097 members.reserve(reductions.size());
5098 for (omp::DeclareReductionOp &red : reductions)
5099 members.push_back(red.getType());
5101 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5117 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5118 bool isTargetDevice,
bool isGPU) {
5121 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5122 if (!isTargetDevice) {
5129 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5130 numTeamsLower = teamsOp.getNumTeamsLower();
5131 numTeamsUpper = teamsOp.getNumTeamsUpper();
5132 threadLimit = teamsOp.getThreadLimit();
5135 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5136 numThreads = parallelOp.getNumThreads();
5141 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5142 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5145 if (numTeamsUpper) {
5147 minTeamsVal = maxTeamsVal = *val;
5149 minTeamsVal = maxTeamsVal = 0;
5151 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5153 castOrGetParentOfType<omp::SimdOp>(capturedOp,
5155 minTeamsVal = maxTeamsVal = 1;
5157 minTeamsVal = maxTeamsVal = -1;
5162 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
5176 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5177 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5178 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5181 int32_t maxThreadsVal = -1;
5182 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5183 setMaxValueFromClause(numThreads, maxThreadsVal);
5184 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5191 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5192 if (combinedMaxThreadsVal < 0 ||
5193 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5194 combinedMaxThreadsVal = teamsThreadLimitVal;
5196 if (combinedMaxThreadsVal < 0 ||
5197 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5198 combinedMaxThreadsVal = maxThreadsVal;
5200 int32_t reductionDataSize = 0;
5201 if (isGPU && capturedOp) {
5202 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5207 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5209 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5210 omp::TargetRegionFlags::spmd) &&
5211 "invalid kernel flags");
5213 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5214 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5215 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5216 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5217 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5218 attrs.MinTeams = minTeamsVal;
5219 attrs.MaxTeams.front() = maxTeamsVal;
5220 attrs.MinThreads = 1;
5221 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5222 attrs.ReductionDataSize = reductionDataSize;
5225 if (attrs.ReductionDataSize != 0)
5226 attrs.ReductionBufferLength = 1024;
5238 omp::TargetOp targetOp,
Operation *capturedOp,
5239 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5240 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5241 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5243 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5247 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5250 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5251 attrs.TargetThreadLimit.front() =
5255 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5258 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5260 if (teamsThreadLimit)
5261 attrs.TeamsThreadLimit.front() =
5265 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5267 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5268 omp::TargetRegionFlags::trip_count)) {
5270 attrs.LoopTripCount =
nullptr;
5275 for (
auto [loopLower, loopUpper, loopStep] :
5276 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5277 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5278 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5279 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5281 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5282 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5283 loc, lowerBound, upperBound, step,
true,
5284 loopOp.getLoopInclusive());
5286 if (!attrs.LoopTripCount) {
5287 attrs.LoopTripCount = tripCount;
5292 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5298 static LogicalResult
5301 auto targetOp = cast<omp::TargetOp>(opInst);
5306 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5307 bool isGPU = ompBuilder->Config.isGPU();
5310 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5311 auto &targetRegion = targetOp.getRegion();
5328 llvm::Function *llvmOutlinedFn =
nullptr;
5332 bool isOffloadEntry =
5333 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5340 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5342 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5343 std::optional<DenseI64ArrayAttr> privateMapIndices =
5344 targetOp.getPrivateMapsAttr();
5346 for (
auto [privVarIdx, privVarSymPair] :
5348 auto privVar = std::get<0>(privVarSymPair);
5349 auto privSym = std::get<1>(privVarSymPair);
5351 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5352 omp::PrivateClauseOp privatizer =
5355 if (!privatizer.needsMap())
5359 targetOp.getMappedValueForPrivateVar(privVarIdx);
5360 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5361 "variable that needs mapping");
5366 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5367 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5371 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5373 varType == privVar.getType() &&
5374 "Type of private var doesn't match the type of the mapped value");
5378 mappedPrivateVars.insert(
5380 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5381 (*privateMapIndices)[privVarIdx])});
5385 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5386 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5387 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5388 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5389 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5392 llvm::Function *llvmParentFn =
5394 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5395 assert(llvmParentFn && llvmOutlinedFn &&
5396 "Both parent and outlined functions must exist at this point");
5398 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5399 attr.isStringAttribute())
5400 llvmOutlinedFn->addFnAttr(attr);
5402 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5403 attr.isStringAttribute())
5404 llvmOutlinedFn->addFnAttr(attr);
5406 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5407 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5408 llvm::Value *mapOpValue =
5409 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5410 moduleTranslation.
mapValue(arg, mapOpValue);
5412 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5413 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5414 llvm::Value *mapOpValue =
5415 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5416 moduleTranslation.
mapValue(arg, mapOpValue);
5425 allocaIP, &mappedPrivateVars);
5428 return llvm::make_error<PreviouslyReportedError>();
5430 builder.restoreIP(codeGenIP);
5432 &mappedPrivateVars),
5435 return llvm::make_error<PreviouslyReportedError>();
5438 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
5440 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5441 return llvm::make_error<PreviouslyReportedError>();
5445 std::back_inserter(privateCleanupRegions),
5446 [](omp::PrivateClauseOp privatizer) {
5447 return &privatizer.getDeallocRegion();
5451 targetRegion,
"omp.target", builder, moduleTranslation);
5454 return exitBlock.takeError();
5456 builder.SetInsertPoint(*exitBlock);
5457 if (!privateCleanupRegions.empty()) {
5459 privateCleanupRegions, privateVarsInfo.
llvmVars,
5460 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5462 return llvm::createStringError(
5463 "failed to inline `dealloc` region of `omp.private` "
5464 "op in the target region");
5466 return builder.saveIP();
5469 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5472 StringRef parentName = parentFn.getName();
5474 llvm::TargetRegionEntryInfo entryInfo;
5478 MapInfoData mapData;
5483 MapInfosTy combinedInfos;
5485 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5486 builder.restoreIP(codeGenIP);
5487 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5488 return combinedInfos;
5491 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5492 llvm::Value *&retVal, InsertPointTy allocaIP,
5493 InsertPointTy codeGenIP)
5494 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5495 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5496 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5502 if (!isTargetDevice) {
5503 retVal = cast<llvm::Value>(&arg);
5508 *ompBuilder, moduleTranslation,
5509 allocaIP, codeGenIP);
5512 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5513 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5514 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5516 isTargetDevice, isGPU);
5520 if (!isTargetDevice)
5522 targetCapturedOp, runtimeAttrs);
5530 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5531 llvm::Value *value = moduleTranslation.
lookupValue(var);
5532 moduleTranslation.
mapValue(arg, value);
5534 if (!llvm::isa<llvm::Constant>(value))
5535 kernelInput.push_back(value);
5538 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5545 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5546 kernelInput.push_back(mapData.OriginalValue[i]);
5551 moduleTranslation, dds);
5553 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5555 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5557 llvm::OpenMPIRBuilder::TargetDataInfo info(
5561 auto customMapperCB =
5563 if (!combinedInfos.Mappers[i])
5565 info.HasMapper =
true;
5570 llvm::Value *ifCond =
nullptr;
5571 if (
Value targetIfCond = targetOp.getIfExpr())
5572 ifCond = moduleTranslation.
lookupValue(targetIfCond);
5574 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5576 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5577 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5578 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5583 builder.restoreIP(*afterIP);
5594 static LogicalResult
5604 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5605 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5607 if (!offloadMod.getIsTargetDevice())
5610 omp::DeclareTargetDeviceType declareType =
5611 attribute.getDeviceType().getValue();
5613 if (declareType == omp::DeclareTargetDeviceType::host) {
5614 llvm::Function *llvmFunc =
5616 llvmFunc->dropAllReferences();
5617 llvmFunc->eraseFromParent();
5623 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5624 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
5625 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5627 bool isDeclaration = gOp.isDeclaration();
5628 bool isExternallyVisible =
5631 llvm::StringRef mangledName = gOp.getSymName();
5632 auto captureClause =
5638 std::vector<llvm::GlobalVariable *> generatedRefs;
5640 std::vector<llvm::Triple> targetTriple;
5641 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5643 LLVM::LLVMDialect::getTargetTripleAttrName()));
5644 if (targetTripleAttr)
5645 targetTriple.emplace_back(targetTripleAttr.data());
5647 auto fileInfoCallBack = [&loc]() {
5648 std::string filename =
"";
5649 std::uint64_t lineNo = 0;
5652 filename = loc.getFilename().str();
5653 lineNo = loc.getLine();
5656 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5660 ompBuilder->registerTargetGlobalVariable(
5661 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5662 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5663 generatedRefs,
false, targetTriple,
5665 gVal->getType(), gVal);
5667 if (ompBuilder->Config.isTargetDevice() &&
5668 (attribute.getCaptureClause().getValue() !=
5669 mlir::omp::DeclareTargetCaptureClause::to ||
5670 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5671 ompBuilder->getAddrOfDeclareTargetVar(
5672 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5673 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5674 generatedRefs,
false, targetTriple, gVal->getType(),
5696 if (mlir::isa<omp::ThreadprivateOp>(op))
5700 if (
auto declareTargetIface =
5701 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5702 parentFn.getOperation()))
5703 if (declareTargetIface.isDeclareTarget() &&
5704 declareTargetIface.getDeclareTargetDeviceType() !=
5705 mlir::omp::DeclareTargetDeviceType::host)
5713 static LogicalResult
5724 bool isOutermostLoopWrapper =
5725 isa_and_present<omp::LoopWrapperInterface>(op) &&
5726 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
5728 if (isOutermostLoopWrapper)
5729 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
5733 .Case([&](omp::BarrierOp op) -> LogicalResult {
5737 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5738 ompBuilder->createBarrier(builder.saveIP(),
5739 llvm::omp::OMPD_barrier);
5741 if (res.succeeded()) {
5744 builder.restoreIP(*afterIP);
5748 .Case([&](omp::TaskyieldOp op) {
5752 ompBuilder->createTaskyield(builder.saveIP());
5755 .Case([&](omp::FlushOp op) {
5767 ompBuilder->createFlush(builder.saveIP());
5770 .Case([&](omp::ParallelOp op) {
5773 .Case([&](omp::MaskedOp) {
5776 .Case([&](omp::MasterOp) {
5779 .Case([&](omp::CriticalOp) {
5782 .Case([&](omp::OrderedRegionOp) {
5785 .Case([&](omp::OrderedOp) {
5788 .Case([&](omp::WsloopOp) {
5791 .Case([&](omp::SimdOp) {
5794 .Case([&](omp::AtomicReadOp) {
5797 .Case([&](omp::AtomicWriteOp) {
5800 .Case([&](omp::AtomicUpdateOp op) {
5803 .Case([&](omp::AtomicCaptureOp op) {
5806 .Case([&](omp::CancelOp op) {
5809 .Case([&](omp::CancellationPointOp op) {
5812 .Case([&](omp::SectionsOp) {
5815 .Case([&](omp::SingleOp op) {
5818 .Case([&](omp::TeamsOp op) {
5821 .Case([&](omp::TaskOp op) {
5824 .Case([&](omp::TaskgroupOp op) {
5827 .Case([&](omp::TaskwaitOp op) {
5830 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5831 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5832 omp::CriticalDeclareOp>([](
auto op) {
5845 .Case([&](omp::ThreadprivateOp) {
5848 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
5849 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
5852 .Case([&](omp::TargetOp) {
5855 .Case([&](omp::DistributeOp) {
5858 .Case([&](omp::LoopNestOp) {
5861 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
5870 <<
"not yet implemented: " << inst->
getName();
5873 if (isOutermostLoopWrapper)
5879 static LogicalResult
5885 static LogicalResult
5888 if (isa<omp::TargetOp>(op))
5890 if (isa<omp::TargetDataOp>(op))
5894 if (isa<omp::TargetOp>(oper)) {
5896 return WalkResult::interrupt();
5897 return WalkResult::skip();
5899 if (isa<omp::TargetDataOp>(oper)) {
5901 return WalkResult::interrupt();
5902 return WalkResult::skip();
5909 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5910 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5911 !oper->getRegions().empty()) {
5912 if (
auto blockArgsIface =
5913 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5919 if (isa<mlir::omp::AtomicUpdateOp>(oper))
5920 for (
auto [operand, arg] :
5921 llvm::zip_equal(oper->getOperands(),
5922 oper->getRegion(0).getArguments())) {
5924 arg, builder.CreateLoad(
5930 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5931 assert(builder.GetInsertBlock() &&
5932 "No insert block is set for the builder");
5933 for (
auto iv : loopNest.getIVs()) {
5941 for (
Region ®ion : oper->getRegions()) {
5948 region, oper->getName().getStringRef().str() +
".fake.region",
5949 builder, moduleTranslation, &phis);
5951 return WalkResult::interrupt();
5953 builder.SetInsertPoint(result.get(), result.get()->end());
5956 return WalkResult::skip();
5959 return WalkResult::advance();
5960 }).wasInterrupted();
5961 return failure(interrupted);
5968 class OpenMPDialectLLVMIRTranslationInterface
5989 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
5995 .Case(
"omp.is_target_device",
5997 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
5998 llvm::OpenMPIRBuilderConfig &
config =
6000 config.setIsTargetDevice(deviceAttr.getValue());
6007 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6008 llvm::OpenMPIRBuilderConfig &
config =
6010 config.setIsGPU(gpuAttr.getValue());
6015 .Case(
"omp.host_ir_filepath",
6017 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6018 llvm::OpenMPIRBuilder *ompBuilder =
6020 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
6027 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6031 .Case(
"omp.version",
6033 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6034 llvm::OpenMPIRBuilder *ompBuilder =
6036 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6037 versionAttr.getVersion());
6042 .Case(
"omp.declare_target",
6044 if (
auto declareTargetAttr =
6045 dyn_cast<omp::DeclareTargetAttr>(attr))
6050 .Case(
"omp.requires",
6052 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6053 using Requires = omp::ClauseRequires;
6054 Requires flags = requiresAttr.getValue();
6055 llvm::OpenMPIRBuilderConfig &
config =
6057 config.setHasRequiresReverseOffload(
6058 bitEnumContainsAll(flags, Requires::reverse_offload));
6059 config.setHasRequiresUnifiedAddress(
6060 bitEnumContainsAll(flags, Requires::unified_address));
6061 config.setHasRequiresUnifiedSharedMemory(
6062 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6063 config.setHasRequiresDynamicAllocators(
6064 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6069 .Case(
"omp.target_triples",
6071 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6072 llvm::OpenMPIRBuilderConfig &
config =
6074 config.TargetTriples.clear();
6075 config.TargetTriples.reserve(triplesAttr.size());
6076 for (
Attribute tripleAttr : triplesAttr) {
6077 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6078 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6096 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6097 Operation *op, llvm::IRBuilderBase &builder,
6101 if (ompBuilder->Config.isTargetDevice()) {
6112 registry.
insert<omp::OpenMPDialect>();
6114 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult convertIgnoredWrapper(omp::LoopWrapperInterface opInst, LLVM::ModuleTranslation &moduleTranslation)
Helper function to map block arguments defined by ignored loop wrappers to LLVM values and prevent an...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.