26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
31 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/MDBuilder.h"
37 #include "llvm/IR/ReplaceConstant.h"
38 #include "llvm/Support/FileSystem.h"
39 #include "llvm/TargetParser/Triple.h"
40 #include "llvm/Transforms/Utils/ModuleUtils.h"
52 static llvm::omp::ScheduleKind
53 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
66 return llvm::omp::OMP_SCHEDULE_Runtime;
68 llvm_unreachable(
"unhandled schedule clause argument");
73 class OpenMPAllocaStackFrame
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
86 class OpenMPLoopInfoStackFrame
90 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
109 class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
136 class LinearClauseProcessor {
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
150 void createLinearVar(llvm::IRBuilderBase &builder,
153 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
155 linearPreconditionVars.push_back(builder.CreateAlloca(
156 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_var"));
157 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
158 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_result");
159 linearOrigVal.push_back(moduleTranslation.
lookupValue(linearVar));
160 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
161 linearOrigVars.push_back(linearVarAlloca);
168 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173 initLinearVar(llvm::IRBuilderBase &builder,
175 llvm::BasicBlock *loopPreHeader) {
176 builder.SetInsertPoint(loopPreHeader->getTerminator());
177 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
178 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179 linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
180 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
184 builder.saveIP(), llvm::omp::OMPD_barrier);
185 return afterBarrierIP;
189 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190 llvm::Value *loopInductionVar) {
191 builder.SetInsertPoint(loopBody->getTerminator());
192 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
194 llvm::LoadInst *linearVarStart =
195 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
197 linearPreconditionVars[index]);
198 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
199 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
200 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
206 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207 llvm::BasicBlock *loopExit) {
208 linearFinalizationBB = loopExit->splitBasicBlock(
209 loopExit->getTerminator(),
"omp_loop.linear_finalization");
210 linearExitBB = linearFinalizationBB->splitBasicBlock(
211 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
212 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
213 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
217 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218 finalizeLinearVar(llvm::IRBuilderBase &builder,
220 llvm::Value *lastIter) {
222 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
223 llvm::Value *loopLastIterLoad = builder.CreateLoad(
224 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
225 llvm::Value *isLast =
226 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
228 llvm::Type::getInt32Ty(builder.getContext()), 0));
230 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
231 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
232 llvm::LoadInst *linearVarTemp =
233 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
234 linearLoopBodyTemps[index]);
235 builder.CreateStore(linearVarTemp, linearOrigVars[index]);
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
245 builder.SetInsertPoint(linearExitBB->getTerminator());
247 builder.saveIP(), llvm::omp::OMPD_barrier);
252 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(user);
257 for (
auto *user : users) {
258 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259 if (userInst->getParent()->getName().str() == BBName)
260 user->replaceUsesOfWith(linearOrigVal[varIndex],
261 linearLoopBodyTemps[varIndex]);
272 SymbolRefAttr symbolName) {
273 omp::PrivateClauseOp privatizer =
274 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
276 assert(privatizer &&
"privatizer not found in the symbol table");
287 auto todo = [&op](StringRef clauseName) {
288 return op.
emitError() <<
"not yet implemented: Unhandled clause "
289 << clauseName <<
" in " << op.
getName()
293 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
294 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
295 result = todo(
"allocate");
297 auto checkBare = [&todo](
auto op, LogicalResult &result) {
299 result = todo(
"ompx_bare");
301 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
302 omp::ClauseCancellationConstructType cancelledDirective =
303 op.getCancelDirective();
306 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
313 if (isa_and_nonnull<omp::TaskloopOp>(parent))
314 result = todo(
"cancel directive inside of taskloop");
317 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
318 if (!op.getDependVars().empty() || op.getDependKinds())
319 result = todo(
"depend");
321 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
323 result = todo(
"device");
325 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
326 if (op.getDistScheduleChunkSize())
327 result = todo(
"dist_schedule with chunk_size");
329 auto checkHint = [](
auto op, LogicalResult &) {
333 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
334 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
335 op.getInReductionSyms())
336 result = todo(
"in_reduction");
338 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
339 if (!op.getIsDevicePtrVars().empty())
340 result = todo(
"is_device_ptr");
342 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
343 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
344 result = todo(
"linear");
346 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
348 result = todo(
"nowait");
350 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
351 if (op.getOrder() || op.getOrderMod())
352 result = todo(
"order");
354 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
355 if (op.getParLevelSimd())
356 result = todo(
"parallelization-level");
358 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
359 if (op.getPriority())
360 result = todo(
"priority");
362 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
363 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
365 if (!op.getPrivateVars().empty() && op.getNowait())
366 result = todo(
"privatization for deferred target tasks");
368 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
369 result = todo(
"privatization");
372 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
373 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
374 if (!op.getReductionVars().empty() || op.getReductionByref() ||
375 op.getReductionSyms())
376 result = todo(
"reduction");
377 if (op.getReductionMod() &&
378 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
379 result = todo(
"reduction with modifier");
381 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
382 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
383 op.getTaskReductionSyms())
384 result = todo(
"task_reduction");
386 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
388 result = todo(
"untied");
391 LogicalResult result = success();
393 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
394 .Case([&](omp::CancellationPointOp op) {
395 checkCancelDirective(op, result);
397 .Case([&](omp::DistributeOp op) {
398 checkAllocate(op, result);
399 checkDistSchedule(op, result);
400 checkOrder(op, result);
402 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op, result);
405 checkPrivate(op, result);
406 checkReduction(op, result);
408 .Case([&](omp::SingleOp op) {
409 checkAllocate(op, result);
410 checkPrivate(op, result);
412 .Case([&](omp::TeamsOp op) {
413 checkAllocate(op, result);
414 checkPrivate(op, result);
416 .Case([&](omp::TaskOp op) {
417 checkAllocate(op, result);
418 checkInReduction(op, result);
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
424 .Case([&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
428 .Case([&](omp::TaskloopOp op) {
430 checkUntied(op, result);
431 checkPriority(op, result);
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkLinear(op, result);
436 checkOrder(op, result);
437 checkReduction(op, result);
439 .Case([&](omp::ParallelOp op) {
440 checkAllocate(op, result);
441 checkReduction(op, result);
443 .Case([&](omp::SimdOp op) {
444 checkLinear(op, result);
445 checkReduction(op, result);
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
450 [&](
auto op) { checkDepend(op, result); })
451 .Case([&](omp::TargetOp op) {
452 checkAllocate(op, result);
453 checkBare(op, result);
454 checkDevice(op, result);
455 checkInReduction(op, result);
456 checkIsDevicePtr(op, result);
457 checkPrivate(op, result);
467 LogicalResult result = success();
469 llvm::handleAllErrors(
471 [&](
const PreviouslyReportedError &) { result = failure(); },
472 [&](
const llvm::ErrorInfoBase &err) {
479 template <
typename T>
489 static llvm::OpenMPIRBuilder::InsertPointTy
495 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
497 [&](OpenMPAllocaStackFrame &frame) {
498 allocaInsertPoint = frame.allocaInsertPoint;
502 return allocaInsertPoint;
511 if (builder.GetInsertBlock() ==
512 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
513 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
514 "Assuming end of basic block");
515 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
516 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
517 builder.GetInsertBlock()->getNextNode());
518 builder.CreateBr(entryBB);
519 builder.SetInsertPoint(entryBB);
522 llvm::BasicBlock &funcEntryBlock =
523 builder.GetInsertBlock()->getParent()->getEntryBlock();
524 return llvm::OpenMPIRBuilder::InsertPointTy(
525 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
531 static llvm::CanonicalLoopInfo *
533 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
534 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
535 [&](OpenMPLoopInfoStackFrame &frame) {
536 loopInfo = frame.loopInfo;
548 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
551 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
553 llvm::BasicBlock *continuationBlock =
554 splitBB(builder,
true,
"omp.region.cont");
555 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
557 llvm::LLVMContext &llvmContext = builder.getContext();
558 for (
Block &bb : region) {
559 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
560 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
561 builder.GetInsertBlock()->getNextNode());
562 moduleTranslation.
mapBlock(&bb, llvmBB);
565 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
572 unsigned numYields = 0;
574 if (!isLoopWrapper) {
575 bool operandsProcessed =
false;
577 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
578 if (!operandsProcessed) {
579 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
580 continuationBlockPHITypes.push_back(
581 moduleTranslation.
convertType(yield->getOperand(i).getType()));
583 operandsProcessed =
true;
585 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
586 "mismatching number of values yielded from the region");
587 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
588 llvm::Type *operandType =
589 moduleTranslation.
convertType(yield->getOperand(i).getType());
591 assert(continuationBlockPHITypes[i] == operandType &&
592 "values of mismatching types yielded from the region");
602 if (!continuationBlockPHITypes.empty())
604 continuationBlockPHIs &&
605 "expected continuation block PHIs if converted regions yield values");
606 if (continuationBlockPHIs) {
607 llvm::IRBuilderBase::InsertPointGuard guard(builder);
608 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
609 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
610 for (llvm::Type *ty : continuationBlockPHITypes)
611 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
617 for (
Block *bb : blocks) {
618 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
621 if (bb->isEntryBlock()) {
622 assert(sourceTerminator->getNumSuccessors() == 1 &&
623 "provided entry block has multiple successors");
624 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
625 "ContinuationBlock is not the successor of the entry block");
626 sourceTerminator->setSuccessor(0, llvmBB);
629 llvm::IRBuilderBase::InsertPointGuard guard(builder);
631 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
632 return llvm::make_error<PreviouslyReportedError>();
637 builder.CreateBr(continuationBlock);
648 Operation *terminator = bb->getTerminator();
649 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
650 builder.CreateBr(continuationBlock);
652 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
653 (*continuationBlockPHIs)[i]->addIncoming(
667 return continuationBlock;
673 case omp::ClauseProcBindKind::Close:
674 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
675 case omp::ClauseProcBindKind::Master:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
677 case omp::ClauseProcBindKind::Primary:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
679 case omp::ClauseProcBindKind::Spread:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
682 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
692 omp::BlockArgOpenMPOpInterface blockArgIface) {
694 blockArgIface.getBlockArgsPairs(blockArgsPairs);
695 for (
auto [var, arg] : blockArgsPairs)
712 .Case([&](omp::SimdOp op) {
714 cast<omp::BlockArgOpenMPOpInterface>(*op));
715 op.emitWarning() <<
"simd information on composite construct discarded";
719 return op->emitError() <<
"cannot ignore wrapper";
727 auto maskedOp = cast<omp::MaskedOp>(opInst);
728 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
733 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
735 auto ®ion = maskedOp.getRegion();
736 builder.restoreIP(codeGenIP);
744 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
746 llvm::Value *filterVal =
nullptr;
747 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
748 filterVal = moduleTranslation.
lookupValue(filterVar);
750 llvm::LLVMContext &llvmContext = builder.getContext();
754 assert(filterVal !=
nullptr);
755 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
756 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
763 builder.restoreIP(*afterIP);
771 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
772 auto masterOp = cast<omp::MasterOp>(opInst);
777 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
779 auto ®ion = masterOp.getRegion();
780 builder.restoreIP(codeGenIP);
788 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
790 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
791 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
798 builder.restoreIP(*afterIP);
806 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
807 auto criticalOp = cast<omp::CriticalOp>(opInst);
812 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
814 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
815 builder.restoreIP(codeGenIP);
823 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
825 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
826 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
827 llvm::Constant *hint =
nullptr;
830 if (criticalOp.getNameAttr()) {
833 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
834 auto criticalDeclareOp =
835 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
839 static_cast<int>(criticalDeclareOp.getHint()));
841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
843 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
848 builder.restoreIP(*afterIP);
855 template <
typename OP>
858 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
859 mlirVars.reserve(blockArgs.size());
860 llvmVars.reserve(blockArgs.size());
861 collectPrivatizationDecls<OP>(op);
864 mlirVars.push_back(privateVar);
876 void collectPrivatizationDecls(OP op) {
877 std::optional<ArrayAttr> attr = op.getPrivateSyms();
881 privatizers.reserve(privatizers.size() + attr->size());
882 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
889 template <
typename T>
893 std::optional<ArrayAttr> attr = op.getReductionSyms();
897 reductions.reserve(reductions.size() + op.getNumReductionVars());
898 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
899 reductions.push_back(
900 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
911 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
919 if (llvm::hasSingleElement(region)) {
920 llvm::Instruction *potentialTerminator =
921 builder.GetInsertBlock()->empty() ? nullptr
922 : &builder.GetInsertBlock()->back();
924 if (potentialTerminator && potentialTerminator->isTerminator())
925 potentialTerminator->removeFromParent();
926 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
929 region.
front(),
true, builder)))
933 if (continuationBlockArgs)
935 *continuationBlockArgs,
942 if (potentialTerminator && potentialTerminator->isTerminator()) {
943 llvm::BasicBlock *block = builder.GetInsertBlock();
944 if (block->empty()) {
950 potentialTerminator->insertInto(block, block->begin());
952 potentialTerminator->insertAfter(&block->back());
966 if (continuationBlockArgs)
967 llvm::append_range(*continuationBlockArgs, phis);
968 builder.SetInsertPoint(*continuationBlock,
969 (*continuationBlock)->getFirstInsertionPt());
976 using OwningReductionGen =
977 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
978 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
980 using OwningAtomicReductionGen =
981 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
982 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
989 static OwningReductionGen
995 OwningReductionGen gen =
996 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
997 llvm::Value *lhs, llvm::Value *rhs,
998 llvm::Value *&result)
mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.
mapValue(decl.getReductionLhsArg(), lhs);
1001 moduleTranslation.
mapValue(decl.getReductionRhsArg(), rhs);
1002 builder.restoreIP(insertPoint);
1005 "omp.reduction.nonatomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `combiner` region of `omp.declare_reduction`");
1009 result = llvm::getSingleElement(phis);
1010 return builder.saveIP();
1019 static OwningAtomicReductionGen
1021 llvm::IRBuilderBase &builder,
1023 if (decl.getAtomicReductionRegion().empty())
1024 return OwningAtomicReductionGen();
1029 OwningAtomicReductionGen atomicGen =
1030 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1031 llvm::Value *lhs, llvm::Value *rhs)
mutable
1032 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1033 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(), lhs);
1034 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(), rhs);
1035 builder.restoreIP(insertPoint);
1038 "omp.reduction.atomic.body", builder,
1039 moduleTranslation, &phis)))
1040 return llvm::createStringError(
1041 "failed to inline `atomic` region of `omp.declare_reduction`");
1042 assert(phis.empty());
1043 return builder.saveIP();
1049 static LogicalResult
1052 auto orderedOp = cast<omp::OrderedOp>(opInst);
1057 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1058 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1059 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1061 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1063 size_t indexVecValues = 0;
1064 while (indexVecValues < vecValues.size()) {
1066 storeValues.reserve(numLoops);
1067 for (
unsigned i = 0; i < numLoops; i++) {
1068 storeValues.push_back(vecValues[indexVecValues]);
1071 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1073 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1074 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1075 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1082 static LogicalResult
1085 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1086 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1091 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1093 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1094 builder.restoreIP(codeGenIP);
1102 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1104 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1105 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1107 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1112 builder.restoreIP(*afterIP);
1118 struct DeferredStore {
1119 DeferredStore(llvm::Value *value, llvm::Value *address)
1120 : value(value), address(address) {}
1123 llvm::Value *address;
1130 template <
typename T>
1131 static LogicalResult
1133 llvm::IRBuilderBase &builder,
1135 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1141 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1142 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1145 deferredStores.reserve(loop.getNumReductionVars());
1147 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1148 Region &allocRegion = reductionDecls[i].getAllocRegion();
1150 if (allocRegion.
empty())
1155 builder, moduleTranslation, &phis)))
1156 return loop.emitError(
1157 "failed to inline `alloc` region of `omp.declare_reduction`");
1159 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1160 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1164 llvm::Value *var = builder.CreateAlloca(
1165 moduleTranslation.
convertType(reductionDecls[i].getType()));
1167 llvm::Type *ptrTy = builder.getPtrTy();
1168 llvm::Value *castVar =
1169 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1170 llvm::Value *castPhi =
1171 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1173 deferredStores.emplace_back(castPhi, castVar);
1175 privateReductionVariables[i] = castVar;
1176 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1177 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1179 assert(allocRegion.
empty() &&
1180 "allocaction is implicit for by-val reduction");
1181 llvm::Value *var = builder.CreateAlloca(
1182 moduleTranslation.
convertType(reductionDecls[i].getType()));
1184 llvm::Type *ptrTy = builder.getPtrTy();
1185 llvm::Value *castVar =
1186 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1188 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1189 privateReductionVariables[i] = castVar;
1190 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1198 template <
typename T>
1205 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1206 Region &initializerRegion = reduction.getInitializerRegion();
1209 mlir::Value mlirSource = loop.getReductionVars()[i];
1210 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1211 assert(llvmSource &&
"lookup reduction var");
1212 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1215 llvm::Value *allocation =
1216 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1217 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1223 llvm::BasicBlock *block =
nullptr) {
1224 if (block ==
nullptr)
1225 block = builder.GetInsertBlock();
1227 if (block->empty() || block->getTerminator() ==
nullptr)
1228 builder.SetInsertPoint(block);
1230 builder.SetInsertPoint(block->getTerminator());
1238 template <
typename OP>
1239 static LogicalResult
1241 llvm::IRBuilderBase &builder,
1243 llvm::BasicBlock *latestAllocaBlock,
1249 if (op.getNumReductionVars() == 0)
1252 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1253 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1254 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1255 builder.restoreIP(allocaIP);
1258 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1260 if (!reductionDecls[i].getAllocRegion().empty())
1266 byRefVars[i] = builder.CreateAlloca(
1267 moduleTranslation.
convertType(reductionDecls[i].getType()));
1275 for (
auto [data, addr] : deferredStores)
1276 builder.CreateStore(data, addr);
1281 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1286 reductionVariableMap, i);
1289 "omp.reduction.neutral", builder,
1290 moduleTranslation, &phis)))
1293 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1294 "reduction neutral element declaration region");
1299 if (!reductionDecls[i].getAllocRegion().empty())
1308 builder.CreateStore(phis[0], byRefVars[i]);
1310 privateReductionVariables[i] = byRefVars[i];
1311 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1312 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1315 builder.CreateStore(phis[0], privateReductionVariables[i]);
1322 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1329 template <
typename T>
1331 T loop, llvm::IRBuilderBase &builder,
1338 unsigned numReductions = loop.getNumReductionVars();
1340 for (
unsigned i = 0; i < numReductions; ++i) {
1341 owningReductionGens.push_back(
1343 owningAtomicReductionGens.push_back(
1348 reductionInfos.reserve(numReductions);
1349 for (
unsigned i = 0; i < numReductions; ++i) {
1350 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1351 if (owningAtomicReductionGens[i])
1352 atomicGen = owningAtomicReductionGens[i];
1353 llvm::Value *variable =
1354 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1355 reductionInfos.push_back(
1357 privateReductionVariables[i],
1358 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1359 owningReductionGens[i],
1360 nullptr, atomicGen});
1365 static LogicalResult
1369 llvm::IRBuilderBase &builder, StringRef regionName,
1370 bool shouldLoadCleanupRegionArg =
true) {
1372 if (cleanupRegion->empty())
1378 llvm::Instruction *potentialTerminator =
1379 builder.GetInsertBlock()->empty() ? nullptr
1380 : &builder.GetInsertBlock()->back();
1381 if (potentialTerminator && potentialTerminator->isTerminator())
1382 builder.SetInsertPoint(potentialTerminator);
1383 llvm::Value *privateVarValue =
1384 shouldLoadCleanupRegionArg
1385 ? builder.CreateLoad(
1387 privateVariables[i])
1388 : privateVariables[i];
1393 moduleTranslation)))
1406 OP op, llvm::IRBuilderBase &builder,
1408 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1411 bool isNowait =
false,
bool isTeamsReduction =
false) {
1413 if (op.getNumReductionVars() == 0)
1425 owningReductionGens, owningAtomicReductionGens,
1426 privateReductionVariables, reductionInfos);
1431 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1432 builder.SetInsertPoint(tempTerminator);
1433 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1434 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1435 isByRef, isNowait, isTeamsReduction);
1440 if (!contInsertPoint->getBlock())
1441 return op->emitOpError() <<
"failed to convert reductions";
1443 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1444 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1449 tempTerminator->eraseFromParent();
1450 builder.restoreIP(*afterIP);
1454 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1455 [](omp::DeclareReductionOp reductionDecl) {
1456 return &reductionDecl.getCleanupRegion();
1459 moduleTranslation, builder,
1460 "omp.reduction.cleanup");
1471 template <
typename OP>
1475 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1480 if (op.getNumReductionVars() == 0)
1486 allocaIP, reductionDecls,
1487 privateReductionVariables, reductionVariableMap,
1488 deferredStores, isByRef)))
1492 allocaIP.getBlock(), reductionDecls,
1493 privateReductionVariables, reductionVariableMap,
1494 isByRef, deferredStores);
1504 static llvm::Value *
1508 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1511 Value blockArg = (*mappedPrivateVars)[privateVar];
1514 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1515 "A block argument corresponding to a mapped var should have "
1518 if (privVarType == blockArgType)
1525 if (!isa<LLVM::LLVMPointerType>(privVarType))
1526 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1539 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1541 Region &initRegion = privDecl.getInitRegion();
1542 if (initRegion.
empty())
1543 return llvmPrivateVar;
1547 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1548 assert(nonPrivateVar);
1549 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1550 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1555 moduleTranslation, &phis)))
1556 return llvm::createStringError(
1557 "failed to inline `init` region of `omp.private`");
1559 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1578 return llvm::Error::success();
1580 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1586 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1588 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1589 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1592 return privVarOrErr.takeError();
1594 llvmPrivateVar = privVarOrErr.get();
1595 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1600 return llvm::Error::success();
1610 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1613 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1614 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1615 allocaTerminator->getIterator()),
1616 true, allocaTerminator->getStableDebugLoc(),
1617 "omp.region.after_alloca");
1619 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1621 allocaTerminator = allocaIP.getBlock()->getTerminator();
1622 builder.SetInsertPoint(allocaTerminator);
1624 assert(allocaTerminator->getNumSuccessors() == 1 &&
1625 "This is an unconditional branch created by splitBB");
1627 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1628 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1630 unsigned int allocaAS =
1631 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1634 .getProgramAddressSpace();
1636 for (
auto [privDecl, mlirPrivVar, blockArg] :
1639 llvm::Type *llvmAllocType =
1640 moduleTranslation.
convertType(privDecl.getType());
1641 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1642 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1643 llvmAllocType,
nullptr,
"omp.private.alloc");
1644 if (allocaAS != defaultAS)
1645 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1646 builder.getPtrTy(defaultAS));
1648 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1651 return afterAllocas;
1662 bool needsFirstprivate =
1663 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1664 return privOp.getDataSharingType() ==
1665 omp::DataSharingClauseType::FirstPrivate;
1668 if (!needsFirstprivate)
1671 llvm::BasicBlock *copyBlock =
1672 splitBB(builder,
true,
"omp.private.copy");
1675 for (
auto [decl, mlirVar, llvmVar] :
1676 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1677 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1681 Region ©Region = decl.getCopyRegion();
1685 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1686 assert(nonPrivateVar);
1687 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1690 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1694 moduleTranslation)))
1695 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1707 if (insertBarrier) {
1709 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1710 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1718 static LogicalResult
1725 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1726 [](omp::PrivateClauseOp privatizer) {
1727 return &privatizer.getDeallocRegion();
1731 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1732 "omp.private.dealloc",
false)))
1733 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1734 "`omp.private` op in");
1746 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1753 static LogicalResult
1756 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1757 using StorableBodyGenCallbackTy =
1758 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1760 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1766 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1770 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1774 sectionsOp.getNumReductionVars());
1778 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1781 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1782 reductionDecls, privateReductionVariables, reductionVariableMap,
1789 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1793 Region ®ion = sectionOp.getRegion();
1794 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1795 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1796 builder.restoreIP(codeGenIP);
1803 sectionsOp.getRegion().getNumArguments());
1804 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1805 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1806 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1808 moduleTranslation.
mapValue(sectionArg, llvmVal);
1815 sectionCBs.push_back(sectionCB);
1821 if (sectionCBs.empty())
1824 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1829 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1830 llvm::Value &vPtr, llvm::Value *&replacementValue)
1831 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1832 replacementValue = &vPtr;
1838 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1842 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1843 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1845 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1846 sectionsOp.getNowait());
1851 builder.restoreIP(*afterIP);
1855 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1856 privateReductionVariables, isByRef, sectionsOp.getNowait());
1860 static LogicalResult
1863 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1864 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1869 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1870 builder.restoreIP(codegenIP);
1872 builder, moduleTranslation)
1875 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1879 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1882 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1883 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1884 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1885 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1886 llvmCPFuncs.push_back(
1890 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1892 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1898 builder.restoreIP(*afterIP);
1904 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1909 for (
auto ra : iface.getReductionBlockArgs())
1910 for (
auto &use : ra.getUses()) {
1911 auto *useOp = use.getOwner();
1913 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1914 debugUses.push_back(useOp);
1918 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1923 Operation *currentOp = currentDistOp.getOperation();
1924 if (distOp && (distOp != currentOp))
1933 for (
auto use : debugUses)
1939 static LogicalResult
1942 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1947 unsigned numReductionVars = op.getNumReductionVars();
1951 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1957 if (doTeamsReduction) {
1958 isByRef =
getIsByRef(op.getReductionByref());
1960 assert(isByRef.size() == op.getNumReductionVars());
1963 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1968 op, reductionArgs, builder, moduleTranslation, allocaIP,
1969 reductionDecls, privateReductionVariables, reductionVariableMap,
1974 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1976 moduleTranslation, allocaIP);
1977 builder.restoreIP(codegenIP);
1983 llvm::Value *numTeamsLower =
nullptr;
1984 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1985 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1987 llvm::Value *numTeamsUpper =
nullptr;
1988 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1989 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1991 llvm::Value *threadLimit =
nullptr;
1992 if (
Value threadLimitVar = op.getThreadLimit())
1993 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1995 llvm::Value *ifExpr =
nullptr;
1996 if (
Value ifVar = op.getIfExpr())
1999 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2000 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2002 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2007 builder.restoreIP(*afterIP);
2008 if (doTeamsReduction) {
2011 op, builder, moduleTranslation, allocaIP, reductionDecls,
2012 privateReductionVariables, isByRef,
2022 if (dependVars.empty())
2024 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2025 llvm::omp::RTLDependenceKindTy type;
2027 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2028 case mlir::omp::ClauseTaskDepend::taskdependin:
2029 type = llvm::omp::RTLDependenceKindTy::DepIn;
2034 case mlir::omp::ClauseTaskDepend::taskdependout:
2035 case mlir::omp::ClauseTaskDepend::taskdependinout:
2036 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2038 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2039 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2041 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2042 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2045 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2046 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2047 dds.emplace_back(dd);
2059 llvm::IRBuilderBase &llvmBuilder,
2061 llvm::omp::Directive cancelDirective) {
2062 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2063 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2067 llvmBuilder.restoreIP(ip);
2073 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2074 return llvm::Error::success();
2079 ompBuilder.pushFinalizationCB(
2089 llvm::OpenMPIRBuilder &ompBuilder,
2090 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2091 ompBuilder.popFinalizationCB();
2092 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2093 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2094 assert(cancelBranch->getNumSuccessors() == 1 &&
2095 "cancel branch should have one target");
2096 cancelBranch->setSuccessor(0, constructFini);
2103 class TaskContextStructManager {
2105 TaskContextStructManager(llvm::IRBuilderBase &builder,
2108 : builder{builder}, moduleTranslation{moduleTranslation},
2109 privateDecls{privateDecls} {}
2115 void generateTaskContextStruct();
2121 void createGEPsToPrivateVars();
2124 void freeStructPtr();
2127 return llvmPrivateVarGEPs;
2130 llvm::Value *getStructPtr() {
return structPtr; }
2133 llvm::IRBuilderBase &builder;
2145 llvm::Value *structPtr =
nullptr;
2147 llvm::Type *structTy =
nullptr;
2151 void TaskContextStructManager::generateTaskContextStruct() {
2152 if (privateDecls.empty())
2154 privateVarTypes.reserve(privateDecls.size());
2156 for (omp::PrivateClauseOp &privOp : privateDecls) {
2159 if (!privOp.readsFromMold())
2161 Type mlirType = privOp.getType();
2162 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2168 llvm::DataLayout dataLayout =
2169 builder.GetInsertBlock()->getModule()->getDataLayout();
2170 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2171 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2174 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2176 "omp.task.context_ptr");
2179 void TaskContextStructManager::createGEPsToPrivateVars() {
2181 assert(privateVarTypes.empty());
2186 llvmPrivateVarGEPs.clear();
2187 llvmPrivateVarGEPs.reserve(privateDecls.size());
2188 llvm::Value *zero = builder.getInt32(0);
2190 for (
auto privDecl : privateDecls) {
2191 if (!privDecl.readsFromMold()) {
2193 llvmPrivateVarGEPs.push_back(
nullptr);
2196 llvm::Value *iVal = builder.getInt32(i);
2197 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2198 llvmPrivateVarGEPs.push_back(gep);
2203 void TaskContextStructManager::freeStructPtr() {
2207 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2209 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2210 builder.CreateFree(structPtr);
2214 static LogicalResult
2217 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2222 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2234 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2239 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2240 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2241 builder.getContext(),
"omp.task.start",
2242 builder.GetInsertBlock()->getParent());
2243 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2244 builder.SetInsertPoint(branchToTaskStartBlock);
2247 llvm::BasicBlock *copyBlock =
2248 splitBB(builder,
true,
"omp.private.copy");
2249 llvm::BasicBlock *initBlock =
2250 splitBB(builder,
true,
"omp.private.init");
2266 moduleTranslation, allocaIP);
2269 builder.SetInsertPoint(initBlock->getTerminator());
2272 taskStructMgr.generateTaskContextStruct();
2279 taskStructMgr.createGEPsToPrivateVars();
2281 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2284 taskStructMgr.getLLVMPrivateVarGEPs())) {
2286 if (!privDecl.readsFromMold())
2288 assert(llvmPrivateVarAlloc &&
2289 "reads from mold so shouldn't have been skipped");
2292 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2293 blockArg, llvmPrivateVarAlloc, initBlock);
2294 if (!privateVarOrErr)
2295 return handleError(privateVarOrErr, *taskOp.getOperation());
2297 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2298 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2305 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2306 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2307 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2309 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2310 llvmPrivateVarAlloc);
2312 assert(llvmPrivateVarAlloc->getType() ==
2313 moduleTranslation.
convertType(blockArg.getType()));
2323 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2324 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2325 taskOp.getPrivateNeedsBarrier())))
2326 return llvm::failure();
2329 builder.SetInsertPoint(taskStartBlock);
2331 auto bodyCB = [&](InsertPointTy allocaIP,
2332 InsertPointTy codegenIP) -> llvm::Error {
2336 moduleTranslation, allocaIP);
2339 builder.restoreIP(codegenIP);
2341 llvm::BasicBlock *privInitBlock =
nullptr;
2346 auto [blockArg, privDecl, mlirPrivVar] = zip;
2348 if (privDecl.readsFromMold())
2351 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2352 llvm::Type *llvmAllocType =
2353 moduleTranslation.
convertType(privDecl.getType());
2354 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2355 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2356 llvmAllocType,
nullptr,
"omp.private.alloc");
2359 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2360 blockArg, llvmPrivateVar, privInitBlock);
2361 if (!privateVarOrError)
2362 return privateVarOrError.takeError();
2363 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2364 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2367 taskStructMgr.createGEPsToPrivateVars();
2368 for (
auto [i, llvmPrivVar] :
2371 assert(privateVarsInfo.
llvmVars[i] &&
2372 "This is added in the loop above");
2375 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2380 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2384 if (!privateDecl.readsFromMold())
2387 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2388 llvmPrivateVar = builder.CreateLoad(
2389 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2391 assert(llvmPrivateVar->getType() ==
2392 moduleTranslation.
convertType(blockArg.getType()));
2393 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2397 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2398 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2399 return llvm::make_error<PreviouslyReportedError>();
2401 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2406 return llvm::make_error<PreviouslyReportedError>();
2409 taskStructMgr.freeStructPtr();
2411 return llvm::Error::success();
2420 llvm::omp::Directive::OMPD_taskgroup);
2424 moduleTranslation, dds);
2426 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2427 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2429 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2431 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2432 taskOp.getMergeable(),
2433 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2434 moduleTranslation.
lookupValue(taskOp.getPriority()));
2442 builder.restoreIP(*afterIP);
2447 static LogicalResult
2450 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2454 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2455 builder.restoreIP(codegenIP);
2457 builder, moduleTranslation)
2462 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2463 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2470 builder.restoreIP(*afterIP);
2474 static LogicalResult
2485 static LogicalResult
2489 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2493 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2495 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2499 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2502 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2503 llvm::Type *ivType = step->getType();
2504 llvm::Value *chunk =
nullptr;
2505 if (wsloopOp.getScheduleChunk()) {
2506 llvm::Value *chunkVar =
2507 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2508 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2515 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2519 wsloopOp.getNumReductionVars());
2522 builder, moduleTranslation, privateVarsInfo, allocaIP);
2529 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2534 moduleTranslation, allocaIP, reductionDecls,
2535 privateReductionVariables, reductionVariableMap,
2536 deferredStores, isByRef)))
2545 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2547 wsloopOp.getPrivateNeedsBarrier())))
2550 assert(afterAllocas.get()->getSinglePredecessor());
2553 afterAllocas.get()->getSinglePredecessor(),
2554 reductionDecls, privateReductionVariables,
2555 reductionVariableMap, isByRef, deferredStores)))
2559 bool isOrdered = wsloopOp.getOrdered().has_value();
2560 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2561 bool isSimd = wsloopOp.getScheduleSimd();
2562 bool loopNeedsBarrier = !wsloopOp.getNowait();
2567 llvm::omp::WorksharingLoopType workshareLoopType =
2568 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2569 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2570 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2574 llvm::omp::Directive::OMPD_for);
2576 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2579 LinearClauseProcessor linearClauseProcessor;
2580 if (wsloopOp.getLinearVars().size()) {
2581 for (
mlir::Value linearVar : wsloopOp.getLinearVars())
2582 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2584 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2585 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2589 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2597 if (wsloopOp.getLinearVars().size()) {
2598 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2599 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2600 loopInfo->getPreheader());
2603 builder.restoreIP(*afterBarrierIP);
2604 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2605 loopInfo->getIndVar());
2606 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2607 loopInfo->getExit());
2610 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2611 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2612 ompBuilder->applyWorkshareLoop(
2613 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2614 convertToScheduleKind(schedule), chunk, isSimd,
2615 scheduleMod == omp::ScheduleModifier::monotonic,
2616 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2623 if (wsloopOp.getLinearVars().size()) {
2624 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2625 assert(loopInfo->getLastIter() &&
2626 "`lastiter` in CanonicalLoopInfo is nullptr");
2627 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2628 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2629 loopInfo->getLastIter());
2632 for (
size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2633 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2635 builder.restoreIP(oldIP);
2643 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2644 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2654 static LogicalResult
2657 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2659 assert(isByRef.size() == opInst.getNumReductionVars());
2671 opInst.getNumReductionVars());
2674 auto bodyGenCB = [&](InsertPointTy allocaIP,
2675 InsertPointTy codeGenIP) -> llvm::Error {
2677 builder, moduleTranslation, privateVarsInfo, allocaIP);
2679 return llvm::make_error<PreviouslyReportedError>();
2685 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2688 InsertPointTy(allocaIP.getBlock(),
2689 allocaIP.getBlock()->getTerminator()->getIterator());
2692 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2693 reductionDecls, privateReductionVariables, reductionVariableMap,
2694 deferredStores, isByRef)))
2695 return llvm::make_error<PreviouslyReportedError>();
2697 assert(afterAllocas.get()->getSinglePredecessor());
2698 builder.restoreIP(codeGenIP);
2704 return llvm::make_error<PreviouslyReportedError>();
2707 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2709 opInst.getPrivateNeedsBarrier())))
2710 return llvm::make_error<PreviouslyReportedError>();
2714 afterAllocas.get()->getSinglePredecessor(),
2715 reductionDecls, privateReductionVariables,
2716 reductionVariableMap, isByRef, deferredStores)))
2717 return llvm::make_error<PreviouslyReportedError>();
2722 moduleTranslation, allocaIP);
2726 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2728 return regionBlock.takeError();
2731 if (opInst.getNumReductionVars() > 0) {
2737 owningReductionGens, owningAtomicReductionGens,
2738 privateReductionVariables, reductionInfos);
2741 builder.SetInsertPoint((*regionBlock)->getTerminator());
2744 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2745 builder.SetInsertPoint(tempTerminator);
2747 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2748 ompBuilder->createReductions(
2749 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2751 if (!contInsertPoint)
2752 return contInsertPoint.takeError();
2754 if (!contInsertPoint->getBlock())
2755 return llvm::make_error<PreviouslyReportedError>();
2757 tempTerminator->eraseFromParent();
2758 builder.restoreIP(*contInsertPoint);
2761 return llvm::Error::success();
2764 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2765 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2774 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2775 InsertPointTy oldIP = builder.saveIP();
2776 builder.restoreIP(codeGenIP);
2781 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2782 [](omp::DeclareReductionOp reductionDecl) {
2783 return &reductionDecl.getCleanupRegion();
2786 reductionCleanupRegions, privateReductionVariables,
2787 moduleTranslation, builder,
"omp.reduction.cleanup")))
2788 return llvm::createStringError(
2789 "failed to inline `cleanup` region of `omp.declare_reduction`");
2794 return llvm::make_error<PreviouslyReportedError>();
2796 builder.restoreIP(oldIP);
2797 return llvm::Error::success();
2800 llvm::Value *ifCond =
nullptr;
2801 if (
auto ifVar = opInst.getIfExpr())
2803 llvm::Value *numThreads =
nullptr;
2804 if (
auto numThreadsVar = opInst.getNumThreads())
2805 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2806 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2807 if (
auto bind = opInst.getProcBindKind())
2811 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2813 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2815 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2816 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2817 ifCond, numThreads, pbKind, isCancellable);
2822 builder.restoreIP(*afterIP);
2827 static llvm::omp::OrderKind
2830 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2832 case omp::ClauseOrderKind::Concurrent:
2833 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2835 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2839 static LogicalResult
2843 auto simdOp = cast<omp::SimdOp>(opInst);
2849 if (simdOp.isComposite()) {
2854 builder, moduleTranslation);
2862 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2866 builder, moduleTranslation, privateVarsInfo, allocaIP);
2875 llvm::ConstantInt *simdlen =
nullptr;
2876 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2877 simdlen = builder.getInt64(simdlenVar.value());
2879 llvm::ConstantInt *safelen =
nullptr;
2880 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2881 safelen = builder.getInt64(safelenVar.value());
2883 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2886 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2887 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2889 for (
size_t i = 0; i < operands.size(); ++i) {
2890 llvm::Value *alignment =
nullptr;
2891 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
2892 llvm::Type *ty = llvmVal->getType();
2894 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2895 alignment = builder.getInt64(intAttr.getInt());
2896 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2897 assert(alignment &&
"Invalid alignment value");
2898 auto curInsert = builder.saveIP();
2899 builder.SetInsertPoint(sourceBlock);
2900 llvmVal = builder.CreateLoad(ty, llvmVal);
2901 builder.restoreIP(curInsert);
2902 alignedVars[llvmVal] = alignment;
2906 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2911 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2913 ompBuilder->applySimd(loopInfo, alignedVars,
2915 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2917 order, simdlen, safelen);
2925 static LogicalResult
2929 auto loopOp = cast<omp::LoopNestOp>(opInst);
2932 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2937 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2938 llvm::Value *iv) -> llvm::Error {
2941 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2946 bodyInsertPoints.push_back(ip);
2948 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2949 return llvm::Error::success();
2952 builder.restoreIP(ip);
2954 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
2956 return regionBlock.takeError();
2958 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2959 return llvm::Error::success();
2967 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2968 llvm::Value *lowerBound =
2969 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
2970 llvm::Value *upperBound =
2971 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
2972 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
2977 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2978 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2980 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2982 computeIP = loopInfos.front()->getPreheaderIP();
2986 ompBuilder->createCanonicalLoop(
2987 loc, bodyGen, lowerBound, upperBound, step,
2988 true, loopOp.getLoopInclusive(), computeIP);
2993 loopInfos.push_back(*loopResult);
2998 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2999 loopInfos.front()->getAfterIP();
3003 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3004 [&](OpenMPLoopInfoStackFrame &frame) {
3005 frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
3013 builder.restoreIP(afterIP);
3018 static llvm::AtomicOrdering
3021 return llvm::AtomicOrdering::Monotonic;
3024 case omp::ClauseMemoryOrderKind::Seq_cst:
3025 return llvm::AtomicOrdering::SequentiallyConsistent;
3026 case omp::ClauseMemoryOrderKind::Acq_rel:
3027 return llvm::AtomicOrdering::AcquireRelease;
3028 case omp::ClauseMemoryOrderKind::Acquire:
3029 return llvm::AtomicOrdering::Acquire;
3030 case omp::ClauseMemoryOrderKind::Release:
3031 return llvm::AtomicOrdering::Release;
3032 case omp::ClauseMemoryOrderKind::Relaxed:
3033 return llvm::AtomicOrdering::Monotonic;
3035 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3039 static LogicalResult
3042 auto readOp = cast<omp::AtomicReadOp>(opInst);
3047 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3050 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3053 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3054 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3056 llvm::Type *elementType =
3057 moduleTranslation.
convertType(readOp.getElementType());
3059 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3060 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3061 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3066 static LogicalResult
3069 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3074 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3077 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3079 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3080 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3081 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3082 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3085 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3093 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3094 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3095 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3096 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3097 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3098 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3099 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3100 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3101 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3102 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3106 static LogicalResult
3108 llvm::IRBuilderBase &builder,
3115 auto &innerOpList = opInst.getRegion().front().getOperations();
3116 bool isXBinopExpr{
false};
3117 llvm::AtomicRMWInst::BinOp binop;
3119 llvm::Value *llvmExpr =
nullptr;
3120 llvm::Value *llvmX =
nullptr;
3121 llvm::Type *llvmXElementType =
nullptr;
3122 if (innerOpList.size() == 2) {
3128 opInst.getRegion().getArgument(0))) {
3129 return opInst.emitError(
"no atomic update operation with region argument"
3130 " as operand found inside atomic.update region");
3133 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3135 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3139 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3141 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3143 opInst.getRegion().getArgument(0).getType());
3144 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3148 llvm::AtomicOrdering atomicOrdering =
3153 [&opInst, &moduleTranslation](
3154 llvm::Value *atomicx,
3157 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3158 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3159 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3160 return llvm::make_error<PreviouslyReportedError>();
3162 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3163 assert(yieldop && yieldop.getResults().size() == 1 &&
3164 "terminator must be omp.yield op and it must have exactly one "
3166 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3171 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3173 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3174 atomicOrdering, binop, updateFn,
3180 builder.restoreIP(*afterIP);
3184 static LogicalResult
3186 llvm::IRBuilderBase &builder,
3193 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3194 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3196 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3197 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3199 assert((atomicUpdateOp || atomicWriteOp) &&
3200 "internal op must be an atomic.update or atomic.write op");
3202 if (atomicWriteOp) {
3203 isPostfixUpdate =
true;
3204 mlirExpr = atomicWriteOp.getExpr();
3206 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3207 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3208 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3211 if (innerOpList.size() == 2) {
3214 atomicUpdateOp.getRegion().getArgument(0))) {
3215 return atomicUpdateOp.emitError(
3216 "no atomic update operation with region argument"
3217 " as operand found inside atomic.update region");
3221 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3224 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3228 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3229 llvm::Value *llvmX =
3230 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3231 llvm::Value *llvmV =
3232 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3233 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3234 atomicCaptureOp.getAtomicReadOp().getElementType());
3235 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3238 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3242 llvm::AtomicOrdering atomicOrdering =
3246 [&](llvm::Value *atomicx,
3249 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3250 Block &bb = *atomicUpdateOp.getRegion().
begin();
3251 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3253 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3254 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3255 return llvm::make_error<PreviouslyReportedError>();
3257 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3258 assert(yieldop && yieldop.getResults().size() == 1 &&
3259 "terminator must be omp.yield op and it must have exactly one "
3261 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3266 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3267 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3268 ompBuilder->createAtomicCapture(
3269 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3270 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
3272 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3275 builder.restoreIP(*afterIP);
3280 omp::ClauseCancellationConstructType directive) {
3281 switch (directive) {
3282 case omp::ClauseCancellationConstructType::Loop:
3283 return llvm::omp::Directive::OMPD_for;
3284 case omp::ClauseCancellationConstructType::Parallel:
3285 return llvm::omp::Directive::OMPD_parallel;
3286 case omp::ClauseCancellationConstructType::Sections:
3287 return llvm::omp::Directive::OMPD_sections;
3288 case omp::ClauseCancellationConstructType::Taskgroup:
3289 return llvm::omp::Directive::OMPD_taskgroup;
3293 static LogicalResult
3299 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3302 llvm::Value *ifCond =
nullptr;
3303 if (
Value ifVar = op.getIfExpr())
3306 llvm::omp::Directive cancelledDirective =
3309 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3310 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3312 if (failed(
handleError(afterIP, *op.getOperation())))
3315 builder.restoreIP(afterIP.get());
3320 static LogicalResult
3322 llvm::IRBuilderBase &builder,
3327 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3330 llvm::omp::Directive cancelledDirective =
3333 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3334 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3336 if (failed(
handleError(afterIP, *op.getOperation())))
3339 builder.restoreIP(afterIP.get());
3346 static LogicalResult
3349 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3351 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3356 Value symAddr = threadprivateOp.getSymAddr();
3359 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3362 if (!isa<LLVM::AddressOfOp>(symOp))
3363 return opInst.
emitError(
"Addressing symbol not found");
3364 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3366 LLVM::GlobalOp global =
3367 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3368 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3370 if (!ompBuilder->Config.isTargetDevice()) {
3371 llvm::Type *type = globalValue->getValueType();
3372 llvm::TypeSize typeSize =
3373 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3375 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3376 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3377 ompLoc, globalValue, size, global.getSymName() +
".cache");
3386 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3388 switch (deviceClause) {
3389 case mlir::omp::DeclareTargetDeviceType::host:
3390 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3392 case mlir::omp::DeclareTargetDeviceType::nohost:
3393 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3395 case mlir::omp::DeclareTargetDeviceType::any:
3396 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3399 llvm_unreachable(
"unhandled device clause");
3402 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3404 mlir::omp::DeclareTargetCaptureClause captureClause) {
3405 switch (captureClause) {
3406 case mlir::omp::DeclareTargetCaptureClause::to:
3407 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3408 case mlir::omp::DeclareTargetCaptureClause::link:
3409 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3410 case mlir::omp::DeclareTargetCaptureClause::enter:
3411 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3413 llvm_unreachable(
"unhandled capture clause");
3418 llvm::OpenMPIRBuilder &ompBuilder) {
3420 llvm::raw_svector_ostream os(suffix);
3423 auto fileInfoCallBack = [&loc]() {
3424 return std::pair<std::string, uint64_t>(
3425 llvm::StringRef(loc.getFilename()), loc.getLine());
3429 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3431 os <<
"_decl_tgt_ref_ptr";
3437 if (
auto addressOfOp =
3438 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3439 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3440 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3441 if (
auto declareTargetGlobal =
3442 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3443 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3444 mlir::omp::DeclareTargetCaptureClause::link)
3453 static llvm::Value *
3460 if (
auto addressOfOp =
3461 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3462 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3463 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3464 addressOfOp.getGlobalName()))) {
3466 if (
auto declareTargetGlobal =
3467 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3468 gOp.getOperation())) {
3472 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3473 mlir::omp::DeclareTargetCaptureClause::link) ||
3474 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3475 mlir::omp::DeclareTargetCaptureClause::to &&
3476 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3480 if (gOp.getSymName().contains(suffix))
3485 (gOp.getSymName().str() + suffix.str()).str());
3496 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3500 void append(MapInfosTy &curInfo) {
3501 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3502 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3511 struct MapInfoData : MapInfosTy {
3523 void append(MapInfoData &CurInfo) {
3524 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3525 CurInfo.IsDeclareTarget.end());
3526 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3527 OriginalValue.append(CurInfo.OriginalValue.begin(),
3528 CurInfo.OriginalValue.end());
3529 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3530 MapInfosTy::append(CurInfo);
3536 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3537 arrTy.getElementType()))
3553 Operation *clauseOp, llvm::Value *basePointer,
3554 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3556 if (
auto memberClause =
3557 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3562 if (!memberClause.getBounds().empty()) {
3563 llvm::Value *elementCount = builder.getInt64(1);
3564 for (
auto bounds : memberClause.getBounds()) {
3565 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3566 bounds.getDefiningOp())) {
3571 elementCount = builder.CreateMul(
3575 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3576 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3577 builder.getInt64(1)));
3584 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3592 return builder.CreateMul(elementCount,
3593 builder.getInt64(underlyingTypeSzInBits / 8));
3606 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3614 for (
Value mapValue : mapVars) {
3615 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3616 for (
auto member : map.getMembers())
3617 if (member == mapOp)
3624 for (
Value mapValue : mapVars) {
3625 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3627 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3628 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
3629 mapData.Pointers.push_back(mapData.OriginalValue.back());
3631 if (llvm::Value *refPtr =
3633 moduleTranslation)) {
3634 mapData.IsDeclareTarget.push_back(
true);
3635 mapData.BasePointers.push_back(refPtr);
3637 mapData.IsDeclareTarget.push_back(
false);
3638 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3641 mapData.BaseType.push_back(
3642 moduleTranslation.
convertType(mapOp.getVarType()));
3643 mapData.Sizes.push_back(
3644 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3645 mapData.BaseType.back(), builder, moduleTranslation));
3646 mapData.MapClause.push_back(mapOp.getOperation());
3647 mapData.Types.push_back(
3648 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3652 if (mapOp.getMapperId())
3653 mapData.Mappers.push_back(
3654 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3655 mapOp, mapOp.getMapperIdAttr()));
3657 mapData.Mappers.push_back(
nullptr);
3658 mapData.IsAMapping.push_back(
true);
3659 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3662 auto findMapInfo = [&mapData](llvm::Value *val,
3663 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3666 for (llvm::Value *basePtr : mapData.OriginalValue) {
3667 if (basePtr == val && mapData.IsAMapping[index]) {
3669 mapData.Types[index] |=
3670 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3671 mapData.DevicePointers[index] = devInfoTy;
3680 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3681 for (
Value mapValue : useDevOperands) {
3682 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3684 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3685 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3688 if (!findMapInfo(origValue, devInfoTy)) {
3689 mapData.OriginalValue.push_back(origValue);
3690 mapData.Pointers.push_back(mapData.OriginalValue.back());
3691 mapData.IsDeclareTarget.push_back(
false);
3692 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3693 mapData.BaseType.push_back(
3694 moduleTranslation.
convertType(mapOp.getVarType()));
3695 mapData.Sizes.push_back(builder.getInt64(0));
3696 mapData.MapClause.push_back(mapOp.getOperation());
3697 mapData.Types.push_back(
3698 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3701 mapData.DevicePointers.push_back(devInfoTy);
3702 mapData.Mappers.push_back(
nullptr);
3703 mapData.IsAMapping.push_back(
false);
3704 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3709 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3710 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3712 for (
Value mapValue : hasDevAddrOperands) {
3713 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3715 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3716 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3718 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3719 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3721 mapData.OriginalValue.push_back(origValue);
3722 mapData.BasePointers.push_back(origValue);
3723 mapData.Pointers.push_back(origValue);
3724 mapData.IsDeclareTarget.push_back(
false);
3725 mapData.BaseType.push_back(
3726 moduleTranslation.
convertType(mapOp.getVarType()));
3727 mapData.Sizes.push_back(
3728 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3729 mapData.MapClause.push_back(mapOp.getOperation());
3730 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3734 mapData.Types.push_back(mapType);
3738 if (mapOp.getMapperId()) {
3739 mapData.Mappers.push_back(
3740 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3741 mapOp, mapOp.getMapperIdAttr()));
3743 mapData.Mappers.push_back(
nullptr);
3746 mapData.Types.push_back(
3747 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3748 mapData.Mappers.push_back(
nullptr);
3752 mapData.DevicePointers.push_back(
3753 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3754 mapData.IsAMapping.push_back(
false);
3755 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3760 auto *res = llvm::find(mapData.MapClause, memberOp);
3761 assert(res != mapData.MapClause.end() &&
3762 "MapInfoOp for member not found in MapData, cannot return index");
3763 return std::distance(mapData.MapClause.begin(), res);
3768 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3770 if (indexAttr.size() == 1)
3771 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3774 std::iota(indices.begin(), indices.end(), 0);
3776 llvm::sort(indices.begin(), indices.end(),
3777 [&](
const size_t a,
const size_t b) {
3778 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3779 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3780 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3781 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3782 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3784 if (aIndex == bIndex)
3787 if (aIndex < bIndex)
3790 if (aIndex > bIndex)
3797 return memberIndicesA.size() < memberIndicesB.size();
3800 return llvm::cast<omp::MapInfoOp>(
3801 mapInfo.getMembers()[indices.front()].getDefiningOp());
3823 std::vector<llvm::Value *>
3825 llvm::IRBuilderBase &builder,
bool isArrayTy,
3827 std::vector<llvm::Value *> idx;
3838 idx.push_back(builder.getInt64(0));
3839 for (
int i = bounds.size() - 1; i >= 0; --i) {
3840 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3841 bounds[i].getDefiningOp())) {
3842 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
3864 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3865 for (
size_t i = 1; i < bounds.size(); ++i) {
3866 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3867 bounds[i].getDefiningOp())) {
3868 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3869 moduleTranslation.
lookupValue(boundOp.getExtent()),
3870 dimensionIndexSizeOffset[i - 1]));
3878 for (
int i = bounds.size() - 1; i >= 0; --i) {
3879 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3880 bounds[i].getDefiningOp())) {
3882 idx.emplace_back(builder.CreateMul(
3883 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
3884 dimensionIndexSizeOffset[i]));
3886 idx.back() = builder.CreateAdd(
3887 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
3888 boundOp.getLowerBound()),
3889 dimensionIndexSizeOffset[i]));
3914 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3915 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
3916 assert(!ompBuilder.Config.isTargetDevice() &&
3917 "function only supported for host device codegen");
3920 combinedInfo.Types.emplace_back(
3922 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3923 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3924 combinedInfo.DevicePointers.emplace_back(
3925 mapData.DevicePointers[mapDataIndex]);
3926 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
3928 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3929 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3939 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3941 llvm::Value *lowAddr, *highAddr;
3942 if (!parentClause.getPartialMap()) {
3943 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3944 builder.getPtrTy());
3945 highAddr = builder.CreatePointerCast(
3946 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3947 mapData.Pointers[mapDataIndex], 1),
3948 builder.getPtrTy());
3949 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3951 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3954 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3955 builder.getPtrTy());
3958 highAddr = builder.CreatePointerCast(
3959 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3960 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3961 builder.getPtrTy());
3962 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3965 llvm::Value *size = builder.CreateIntCast(
3966 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3967 builder.getInt64Ty(),
3969 combinedInfo.Sizes.push_back(size);
3971 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3972 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3980 if (!parentClause.getPartialMap()) {
3985 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3986 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3987 combinedInfo.Types.emplace_back(mapFlag);
3988 combinedInfo.DevicePointers.emplace_back(
3990 combinedInfo.Mappers.emplace_back(
nullptr);
3992 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3993 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3994 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3995 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3997 return memberOfFlag;
4009 if (mapOp.getVarPtrPtr())
4024 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4025 MapInfoData &mapData, uint64_t mapDataIndex,
4026 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4027 assert(!ompBuilder.Config.isTargetDevice() &&
4028 "function only supported for host device codegen");
4031 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4033 for (
auto mappedMembers : parentClause.getMembers()) {
4035 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4038 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4049 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4050 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4051 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4052 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4053 combinedInfo.Types.emplace_back(mapFlag);
4054 combinedInfo.DevicePointers.emplace_back(
4056 combinedInfo.Mappers.emplace_back(
nullptr);
4057 combinedInfo.Names.emplace_back(
4059 combinedInfo.BasePointers.emplace_back(
4060 mapData.BasePointers[mapDataIndex]);
4061 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4062 combinedInfo.Sizes.emplace_back(builder.getInt64(
4063 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
4069 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4070 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4071 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4072 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4074 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4076 combinedInfo.Types.emplace_back(mapFlag);
4077 combinedInfo.DevicePointers.emplace_back(
4078 mapData.DevicePointers[memberDataIdx]);
4079 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4080 combinedInfo.Names.emplace_back(
4082 uint64_t basePointerIndex =
4084 combinedInfo.BasePointers.emplace_back(
4085 mapData.BasePointers[basePointerIndex]);
4086 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4088 llvm::Value *size = mapData.Sizes[memberDataIdx];
4090 size = builder.CreateSelect(
4091 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4092 builder.getInt64(0), size);
4095 combinedInfo.Sizes.emplace_back(size);
4100 MapInfosTy &combinedInfo,
bool isTargetParams,
4101 int mapDataParentIdx = -1) {
4105 auto mapFlag = mapData.Types[mapDataIdx];
4106 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4110 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4112 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4113 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4115 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4117 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4122 if (mapDataParentIdx >= 0)
4123 combinedInfo.BasePointers.emplace_back(
4124 mapData.BasePointers[mapDataParentIdx]);
4126 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4128 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4129 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4130 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4131 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4132 combinedInfo.Types.emplace_back(mapFlag);
4133 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4137 llvm::IRBuilderBase &builder,
4138 llvm::OpenMPIRBuilder &ompBuilder,
4140 MapInfoData &mapData, uint64_t mapDataIndex,
4141 bool isTargetParams) {
4142 assert(!ompBuilder.Config.isTargetDevice() &&
4143 "function only supported for host device codegen");
4146 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4151 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4152 auto memberClause = llvm::cast<omp::MapInfoOp>(
4153 parentClause.getMembers()[0].getDefiningOp());
4170 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4172 combinedInfo, mapData, mapDataIndex, isTargetParams);
4174 combinedInfo, mapData, mapDataIndex,
4175 memberOfParentFlag);
4185 llvm::IRBuilderBase &builder) {
4187 "function only supported for host device codegen");
4188 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4190 if (!mapData.IsDeclareTarget[i]) {
4191 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4192 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4202 switch (captureKind) {
4203 case omp::VariableCaptureKind::ByRef: {
4204 llvm::Value *newV = mapData.Pointers[i];
4206 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4209 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4211 if (!offsetIdx.empty())
4212 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4214 mapData.Pointers[i] = newV;
4216 case omp::VariableCaptureKind::ByCopy: {
4217 llvm::Type *type = mapData.BaseType[i];
4219 if (mapData.Pointers[i]->getType()->isPointerTy())
4220 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4222 newV = mapData.Pointers[i];
4225 auto curInsert = builder.saveIP();
4227 auto *memTempAlloc =
4228 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4229 builder.restoreIP(curInsert);
4231 builder.CreateStore(newV, memTempAlloc);
4232 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4235 mapData.Pointers[i] = newV;
4236 mapData.BasePointers[i] = newV;
4238 case omp::VariableCaptureKind::This:
4239 case omp::VariableCaptureKind::VLAType:
4240 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4251 MapInfoData &mapData,
bool isTargetParams =
false) {
4253 "function only supported for host device codegen");
4275 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4278 if (mapData.IsAMember[i])
4281 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4282 if (!mapInfoOp.getMembers().empty()) {
4284 combinedInfo, mapData, i, isTargetParams);
4295 llvm::StringRef mapperFuncName);
4301 "function only supported for host device codegen");
4302 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4303 std::string mapperFuncName =
4305 {
"omp_mapper", declMapperOp.getSymName()});
4307 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4317 llvm::StringRef mapperFuncName) {
4319 "function only supported for host device codegen");
4320 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4321 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4324 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4327 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4330 MapInfosTy combinedInfo;
4332 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4333 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4334 builder.restoreIP(codeGenIP);
4335 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4336 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4337 builder.GetInsertBlock());
4338 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4341 return llvm::make_error<PreviouslyReportedError>();
4342 MapInfoData mapData;
4345 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4350 return combinedInfo;
4354 if (!combinedInfo.Mappers[i])
4361 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4363 return newFn.takeError();
4364 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4368 static LogicalResult
4371 llvm::Value *ifCond =
nullptr;
4372 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4376 llvm::omp::RuntimeFunction RTLFn;
4380 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4383 LogicalResult result =
4385 .Case([&](omp::TargetDataOp dataOp) {
4389 if (
auto ifVar = dataOp.getIfExpr())
4392 if (
auto devId = dataOp.getDevice())
4394 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4395 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4396 deviceID = intAttr.getInt();
4398 mapVars = dataOp.getMapVars();
4399 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4400 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4403 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4407 if (
auto ifVar = enterDataOp.getIfExpr())
4410 if (
auto devId = enterDataOp.getDevice())
4412 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4413 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4414 deviceID = intAttr.getInt();
4416 enterDataOp.getNowait()
4417 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4418 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4419 mapVars = enterDataOp.getMapVars();
4420 info.HasNoWait = enterDataOp.getNowait();
4423 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4427 if (
auto ifVar = exitDataOp.getIfExpr())
4430 if (
auto devId = exitDataOp.getDevice())
4432 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4433 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4434 deviceID = intAttr.getInt();
4436 RTLFn = exitDataOp.getNowait()
4437 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4438 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4439 mapVars = exitDataOp.getMapVars();
4440 info.HasNoWait = exitDataOp.getNowait();
4443 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4447 if (
auto ifVar = updateDataOp.getIfExpr())
4450 if (
auto devId = updateDataOp.getDevice())
4452 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4453 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4454 deviceID = intAttr.getInt();
4457 updateDataOp.getNowait()
4458 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4459 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4460 mapVars = updateDataOp.getMapVars();
4461 info.HasNoWait = updateDataOp.getNowait();
4465 llvm_unreachable(
"unexpected operation");
4472 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4473 MapInfoData mapData;
4475 builder, useDevicePtrVars, useDeviceAddrVars);
4478 MapInfosTy combinedInfo;
4479 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4480 builder.restoreIP(codeGenIP);
4481 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4482 return combinedInfo;
4488 [&moduleTranslation](
4489 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4493 for (
auto [arg, useDevVar] :
4494 llvm::zip_equal(blockArgs, useDeviceVars)) {
4496 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4497 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4498 : mapInfoOp.getVarPtr();
4501 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4502 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4503 mapInfoData.MapClause, mapInfoData.DevicePointers,
4504 mapInfoData.BasePointers)) {
4505 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4506 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4507 devicePointer != type)
4510 if (llvm::Value *devPtrInfoMap =
4511 mapper ? mapper(basePointer) : basePointer) {
4512 moduleTranslation.
mapValue(arg, devPtrInfoMap);
4519 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4520 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4521 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4522 builder.restoreIP(codeGenIP);
4523 assert(isa<omp::TargetDataOp>(op) &&
4524 "BodyGen requested for non TargetDataOp");
4525 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4526 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4527 switch (bodyGenType) {
4528 case BodyGenTy::Priv:
4530 if (!info.DevicePtrInfoMap.empty()) {
4531 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4532 blockArgIface.getUseDeviceAddrBlockArgs(),
4533 useDeviceAddrVars, mapData,
4534 [&](llvm::Value *basePointer) -> llvm::Value * {
4535 if (!info.DevicePtrInfoMap[basePointer].second)
4537 return builder.CreateLoad(
4539 info.DevicePtrInfoMap[basePointer].second);
4541 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4542 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4543 mapData, [&](llvm::Value *basePointer) {
4544 return info.DevicePtrInfoMap[basePointer].second;
4548 moduleTranslation)))
4549 return llvm::make_error<PreviouslyReportedError>();
4552 case BodyGenTy::DupNoPriv:
4555 builder.restoreIP(codeGenIP);
4557 case BodyGenTy::NoPriv:
4559 if (info.DevicePtrInfoMap.empty()) {
4562 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4563 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4564 blockArgIface.getUseDeviceAddrBlockArgs(),
4565 useDeviceAddrVars, mapData);
4566 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4567 blockArgIface.getUseDevicePtrBlockArgs(),
4568 useDevicePtrVars, mapData);
4572 moduleTranslation)))
4573 return llvm::make_error<PreviouslyReportedError>();
4577 return builder.saveIP();
4580 auto customMapperCB =
4582 if (!combinedInfo.Mappers[i])
4584 info.HasMapper =
true;
4589 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4590 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4592 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4593 if (isa<omp::TargetDataOp>(op))
4594 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4595 builder.getInt64(deviceID), ifCond,
4596 info, genMapInfoCB, customMapperCB,
4599 return ompBuilder->createTargetData(
4600 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4601 info, genMapInfoCB, customMapperCB, &RTLFn);
4607 builder.restoreIP(*afterIP);
4611 static LogicalResult
4615 auto distributeOp = cast<omp::DistributeOp>(opInst);
4622 bool doDistributeReduction =
4626 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4631 if (doDistributeReduction) {
4632 isByRef =
getIsByRef(teamsOp.getReductionByref());
4633 assert(isByRef.size() == teamsOp.getNumReductionVars());
4636 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4640 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4641 .getReductionBlockArgs();
4644 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4645 reductionDecls, privateReductionVariables, reductionVariableMap,
4650 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4651 auto bodyGenCB = [&](InsertPointTy allocaIP,
4652 InsertPointTy codeGenIP) -> llvm::Error {
4656 moduleTranslation, allocaIP);
4659 builder.restoreIP(codeGenIP);
4665 return llvm::make_error<PreviouslyReportedError>();
4670 return llvm::make_error<PreviouslyReportedError>();
4673 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
4675 distributeOp.getPrivateNeedsBarrier())))
4676 return llvm::make_error<PreviouslyReportedError>();
4679 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4682 builder, moduleTranslation);
4684 return regionBlock.takeError();
4685 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4690 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4693 auto schedule = omp::ClauseScheduleKind::Static;
4694 bool isOrdered =
false;
4695 std::optional<omp::ScheduleModifier> scheduleMod;
4696 bool isSimd =
false;
4697 llvm::omp::WorksharingLoopType workshareLoopType =
4698 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4699 bool loopNeedsBarrier =
false;
4700 llvm::Value *chunk =
nullptr;
4702 llvm::CanonicalLoopInfo *loopInfo =
4704 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4705 ompBuilder->applyWorkshareLoop(
4706 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4707 convertToScheduleKind(schedule), chunk, isSimd,
4708 scheduleMod == omp::ScheduleModifier::monotonic,
4709 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4713 return wsloopIP.takeError();
4717 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4719 return llvm::make_error<PreviouslyReportedError>();
4721 return llvm::Error::success();
4724 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4726 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4727 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4728 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4733 builder.restoreIP(*afterIP);
4735 if (doDistributeReduction) {
4738 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4739 privateReductionVariables, isByRef,
4750 if (!cast<mlir::ModuleOp>(op))
4755 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
4756 attribute.getOpenmpDeviceVersion());
4758 if (attribute.getNoGpuLib())
4761 ompBuilder->createGlobalFlag(
4762 attribute.getDebugKind() ,
4763 "__omp_rtl_debug_kind");
4764 ompBuilder->createGlobalFlag(
4766 .getAssumeTeamsOversubscription()
4768 "__omp_rtl_assume_teams_oversubscription");
4769 ompBuilder->createGlobalFlag(
4771 .getAssumeThreadsOversubscription()
4773 "__omp_rtl_assume_threads_oversubscription");
4774 ompBuilder->createGlobalFlag(
4775 attribute.getAssumeNoThreadState() ,
4776 "__omp_rtl_assume_no_thread_state");
4777 ompBuilder->createGlobalFlag(
4779 .getAssumeNoNestedParallelism()
4781 "__omp_rtl_assume_no_nested_parallelism");
4786 omp::TargetOp targetOp,
4787 llvm::StringRef parentName =
"") {
4788 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
4790 assert(fileLoc &&
"No file found from location");
4791 StringRef fileName = fileLoc.getFilename().getValue();
4793 llvm::sys::fs::UniqueID id;
4794 uint64_t line = fileLoc.getLine();
4795 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
4797 size_t deviceId = 0xdeadf17e;
4799 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4801 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
4802 id.getFile(), line);
4809 llvm::IRBuilderBase &builder, llvm::Function *func) {
4811 "function only supported for target device codegen");
4812 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4825 if (mapData.IsDeclareTarget[i]) {
4832 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
4833 convertUsersOfConstantsToInstructions(constant, func,
false);
4840 for (llvm::User *user : mapData.OriginalValue[i]->users())
4841 userVec.push_back(user);
4843 for (llvm::User *user : userVec) {
4844 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
4845 if (insn->getFunction() == func) {
4846 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
4847 mapData.BasePointers[i]);
4848 load->moveBefore(insn->getIterator());
4849 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
4896 static llvm::IRBuilderBase::InsertPoint
4898 llvm::Value *input, llvm::Value *&retVal,
4899 llvm::IRBuilderBase &builder,
4900 llvm::OpenMPIRBuilder &ompBuilder,
4902 llvm::IRBuilderBase::InsertPoint allocaIP,
4903 llvm::IRBuilderBase::InsertPoint codeGenIP) {
4904 assert(ompBuilder.Config.isTargetDevice() &&
4905 "function only supported for target device codegen");
4906 builder.restoreIP(allocaIP);
4908 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
4910 ompBuilder.M.getContext());
4911 unsigned alignmentValue = 0;
4913 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
4914 if (mapData.OriginalValue[i] == input) {
4915 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4916 capture = mapOp.getMapCaptureType();
4919 mapOp.getVarType(), ompBuilder.M.getDataLayout());
4923 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
4924 unsigned int defaultAS =
4925 ompBuilder.M.getDataLayout().getProgramAddressSpace();
4928 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
4930 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
4931 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
4933 builder.CreateStore(&arg, v);
4935 builder.restoreIP(codeGenIP);
4938 case omp::VariableCaptureKind::ByCopy: {
4942 case omp::VariableCaptureKind::ByRef: {
4943 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
4945 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
4960 if (v->getType()->isPointerTy() && alignmentValue) {
4961 llvm::MDBuilder MDB(builder.getContext());
4962 loadInst->setMetadata(
4963 llvm::LLVMContext::MD_align,
4966 llvm::Type::getInt64Ty(builder.getContext()),
4973 case omp::VariableCaptureKind::This:
4974 case omp::VariableCaptureKind::VLAType:
4977 assert(
false &&
"Currently unsupported capture kind");
4981 return builder.saveIP();
4998 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4999 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5000 blockArgIface.getHostEvalBlockArgs())) {
5001 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5005 .Case([&](omp::TeamsOp teamsOp) {
5006 if (teamsOp.getNumTeamsLower() == blockArg)
5007 numTeamsLower = hostEvalVar;
5008 else if (teamsOp.getNumTeamsUpper() == blockArg)
5009 numTeamsUpper = hostEvalVar;
5010 else if (teamsOp.getThreadLimit() == blockArg)
5011 threadLimit = hostEvalVar;
5013 llvm_unreachable(
"unsupported host_eval use");
5015 .Case([&](omp::ParallelOp parallelOp) {
5016 if (parallelOp.getNumThreads() == blockArg)
5017 numThreads = hostEvalVar;
5019 llvm_unreachable(
"unsupported host_eval use");
5021 .Case([&](omp::LoopNestOp loopOp) {
5022 auto processBounds =
5027 if (lb == blockArg) {
5030 (*outBounds)[i] = hostEvalVar;
5036 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5037 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5039 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5041 assert(found &&
"unsupported host_eval use");
5044 llvm_unreachable(
"unsupported host_eval use");
5057 template <
typename OpTy>
5062 if (OpTy casted = dyn_cast<OpTy>(op))
5065 if (immediateParent)
5066 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5075 return std::nullopt;
5078 dyn_cast_if_present<LLVM::ConstantOp>(value.
getDefiningOp()))
5079 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5080 return constAttr.getInt();
5082 return std::nullopt;
5087 uint64_t sizeInBytes = sizeInBits / 8;
5091 template <
typename OpTy>
5093 if (op.getNumReductionVars() > 0) {
5098 members.reserve(reductions.size());
5099 for (omp::DeclareReductionOp &red : reductions)
5100 members.push_back(red.getType());
5102 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5118 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5119 bool isTargetDevice,
bool isGPU) {
5122 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5123 if (!isTargetDevice) {
5130 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5131 numTeamsLower = teamsOp.getNumTeamsLower();
5132 numTeamsUpper = teamsOp.getNumTeamsUpper();
5133 threadLimit = teamsOp.getThreadLimit();
5136 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5137 numThreads = parallelOp.getNumThreads();
5142 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5143 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5146 if (numTeamsUpper) {
5148 minTeamsVal = maxTeamsVal = *val;
5150 minTeamsVal = maxTeamsVal = 0;
5152 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5154 castOrGetParentOfType<omp::SimdOp>(capturedOp,
5156 minTeamsVal = maxTeamsVal = 1;
5158 minTeamsVal = maxTeamsVal = -1;
5163 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
5177 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5178 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5179 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5182 int32_t maxThreadsVal = -1;
5183 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5184 setMaxValueFromClause(numThreads, maxThreadsVal);
5185 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5192 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5193 if (combinedMaxThreadsVal < 0 ||
5194 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5195 combinedMaxThreadsVal = teamsThreadLimitVal;
5197 if (combinedMaxThreadsVal < 0 ||
5198 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5199 combinedMaxThreadsVal = maxThreadsVal;
5201 int32_t reductionDataSize = 0;
5202 if (isGPU && capturedOp) {
5203 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5208 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5210 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5211 omp::TargetRegionFlags::spmd) &&
5212 "invalid kernel flags");
5214 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5215 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5216 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5217 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5218 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5219 attrs.MinTeams = minTeamsVal;
5220 attrs.MaxTeams.front() = maxTeamsVal;
5221 attrs.MinThreads = 1;
5222 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5223 attrs.ReductionDataSize = reductionDataSize;
5226 if (attrs.ReductionDataSize != 0)
5227 attrs.ReductionBufferLength = 1024;
5239 omp::TargetOp targetOp,
Operation *capturedOp,
5240 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5241 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5242 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5244 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5248 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5251 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5252 attrs.TargetThreadLimit.front() =
5256 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5259 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5261 if (teamsThreadLimit)
5262 attrs.TeamsThreadLimit.front() =
5266 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5268 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5269 omp::TargetRegionFlags::trip_count)) {
5271 attrs.LoopTripCount =
nullptr;
5276 for (
auto [loopLower, loopUpper, loopStep] :
5277 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5278 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5279 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5280 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5282 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5283 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5284 loc, lowerBound, upperBound, step,
true,
5285 loopOp.getLoopInclusive());
5287 if (!attrs.LoopTripCount) {
5288 attrs.LoopTripCount = tripCount;
5293 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5299 static LogicalResult
5302 auto targetOp = cast<omp::TargetOp>(opInst);
5307 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5308 bool isGPU = ompBuilder->Config.isGPU();
5311 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5312 auto &targetRegion = targetOp.getRegion();
5329 llvm::Function *llvmOutlinedFn =
nullptr;
5333 bool isOffloadEntry =
5334 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5341 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5343 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5344 std::optional<DenseI64ArrayAttr> privateMapIndices =
5345 targetOp.getPrivateMapsAttr();
5347 for (
auto [privVarIdx, privVarSymPair] :
5349 auto privVar = std::get<0>(privVarSymPair);
5350 auto privSym = std::get<1>(privVarSymPair);
5352 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5353 omp::PrivateClauseOp privatizer =
5356 if (!privatizer.needsMap())
5360 targetOp.getMappedValueForPrivateVar(privVarIdx);
5361 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5362 "variable that needs mapping");
5367 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5368 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5372 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5374 varType == privVar.getType() &&
5375 "Type of private var doesn't match the type of the mapped value");
5379 mappedPrivateVars.insert(
5381 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5382 (*privateMapIndices)[privVarIdx])});
5386 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5387 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5388 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5389 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5390 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5393 llvm::Function *llvmParentFn =
5395 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5396 assert(llvmParentFn && llvmOutlinedFn &&
5397 "Both parent and outlined functions must exist at this point");
5399 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5400 attr.isStringAttribute())
5401 llvmOutlinedFn->addFnAttr(attr);
5403 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5404 attr.isStringAttribute())
5405 llvmOutlinedFn->addFnAttr(attr);
5407 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5408 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5409 llvm::Value *mapOpValue =
5410 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5411 moduleTranslation.
mapValue(arg, mapOpValue);
5413 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5414 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5415 llvm::Value *mapOpValue =
5416 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5417 moduleTranslation.
mapValue(arg, mapOpValue);
5426 allocaIP, &mappedPrivateVars);
5429 return llvm::make_error<PreviouslyReportedError>();
5431 builder.restoreIP(codeGenIP);
5433 &mappedPrivateVars),
5436 return llvm::make_error<PreviouslyReportedError>();
5439 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
5441 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5442 return llvm::make_error<PreviouslyReportedError>();
5446 std::back_inserter(privateCleanupRegions),
5447 [](omp::PrivateClauseOp privatizer) {
5448 return &privatizer.getDeallocRegion();
5452 targetRegion,
"omp.target", builder, moduleTranslation);
5455 return exitBlock.takeError();
5457 builder.SetInsertPoint(*exitBlock);
5458 if (!privateCleanupRegions.empty()) {
5460 privateCleanupRegions, privateVarsInfo.
llvmVars,
5461 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5463 return llvm::createStringError(
5464 "failed to inline `dealloc` region of `omp.private` "
5465 "op in the target region");
5467 return builder.saveIP();
5470 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5473 StringRef parentName = parentFn.getName();
5475 llvm::TargetRegionEntryInfo entryInfo;
5479 MapInfoData mapData;
5484 MapInfosTy combinedInfos;
5486 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5487 builder.restoreIP(codeGenIP);
5488 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5489 return combinedInfos;
5492 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5493 llvm::Value *&retVal, InsertPointTy allocaIP,
5494 InsertPointTy codeGenIP)
5495 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5496 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5497 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5503 if (!isTargetDevice) {
5504 retVal = cast<llvm::Value>(&arg);
5509 *ompBuilder, moduleTranslation,
5510 allocaIP, codeGenIP);
5513 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5514 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5515 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5517 isTargetDevice, isGPU);
5521 if (!isTargetDevice)
5523 targetCapturedOp, runtimeAttrs);
5531 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5532 llvm::Value *value = moduleTranslation.
lookupValue(var);
5533 moduleTranslation.
mapValue(arg, value);
5535 if (!llvm::isa<llvm::Constant>(value))
5536 kernelInput.push_back(value);
5539 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5546 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5547 kernelInput.push_back(mapData.OriginalValue[i]);
5552 moduleTranslation, dds);
5554 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5556 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5558 llvm::OpenMPIRBuilder::TargetDataInfo info(
5562 auto customMapperCB =
5564 if (!combinedInfos.Mappers[i])
5566 info.HasMapper =
true;
5571 llvm::Value *ifCond =
nullptr;
5572 if (
Value targetIfCond = targetOp.getIfExpr())
5573 ifCond = moduleTranslation.
lookupValue(targetIfCond);
5575 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5577 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5578 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5579 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5584 builder.restoreIP(*afterIP);
5595 static LogicalResult
5605 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5606 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5608 if (!offloadMod.getIsTargetDevice())
5611 omp::DeclareTargetDeviceType declareType =
5612 attribute.getDeviceType().getValue();
5614 if (declareType == omp::DeclareTargetDeviceType::host) {
5615 llvm::Function *llvmFunc =
5617 llvmFunc->dropAllReferences();
5618 llvmFunc->eraseFromParent();
5624 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5625 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
5626 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5628 bool isDeclaration = gOp.isDeclaration();
5629 bool isExternallyVisible =
5632 llvm::StringRef mangledName = gOp.getSymName();
5633 auto captureClause =
5639 std::vector<llvm::GlobalVariable *> generatedRefs;
5641 std::vector<llvm::Triple> targetTriple;
5642 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5644 LLVM::LLVMDialect::getTargetTripleAttrName()));
5645 if (targetTripleAttr)
5646 targetTriple.emplace_back(targetTripleAttr.data());
5648 auto fileInfoCallBack = [&loc]() {
5649 std::string filename =
"";
5650 std::uint64_t lineNo = 0;
5653 filename = loc.getFilename().str();
5654 lineNo = loc.getLine();
5657 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5661 ompBuilder->registerTargetGlobalVariable(
5662 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5663 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5664 generatedRefs,
false, targetTriple,
5666 gVal->getType(), gVal);
5668 if (ompBuilder->Config.isTargetDevice() &&
5669 (attribute.getCaptureClause().getValue() !=
5670 mlir::omp::DeclareTargetCaptureClause::to ||
5671 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5672 ompBuilder->getAddrOfDeclareTargetVar(
5673 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5674 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5675 generatedRefs,
false, targetTriple, gVal->getType(),
5697 if (mlir::isa<omp::ThreadprivateOp>(op))
5701 if (
auto declareTargetIface =
5702 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5703 parentFn.getOperation()))
5704 if (declareTargetIface.isDeclareTarget() &&
5705 declareTargetIface.getDeclareTargetDeviceType() !=
5706 mlir::omp::DeclareTargetDeviceType::host)
5714 static LogicalResult
5725 bool isOutermostLoopWrapper =
5726 isa_and_present<omp::LoopWrapperInterface>(op) &&
5727 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
5729 if (isOutermostLoopWrapper)
5730 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
5734 .Case([&](omp::BarrierOp op) -> LogicalResult {
5738 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5739 ompBuilder->createBarrier(builder.saveIP(),
5740 llvm::omp::OMPD_barrier);
5742 if (res.succeeded()) {
5745 builder.restoreIP(*afterIP);
5749 .Case([&](omp::TaskyieldOp op) {
5753 ompBuilder->createTaskyield(builder.saveIP());
5756 .Case([&](omp::FlushOp op) {
5768 ompBuilder->createFlush(builder.saveIP());
5771 .Case([&](omp::ParallelOp op) {
5774 .Case([&](omp::MaskedOp) {
5777 .Case([&](omp::MasterOp) {
5780 .Case([&](omp::CriticalOp) {
5783 .Case([&](omp::OrderedRegionOp) {
5786 .Case([&](omp::OrderedOp) {
5789 .Case([&](omp::WsloopOp) {
5792 .Case([&](omp::SimdOp) {
5795 .Case([&](omp::AtomicReadOp) {
5798 .Case([&](omp::AtomicWriteOp) {
5801 .Case([&](omp::AtomicUpdateOp op) {
5804 .Case([&](omp::AtomicCaptureOp op) {
5807 .Case([&](omp::CancelOp op) {
5810 .Case([&](omp::CancellationPointOp op) {
5813 .Case([&](omp::SectionsOp) {
5816 .Case([&](omp::SingleOp op) {
5819 .Case([&](omp::TeamsOp op) {
5822 .Case([&](omp::TaskOp op) {
5825 .Case([&](omp::TaskgroupOp op) {
5828 .Case([&](omp::TaskwaitOp op) {
5831 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5832 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5833 omp::CriticalDeclareOp>([](
auto op) {
5846 .Case([&](omp::ThreadprivateOp) {
5849 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
5850 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
5853 .Case([&](omp::TargetOp) {
5856 .Case([&](omp::DistributeOp) {
5859 .Case([&](omp::LoopNestOp) {
5862 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
5871 <<
"not yet implemented: " << inst->
getName();
5874 if (isOutermostLoopWrapper)
5880 static LogicalResult
5886 static LogicalResult
5889 if (isa<omp::TargetOp>(op))
5891 if (isa<omp::TargetDataOp>(op))
5895 if (isa<omp::TargetOp>(oper)) {
5897 return WalkResult::interrupt();
5898 return WalkResult::skip();
5900 if (isa<omp::TargetDataOp>(oper)) {
5902 return WalkResult::interrupt();
5903 return WalkResult::skip();
5910 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5911 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5912 !oper->getRegions().empty()) {
5913 if (
auto blockArgsIface =
5914 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5920 if (isa<mlir::omp::AtomicUpdateOp>(oper))
5921 for (
auto [operand, arg] :
5922 llvm::zip_equal(oper->getOperands(),
5923 oper->getRegion(0).getArguments())) {
5925 arg, builder.CreateLoad(
5931 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5932 assert(builder.GetInsertBlock() &&
5933 "No insert block is set for the builder");
5934 for (
auto iv : loopNest.getIVs()) {
5942 for (
Region ®ion : oper->getRegions()) {
5949 region, oper->getName().getStringRef().str() +
".fake.region",
5950 builder, moduleTranslation, &phis);
5952 return WalkResult::interrupt();
5954 builder.SetInsertPoint(result.get(), result.get()->end());
5957 return WalkResult::skip();
5960 return WalkResult::advance();
5961 }).wasInterrupted();
5962 return failure(interrupted);
5969 class OpenMPDialectLLVMIRTranslationInterface
5990 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
5996 .Case(
"omp.is_target_device",
5998 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
5999 llvm::OpenMPIRBuilderConfig &
config =
6001 config.setIsTargetDevice(deviceAttr.getValue());
6008 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6009 llvm::OpenMPIRBuilderConfig &
config =
6011 config.setIsGPU(gpuAttr.getValue());
6016 .Case(
"omp.host_ir_filepath",
6018 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6019 llvm::OpenMPIRBuilder *ompBuilder =
6021 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
6028 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6032 .Case(
"omp.version",
6034 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6035 llvm::OpenMPIRBuilder *ompBuilder =
6037 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6038 versionAttr.getVersion());
6043 .Case(
"omp.declare_target",
6045 if (
auto declareTargetAttr =
6046 dyn_cast<omp::DeclareTargetAttr>(attr))
6051 .Case(
"omp.requires",
6053 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6054 using Requires = omp::ClauseRequires;
6055 Requires flags = requiresAttr.getValue();
6056 llvm::OpenMPIRBuilderConfig &
config =
6058 config.setHasRequiresReverseOffload(
6059 bitEnumContainsAll(flags, Requires::reverse_offload));
6060 config.setHasRequiresUnifiedAddress(
6061 bitEnumContainsAll(flags, Requires::unified_address));
6062 config.setHasRequiresUnifiedSharedMemory(
6063 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6064 config.setHasRequiresDynamicAllocators(
6065 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6070 .Case(
"omp.target_triples",
6072 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6073 llvm::OpenMPIRBuilderConfig &
config =
6075 config.TargetTriples.clear();
6076 config.TargetTriples.reserve(triplesAttr.size());
6077 for (
Attribute tripleAttr : triplesAttr) {
6078 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6079 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6097 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6098 Operation *op, llvm::IRBuilderBase &builder,
6102 if (ompBuilder->Config.isTargetDevice()) {
6113 registry.
insert<omp::OpenMPDialect>();
6115 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult convertIgnoredWrapper(omp::LoopWrapperInterface opInst, LLVM::ModuleTranslation &moduleTranslation)
Helper function to map block arguments defined by ignored loop wrappers to LLVM values and prevent an...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.