24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Frontend/OpenMP/OMPConstants.h"
28 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/DebugInfoMetadata.h"
31 #include "llvm/IR/DerivedTypes.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/MDBuilder.h"
34 #include "llvm/IR/ReplaceConstant.h"
35 #include "llvm/Support/FileSystem.h"
36 #include "llvm/Support/VirtualFileSystem.h"
37 #include "llvm/TargetParser/Triple.h"
38 #include "llvm/Transforms/Utils/ModuleUtils.h"
49 static llvm::omp::ScheduleKind
50 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
63 return llvm::omp::OMP_SCHEDULE_Runtime;
65 llvm_unreachable(
"unhandled schedule clause argument");
70 class OpenMPAllocaStackFrame
75 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
76 : allocaInsertPoint(allocaIP) {}
77 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
83 class OpenMPLoopInfoStackFrame
87 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
106 class PreviouslyReportedError
107 :
public llvm::ErrorInfo<PreviouslyReportedError> {
109 void log(raw_ostream &)
const override {
113 std::error_code convertToErrorCode()
const override {
115 "PreviouslyReportedError doesn't support ECError conversion");
133 class LinearClauseProcessor {
141 llvm::BasicBlock *linearFinalizationBB;
142 llvm::BasicBlock *linearExitBB;
143 llvm::BasicBlock *linearLastIterExitBB;
147 void createLinearVar(llvm::IRBuilderBase &builder,
148 LLVM::ModuleTranslation &moduleTranslation,
150 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
151 moduleTranslation.lookupValue(linearVar))) {
152 linearPreconditionVars.push_back(builder.CreateAlloca(
153 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_var"));
154 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
155 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_result");
156 linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
157 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
158 linearOrigVars.push_back(linearVarAlloca);
163 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
165 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
169 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
170 initLinearVar(llvm::IRBuilderBase &builder,
171 LLVM::ModuleTranslation &moduleTranslation,
172 llvm::BasicBlock *loopPreHeader) {
173 builder.SetInsertPoint(loopPreHeader->getTerminator());
174 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
175 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
176 linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
177 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
179 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
180 moduleTranslation.getOpenMPBuilder()->createBarrier(
181 builder.saveIP(), llvm::omp::OMPD_barrier);
182 return afterBarrierIP;
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
191 llvm::LoadInst *linearVarStart =
192 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
194 linearPreconditionVars[index]);
195 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
196 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
197 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
203 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
204 llvm::BasicBlock *loopExit) {
205 linearFinalizationBB = loopExit->splitBasicBlock(
206 loopExit->getTerminator(),
"omp_loop.linear_finalization");
207 linearExitBB = linearFinalizationBB->splitBasicBlock(
208 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
209 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
210 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
214 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
215 finalizeLinearVar(llvm::IRBuilderBase &builder,
216 LLVM::ModuleTranslation &moduleTranslation,
217 llvm::Value *lastIter) {
219 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
220 llvm::Value *loopLastIterLoad = builder.CreateLoad(
221 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
222 llvm::Value *isLast =
223 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
225 llvm::Type::getInt32Ty(builder.getContext()), 0));
227 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
228 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
229 llvm::LoadInst *linearVarTemp =
230 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
231 linearLoopBodyTemps[index]);
232 builder.CreateStore(linearVarTemp, linearOrigVars[index]);
238 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
239 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
240 linearFinalizationBB->getTerminator()->eraseFromParent();
242 builder.SetInsertPoint(linearExitBB->getTerminator());
243 return moduleTranslation.getOpenMPBuilder()->createBarrier(
244 builder.saveIP(), llvm::omp::OMPD_barrier);
249 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
252 for (llvm::User *user : linearOrigVal[varIndex]->users())
253 users.push_back(user);
254 for (
auto *user : users) {
255 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
256 if (userInst->getParent()->getName().str() == BBName)
257 user->replaceUsesOfWith(linearOrigVal[varIndex],
258 linearLoopBodyTemps[varIndex]);
269 SymbolRefAttr symbolName) {
270 omp::PrivateClauseOp privatizer =
271 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
273 assert(privatizer &&
"privatizer not found in the symbol table");
284 auto todo = [&op](StringRef clauseName) {
285 return op.
emitError() <<
"not yet implemented: Unhandled clause "
286 << clauseName <<
" in " << op.
getName()
290 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
291 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
292 result = todo(
"allocate");
294 auto checkBare = [&todo](
auto op, LogicalResult &result) {
296 result = todo(
"ompx_bare");
298 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
299 omp::ClauseCancellationConstructType cancelledDirective =
300 op.getCancelDirective();
303 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
310 if (isa_and_nonnull<omp::TaskloopOp>(parent))
311 result = todo(
"cancel directive inside of taskloop");
314 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
315 if (!op.getDependVars().empty() || op.getDependKinds())
316 result = todo(
"depend");
318 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
320 result = todo(
"device");
322 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
323 if (op.getDistScheduleChunkSize())
324 result = todo(
"dist_schedule with chunk_size");
326 auto checkHint = [](
auto op, LogicalResult &) {
330 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
331 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
332 op.getInReductionSyms())
333 result = todo(
"in_reduction");
335 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
336 if (!op.getIsDevicePtrVars().empty())
337 result = todo(
"is_device_ptr");
339 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
340 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
341 result = todo(
"linear");
343 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
345 result = todo(
"nowait");
347 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
348 if (op.getOrder() || op.getOrderMod())
349 result = todo(
"order");
351 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
352 if (op.getParLevelSimd())
353 result = todo(
"parallelization-level");
355 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
356 if (op.getPriority())
357 result = todo(
"priority");
359 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
360 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
362 if (!op.getPrivateVars().empty() && op.getNowait())
363 result = todo(
"privatization for deferred target tasks");
365 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
366 result = todo(
"privatization");
369 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
370 if (isa<omp::TeamsOp>(op))
371 if (!op.getReductionVars().empty() || op.getReductionByref() ||
372 op.getReductionSyms())
373 result = todo(
"reduction");
374 if (op.getReductionMod() &&
375 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
376 result = todo(
"reduction with modifier");
378 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
379 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
380 op.getTaskReductionSyms())
381 result = todo(
"task_reduction");
383 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
385 result = todo(
"untied");
388 LogicalResult result = success();
390 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
391 .Case([&](omp::CancellationPointOp op) {
392 checkCancelDirective(op, result);
394 .Case([&](omp::DistributeOp op) {
395 checkAllocate(op, result);
396 checkDistSchedule(op, result);
397 checkOrder(op, result);
399 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
400 .Case([&](omp::SectionsOp op) {
401 checkAllocate(op, result);
402 checkPrivate(op, result);
403 checkReduction(op, result);
405 .Case([&](omp::SingleOp op) {
406 checkAllocate(op, result);
407 checkPrivate(op, result);
409 .Case([&](omp::TeamsOp op) {
410 checkAllocate(op, result);
411 checkPrivate(op, result);
413 .Case([&](omp::TaskOp op) {
414 checkAllocate(op, result);
415 checkInReduction(op, result);
417 .Case([&](omp::TaskgroupOp op) {
418 checkAllocate(op, result);
419 checkTaskReduction(op, result);
421 .Case([&](omp::TaskwaitOp op) {
422 checkDepend(op, result);
423 checkNowait(op, result);
425 .Case([&](omp::TaskloopOp op) {
427 checkUntied(op, result);
428 checkPriority(op, result);
430 .Case([&](omp::WsloopOp op) {
431 checkAllocate(op, result);
432 checkLinear(op, result);
433 checkOrder(op, result);
434 checkReduction(op, result);
436 .Case([&](omp::ParallelOp op) {
437 checkAllocate(op, result);
438 checkReduction(op, result);
440 .Case([&](omp::SimdOp op) {
441 checkLinear(op, result);
442 checkReduction(op, result);
444 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
445 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
446 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
447 [&](
auto op) { checkDepend(op, result); })
448 .Case([&](omp::TargetOp op) {
449 checkAllocate(op, result);
450 checkBare(op, result);
451 checkDevice(op, result);
452 checkInReduction(op, result);
453 checkIsDevicePtr(op, result);
454 checkPrivate(op, result);
464 LogicalResult result = success();
466 llvm::handleAllErrors(
468 [&](
const PreviouslyReportedError &) { result = failure(); },
469 [&](
const llvm::ErrorInfoBase &err) {
476 template <
typename T>
486 static llvm::OpenMPIRBuilder::InsertPointTy
488 LLVM::ModuleTranslation &moduleTranslation) {
492 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
493 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
494 [&](OpenMPAllocaStackFrame &frame) {
495 allocaInsertPoint = frame.allocaInsertPoint;
503 allocaInsertPoint.getBlock()->getParent() ==
504 builder.GetInsertBlock()->getParent())
505 return allocaInsertPoint;
514 if (builder.GetInsertBlock() ==
515 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
516 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
517 "Assuming end of basic block");
518 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
519 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
520 builder.GetInsertBlock()->getNextNode());
521 builder.CreateBr(entryBB);
522 builder.SetInsertPoint(entryBB);
525 llvm::BasicBlock &funcEntryBlock =
526 builder.GetInsertBlock()->getParent()->getEntryBlock();
527 return llvm::OpenMPIRBuilder::InsertPointTy(
528 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
534 static llvm::CanonicalLoopInfo *
536 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
537 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
538 [&](OpenMPLoopInfoStackFrame &frame) {
539 loopInfo = frame.loopInfo;
551 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
552 LLVM::ModuleTranslation &moduleTranslation,
554 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
556 llvm::BasicBlock *continuationBlock =
557 splitBB(builder,
true,
"omp.region.cont");
558 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
560 llvm::LLVMContext &llvmContext = builder.getContext();
561 for (
Block &bb : region) {
562 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
563 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
564 builder.GetInsertBlock()->getNextNode());
565 moduleTranslation.mapBlock(&bb, llvmBB);
568 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
575 unsigned numYields = 0;
577 if (!isLoopWrapper) {
578 bool operandsProcessed =
false;
580 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
581 if (!operandsProcessed) {
582 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
583 continuationBlockPHITypes.push_back(
584 moduleTranslation.convertType(yield->getOperand(i).getType()));
586 operandsProcessed =
true;
588 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
589 "mismatching number of values yielded from the region");
590 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
591 llvm::Type *operandType =
592 moduleTranslation.convertType(yield->getOperand(i).getType());
594 assert(continuationBlockPHITypes[i] == operandType &&
595 "values of mismatching types yielded from the region");
605 if (!continuationBlockPHITypes.empty())
607 continuationBlockPHIs &&
608 "expected continuation block PHIs if converted regions yield values");
609 if (continuationBlockPHIs) {
610 llvm::IRBuilderBase::InsertPointGuard guard(builder);
611 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
612 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
613 for (llvm::Type *ty : continuationBlockPHITypes)
614 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
620 for (
Block *bb : blocks) {
621 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
624 if (bb->isEntryBlock()) {
625 assert(sourceTerminator->getNumSuccessors() == 1 &&
626 "provided entry block has multiple successors");
627 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
628 "ContinuationBlock is not the successor of the entry block");
629 sourceTerminator->setSuccessor(0, llvmBB);
632 llvm::IRBuilderBase::InsertPointGuard guard(builder);
634 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
635 return llvm::make_error<PreviouslyReportedError>();
640 builder.CreateBr(continuationBlock);
651 Operation *terminator = bb->getTerminator();
652 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
653 builder.CreateBr(continuationBlock);
655 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
656 (*continuationBlockPHIs)[i]->addIncoming(
657 moduleTranslation.lookupValue(terminator->
getOperand(i)), llvmBB);
668 moduleTranslation.forgetMapping(region);
670 return continuationBlock;
676 case omp::ClauseProcBindKind::Close:
677 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
678 case omp::ClauseProcBindKind::Master:
679 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
680 case omp::ClauseProcBindKind::Primary:
681 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
682 case omp::ClauseProcBindKind::Spread:
683 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
685 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
694 static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
695 omp::BlockArgOpenMPOpInterface blockArgIface) {
697 blockArgIface.getBlockArgsPairs(blockArgsPairs);
698 for (
auto [var, arg] : blockArgsPairs)
699 moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
705 LLVM::ModuleTranslation &moduleTranslation) {
706 auto maskedOp = cast<omp::MaskedOp>(opInst);
707 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
712 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
714 auto ®ion = maskedOp.getRegion();
715 builder.restoreIP(codeGenIP);
723 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
725 llvm::Value *filterVal =
nullptr;
726 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
727 filterVal = moduleTranslation.lookupValue(filterVar);
729 llvm::LLVMContext &llvmContext = builder.getContext();
733 assert(filterVal !=
nullptr);
734 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
735 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
736 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
742 builder.restoreIP(*afterIP);
749 LLVM::ModuleTranslation &moduleTranslation) {
750 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
751 auto masterOp = cast<omp::MasterOp>(opInst);
756 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
758 auto ®ion = masterOp.getRegion();
759 builder.restoreIP(codeGenIP);
767 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
769 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
771 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
777 builder.restoreIP(*afterIP);
784 LLVM::ModuleTranslation &moduleTranslation) {
785 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
786 auto criticalOp = cast<omp::CriticalOp>(opInst);
791 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
793 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
794 builder.restoreIP(codeGenIP);
802 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
804 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
805 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
806 llvm::Constant *hint =
nullptr;
809 if (criticalOp.getNameAttr()) {
812 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
813 auto criticalDeclareOp =
814 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
818 static_cast<int>(criticalDeclareOp.getHint()));
820 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
821 moduleTranslation.getOpenMPBuilder()->createCritical(
822 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
827 builder.restoreIP(*afterIP);
834 template <
typename OP>
837 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
838 mlirVars.reserve(blockArgs.size());
839 llvmVars.reserve(blockArgs.size());
840 collectPrivatizationDecls<OP>(op);
843 mlirVars.push_back(privateVar);
855 void collectPrivatizationDecls(OP op) {
856 std::optional<ArrayAttr> attr = op.getPrivateSyms();
860 privatizers.reserve(privatizers.size() + attr->size());
861 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
868 template <
typename T>
872 std::optional<ArrayAttr> attr = op.getReductionSyms();
876 reductions.reserve(reductions.size() + op.getNumReductionVars());
877 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
878 reductions.push_back(
879 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
890 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
891 LLVM::ModuleTranslation &moduleTranslation,
899 llvm::Instruction *potentialTerminator =
900 builder.GetInsertBlock()->empty() ? nullptr
901 : &builder.GetInsertBlock()->back();
903 if (potentialTerminator && potentialTerminator->isTerminator())
904 potentialTerminator->removeFromParent();
905 moduleTranslation.mapBlock(®ion.
front(), builder.GetInsertBlock());
907 if (
failed(moduleTranslation.convertBlock(
908 region.
front(),
true, builder)))
912 if (continuationBlockArgs)
914 *continuationBlockArgs,
919 moduleTranslation.forgetMapping(region);
921 if (potentialTerminator && potentialTerminator->isTerminator()) {
922 llvm::BasicBlock *block = builder.GetInsertBlock();
923 if (block->empty()) {
929 potentialTerminator->insertInto(block, block->begin());
931 potentialTerminator->insertAfter(&block->back());
945 if (continuationBlockArgs)
946 llvm::append_range(*continuationBlockArgs, phis);
947 builder.SetInsertPoint(*continuationBlock,
948 (*continuationBlock)->getFirstInsertionPt());
955 using OwningReductionGen =
956 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
957 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
959 using OwningAtomicReductionGen =
960 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
961 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
968 static OwningReductionGen
970 LLVM::ModuleTranslation &moduleTranslation) {
974 OwningReductionGen gen =
975 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
976 llvm::Value *lhs, llvm::Value *rhs,
977 llvm::Value *&result)
mutable
978 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
979 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
980 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
981 builder.restoreIP(insertPoint);
984 "omp.reduction.nonatomic.body", builder,
985 moduleTranslation, &phis)))
986 return llvm::createStringError(
987 "failed to inline `combiner` region of `omp.declare_reduction`");
988 result = llvm::getSingleElement(phis);
989 return builder.saveIP();
998 static OwningAtomicReductionGen
1000 llvm::IRBuilderBase &builder,
1001 LLVM::ModuleTranslation &moduleTranslation) {
1002 if (decl.getAtomicReductionRegion().empty())
1003 return OwningAtomicReductionGen();
1008 OwningAtomicReductionGen atomicGen =
1009 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1010 llvm::Value *lhs, llvm::Value *rhs)
mutable
1011 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1012 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1013 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1014 builder.restoreIP(insertPoint);
1017 "omp.reduction.atomic.body", builder,
1018 moduleTranslation, &phis)))
1019 return llvm::createStringError(
1020 "failed to inline `atomic` region of `omp.declare_reduction`");
1021 assert(phis.empty());
1022 return builder.saveIP();
1028 static LogicalResult
1030 LLVM::ModuleTranslation &moduleTranslation) {
1031 auto orderedOp = cast<omp::OrderedOp>(opInst);
1036 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1037 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1038 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1040 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1042 size_t indexVecValues = 0;
1043 while (indexVecValues < vecValues.size()) {
1045 storeValues.reserve(numLoops);
1046 for (
unsigned i = 0; i < numLoops; i++) {
1047 storeValues.push_back(vecValues[indexVecValues]);
1050 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1052 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1053 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1054 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1061 static LogicalResult
1063 LLVM::ModuleTranslation &moduleTranslation) {
1064 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1065 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1070 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1072 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1073 builder.restoreIP(codeGenIP);
1081 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1084 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1085 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1086 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1091 builder.restoreIP(*afterIP);
1097 struct DeferredStore {
1098 DeferredStore(llvm::Value *value, llvm::Value *address)
1099 : value(value), address(address) {}
1102 llvm::Value *address;
1109 template <
typename T>
1110 static LogicalResult
1112 llvm::IRBuilderBase &builder,
1113 LLVM::ModuleTranslation &moduleTranslation,
1114 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1120 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1121 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1124 deferredStores.reserve(loop.getNumReductionVars());
1126 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1127 Region &allocRegion = reductionDecls[i].getAllocRegion();
1129 if (allocRegion.
empty())
1134 builder, moduleTranslation, &phis)))
1135 return loop.emitError(
1136 "failed to inline `alloc` region of `omp.declare_reduction`");
1138 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1139 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1143 llvm::Value *var = builder.CreateAlloca(
1144 moduleTranslation.convertType(reductionDecls[i].getType()));
1146 llvm::Type *ptrTy = builder.getPtrTy();
1147 llvm::Value *castVar =
1148 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1149 llvm::Value *castPhi =
1150 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1152 deferredStores.emplace_back(castPhi, castVar);
1154 privateReductionVariables[i] = castVar;
1155 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1156 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1158 assert(allocRegion.
empty() &&
1159 "allocaction is implicit for by-val reduction");
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.convertType(reductionDecls[i].getType()));
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1167 moduleTranslation.mapValue(reductionArgs[i], castVar);
1168 privateReductionVariables[i] = castVar;
1169 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1177 template <
typename T>
1184 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1185 Region &initializerRegion = reduction.getInitializerRegion();
1188 mlir::Value mlirSource = loop.getReductionVars()[i];
1189 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1190 assert(llvmSource &&
"lookup reduction var");
1191 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
1194 llvm::Value *allocation =
1195 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1196 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1202 llvm::BasicBlock *block =
nullptr) {
1203 if (block ==
nullptr)
1204 block = builder.GetInsertBlock();
1206 if (block->empty() || block->getTerminator() ==
nullptr)
1207 builder.SetInsertPoint(block);
1209 builder.SetInsertPoint(block->getTerminator());
1217 template <
typename OP>
1218 static LogicalResult
1220 llvm::IRBuilderBase &builder,
1221 LLVM::ModuleTranslation &moduleTranslation,
1222 llvm::BasicBlock *latestAllocaBlock,
1228 if (op.getNumReductionVars() == 0)
1231 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1232 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1233 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1234 builder.restoreIP(allocaIP);
1237 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1239 if (!reductionDecls[i].getAllocRegion().empty())
1245 byRefVars[i] = builder.CreateAlloca(
1246 moduleTranslation.convertType(reductionDecls[i].getType()));
1254 for (
auto [data, addr] : deferredStores)
1255 builder.CreateStore(data, addr);
1260 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1265 reductionVariableMap, i);
1273 "omp.reduction.neutral", builder,
1274 moduleTranslation, &phis)))
1277 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1278 "reduction neutral element declaration region");
1283 if (!reductionDecls[i].getAllocRegion().empty())
1292 builder.CreateStore(phis[0], byRefVars[i]);
1294 privateReductionVariables[i] = byRefVars[i];
1295 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1296 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1299 builder.CreateStore(phis[0], privateReductionVariables[i]);
1306 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1313 template <
typename T>
1315 T loop, llvm::IRBuilderBase &builder,
1316 LLVM::ModuleTranslation &moduleTranslation,
1322 unsigned numReductions = loop.getNumReductionVars();
1324 for (
unsigned i = 0; i < numReductions; ++i) {
1325 owningReductionGens.push_back(
1327 owningAtomicReductionGens.push_back(
1332 reductionInfos.reserve(numReductions);
1333 for (
unsigned i = 0; i < numReductions; ++i) {
1334 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1335 if (owningAtomicReductionGens[i])
1336 atomicGen = owningAtomicReductionGens[i];
1337 llvm::Value *variable =
1338 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1339 reductionInfos.push_back(
1340 {moduleTranslation.convertType(reductionDecls[i].
getType()), variable,
1341 privateReductionVariables[i],
1342 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1343 owningReductionGens[i],
1344 nullptr, atomicGen});
1349 static LogicalResult
1352 LLVM::ModuleTranslation &moduleTranslation,
1353 llvm::IRBuilderBase &builder, StringRef regionName,
1354 bool shouldLoadCleanupRegionArg =
true) {
1356 if (cleanupRegion->empty())
1362 llvm::Instruction *potentialTerminator =
1363 builder.GetInsertBlock()->empty() ? nullptr
1364 : &builder.GetInsertBlock()->back();
1365 if (potentialTerminator && potentialTerminator->isTerminator())
1366 builder.SetInsertPoint(potentialTerminator);
1367 llvm::Value *privateVarValue =
1368 shouldLoadCleanupRegionArg
1369 ? builder.CreateLoad(
1371 privateVariables[i])
1372 : privateVariables[i];
1374 moduleTranslation.mapValue(entry.
getArgument(0), privateVarValue);
1377 moduleTranslation)))
1382 moduleTranslation.forgetMapping(*cleanupRegion);
1390 OP op, llvm::IRBuilderBase &builder,
1391 LLVM::ModuleTranslation &moduleTranslation,
1392 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1395 bool isNowait =
false,
bool isTeamsReduction =
false) {
1397 if (op.getNumReductionVars() == 0)
1404 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1409 owningReductionGens, owningAtomicReductionGens,
1410 privateReductionVariables, reductionInfos);
1415 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1416 builder.SetInsertPoint(tempTerminator);
1417 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1418 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1419 isByRef, isNowait, isTeamsReduction);
1424 if (!contInsertPoint->getBlock())
1425 return op->emitOpError() <<
"failed to convert reductions";
1427 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1428 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1433 tempTerminator->eraseFromParent();
1434 builder.restoreIP(*afterIP);
1438 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1439 [](omp::DeclareReductionOp reductionDecl) {
1440 return &reductionDecl.getCleanupRegion();
1443 moduleTranslation, builder,
1444 "omp.reduction.cleanup");
1455 template <
typename OP>
1458 LLVM::ModuleTranslation &moduleTranslation,
1459 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1464 if (op.getNumReductionVars() == 0)
1470 allocaIP, reductionDecls,
1471 privateReductionVariables, reductionVariableMap,
1472 deferredStores, isByRef)))
1476 allocaIP.getBlock(), reductionDecls,
1477 privateReductionVariables, reductionVariableMap,
1478 isByRef, deferredStores);
1488 static llvm::Value *
1490 LLVM::ModuleTranslation &moduleTranslation,
1492 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1493 return moduleTranslation.lookupValue(privateVar);
1495 Value blockArg = (*mappedPrivateVars)[privateVar];
1498 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1499 "A block argument corresponding to a mapped var should have "
1502 if (privVarType == blockArgType)
1503 return moduleTranslation.lookupValue(blockArg);
1509 if (!isa<LLVM::LLVMPointerType>(privVarType))
1510 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1511 moduleTranslation.lookupValue(blockArg));
1513 return moduleTranslation.lookupValue(privateVar);
1521 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1523 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1525 Region &initRegion = privDecl.getInitRegion();
1526 if (initRegion.
empty())
1527 return llvmPrivateVar;
1531 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1532 assert(nonPrivateVar);
1533 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1534 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1539 moduleTranslation, &phis)))
1540 return llvm::createStringError(
1541 "failed to inline `init` region of `omp.private`");
1543 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1548 moduleTranslation.forgetMapping(initRegion);
1558 LLVM::ModuleTranslation &moduleTranslation,
1562 return llvm::Error::success();
1564 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1570 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1572 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1573 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1576 return privVarOrErr.takeError();
1578 llvmPrivateVar = privVarOrErr.get();
1579 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1584 return llvm::Error::success();
1592 LLVM::ModuleTranslation &moduleTranslation,
1594 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1597 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1598 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1599 allocaTerminator->getIterator()),
1600 true, allocaTerminator->getStableDebugLoc(),
1601 "omp.region.after_alloca");
1603 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1605 allocaTerminator = allocaIP.getBlock()->getTerminator();
1606 builder.SetInsertPoint(allocaTerminator);
1608 assert(allocaTerminator->getNumSuccessors() == 1 &&
1609 "This is an unconditional branch created by splitBB");
1611 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1612 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1614 unsigned int allocaAS =
1615 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1616 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1618 .getProgramAddressSpace();
1620 for (
auto [privDecl, mlirPrivVar, blockArg] :
1623 llvm::Type *llvmAllocType =
1624 moduleTranslation.convertType(privDecl.getType());
1625 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1626 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1627 llvmAllocType,
nullptr,
"omp.private.alloc");
1628 if (allocaAS != defaultAS)
1629 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1630 builder.getPtrTy(defaultAS));
1632 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1635 return afterAllocas;
1640 LLVM::ModuleTranslation &moduleTranslation,
1646 bool needsFirstprivate =
1647 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1648 return privOp.getDataSharingType() ==
1649 omp::DataSharingClauseType::FirstPrivate;
1652 if (!needsFirstprivate)
1655 llvm::BasicBlock *copyBlock =
1656 splitBB(builder,
true,
"omp.private.copy");
1659 for (
auto [decl, mlirVar, llvmVar] :
1660 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1661 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1665 Region ©Region = decl.getCopyRegion();
1669 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1670 assert(nonPrivateVar);
1671 moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1674 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1678 moduleTranslation)))
1679 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1688 moduleTranslation.forgetMapping(copyRegion);
1691 if (insertBarrier) {
1692 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1693 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1694 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1702 static LogicalResult
1704 LLVM::ModuleTranslation &moduleTranslation,
Location loc,
1709 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1710 [](omp::PrivateClauseOp privatizer) {
1711 return &privatizer.getDeallocRegion();
1715 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1716 "omp.private.dealloc",
false)))
1717 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1718 "`omp.private` op in");
1730 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1737 static LogicalResult
1739 LLVM::ModuleTranslation &moduleTranslation) {
1740 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1741 using StorableBodyGenCallbackTy =
1742 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1744 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1750 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1754 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1758 sectionsOp.getNumReductionVars());
1762 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1765 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1766 reductionDecls, privateReductionVariables, reductionVariableMap,
1773 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1777 Region ®ion = sectionOp.getRegion();
1778 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1779 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1780 builder.restoreIP(codeGenIP);
1787 sectionsOp.getRegion().getNumArguments());
1788 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1789 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1790 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1792 moduleTranslation.mapValue(sectionArg, llvmVal);
1799 sectionCBs.push_back(sectionCB);
1805 if (sectionCBs.empty())
1808 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1813 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1814 llvm::Value &vPtr, llvm::Value *&replacementValue)
1815 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1816 replacementValue = &vPtr;
1822 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1826 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1827 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1828 moduleTranslation.getOpenMPBuilder()->createSections(
1829 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1830 sectionsOp.getNowait());
1835 builder.restoreIP(*afterIP);
1839 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1840 privateReductionVariables, isByRef, sectionsOp.getNowait());
1844 static LogicalResult
1846 LLVM::ModuleTranslation &moduleTranslation) {
1847 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1848 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1853 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1854 builder.restoreIP(codegenIP);
1856 builder, moduleTranslation)
1859 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1863 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1866 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1867 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1868 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1869 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1870 llvmCPFuncs.push_back(
1871 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1874 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1875 moduleTranslation.getOpenMPBuilder()->createSingle(
1876 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1882 builder.restoreIP(*afterIP);
1888 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1893 for (
auto ra : iface.getReductionBlockArgs())
1894 for (
auto &use : ra.getUses()) {
1895 auto *useOp = use.getOwner();
1897 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1898 debugUses.push_back(useOp);
1902 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1907 Operation *currentOp = currentDistOp.getOperation();
1908 if (distOp && (distOp != currentOp))
1917 for (
auto use : debugUses)
1923 static LogicalResult
1925 LLVM::ModuleTranslation &moduleTranslation) {
1926 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1931 unsigned numReductionVars = op.getNumReductionVars();
1935 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1941 if (doTeamsReduction) {
1942 isByRef =
getIsByRef(op.getReductionByref());
1944 assert(isByRef.size() == op.getNumReductionVars());
1947 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1952 op, reductionArgs, builder, moduleTranslation, allocaIP,
1953 reductionDecls, privateReductionVariables, reductionVariableMap,
1958 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1960 moduleTranslation, allocaIP);
1961 builder.restoreIP(codegenIP);
1967 llvm::Value *numTeamsLower =
nullptr;
1968 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1969 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
1971 llvm::Value *numTeamsUpper =
nullptr;
1972 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1973 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
1975 llvm::Value *threadLimit =
nullptr;
1976 if (
Value threadLimitVar = op.getThreadLimit())
1977 threadLimit = moduleTranslation.lookupValue(threadLimitVar);
1979 llvm::Value *ifExpr =
nullptr;
1980 if (
Value ifVar = op.getIfExpr())
1981 ifExpr = moduleTranslation.lookupValue(ifVar);
1983 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1984 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1985 moduleTranslation.getOpenMPBuilder()->createTeams(
1986 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1991 builder.restoreIP(*afterIP);
1992 if (doTeamsReduction) {
1995 op, builder, moduleTranslation, allocaIP, reductionDecls,
1996 privateReductionVariables, isByRef,
2004 LLVM::ModuleTranslation &moduleTranslation,
2006 if (dependVars.empty())
2008 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2009 llvm::omp::RTLDependenceKindTy type;
2011 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2012 case mlir::omp::ClauseTaskDepend::taskdependin:
2013 type = llvm::omp::RTLDependenceKindTy::DepIn;
2018 case mlir::omp::ClauseTaskDepend::taskdependout:
2019 case mlir::omp::ClauseTaskDepend::taskdependinout:
2020 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2022 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2023 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2025 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2026 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2029 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2030 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2031 dds.emplace_back(dd);
2043 llvm::IRBuilderBase &llvmBuilder,
2045 llvm::omp::Directive cancelDirective) {
2046 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2047 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2051 llvmBuilder.restoreIP(ip);
2057 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2058 return llvm::Error::success();
2063 ompBuilder.pushFinalizationCB(
2073 llvm::OpenMPIRBuilder &ompBuilder,
2074 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2075 ompBuilder.popFinalizationCB();
2076 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2077 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2078 assert(cancelBranch->getNumSuccessors() == 1 &&
2079 "cancel branch should have one target");
2080 cancelBranch->setSuccessor(0, constructFini);
2087 class TaskContextStructManager {
2089 TaskContextStructManager(llvm::IRBuilderBase &builder,
2090 LLVM::ModuleTranslation &moduleTranslation,
2092 : builder{builder}, moduleTranslation{moduleTranslation},
2093 privateDecls{privateDecls} {}
2099 void generateTaskContextStruct();
2105 void createGEPsToPrivateVars();
2108 void freeStructPtr();
2111 return llvmPrivateVarGEPs;
2114 llvm::Value *getStructPtr() {
return structPtr; }
2117 llvm::IRBuilderBase &builder;
2118 LLVM::ModuleTranslation &moduleTranslation;
2129 llvm::Value *structPtr =
nullptr;
2131 llvm::Type *structTy =
nullptr;
2135 void TaskContextStructManager::generateTaskContextStruct() {
2136 if (privateDecls.empty())
2138 privateVarTypes.reserve(privateDecls.size());
2140 for (omp::PrivateClauseOp &privOp : privateDecls) {
2143 if (!privOp.readsFromMold())
2145 Type mlirType = privOp.getType();
2146 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2152 llvm::DataLayout dataLayout =
2153 builder.GetInsertBlock()->getModule()->getDataLayout();
2154 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2155 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2158 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2160 "omp.task.context_ptr");
2163 void TaskContextStructManager::createGEPsToPrivateVars() {
2165 assert(privateVarTypes.empty());
2170 llvmPrivateVarGEPs.clear();
2171 llvmPrivateVarGEPs.reserve(privateDecls.size());
2172 llvm::Value *zero = builder.getInt32(0);
2174 for (
auto privDecl : privateDecls) {
2175 if (!privDecl.readsFromMold()) {
2177 llvmPrivateVarGEPs.push_back(
nullptr);
2180 llvm::Value *iVal = builder.getInt32(i);
2181 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2182 llvmPrivateVarGEPs.push_back(gep);
2187 void TaskContextStructManager::freeStructPtr() {
2191 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2193 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2194 builder.CreateFree(structPtr);
2198 static LogicalResult
2200 LLVM::ModuleTranslation &moduleTranslation) {
2201 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2206 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2218 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2223 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2224 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2225 builder.getContext(),
"omp.task.start",
2226 builder.GetInsertBlock()->getParent());
2227 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2228 builder.SetInsertPoint(branchToTaskStartBlock);
2231 llvm::BasicBlock *copyBlock =
2232 splitBB(builder,
true,
"omp.private.copy");
2233 llvm::BasicBlock *initBlock =
2234 splitBB(builder,
true,
"omp.private.init");
2250 moduleTranslation, allocaIP);
2253 builder.SetInsertPoint(initBlock->getTerminator());
2256 taskStructMgr.generateTaskContextStruct();
2263 taskStructMgr.createGEPsToPrivateVars();
2265 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2268 taskStructMgr.getLLVMPrivateVarGEPs())) {
2270 if (!privDecl.readsFromMold())
2272 assert(llvmPrivateVarAlloc &&
2273 "reads from mold so shouldn't have been skipped");
2276 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2277 blockArg, llvmPrivateVarAlloc, initBlock);
2278 if (!privateVarOrErr)
2279 return handleError(privateVarOrErr, *taskOp.getOperation());
2288 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2289 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2290 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2292 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2293 llvmPrivateVarAlloc);
2295 assert(llvmPrivateVarAlloc->getType() ==
2296 moduleTranslation.convertType(blockArg.getType()));
2306 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2307 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2308 taskOp.getPrivateNeedsBarrier())))
2309 return llvm::failure();
2312 builder.SetInsertPoint(taskStartBlock);
2314 auto bodyCB = [&](InsertPointTy allocaIP,
2315 InsertPointTy codegenIP) -> llvm::Error {
2319 moduleTranslation, allocaIP);
2322 builder.restoreIP(codegenIP);
2324 llvm::BasicBlock *privInitBlock =
nullptr;
2329 auto [blockArg, privDecl, mlirPrivVar] = zip;
2331 if (privDecl.readsFromMold())
2334 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2335 llvm::Type *llvmAllocType =
2336 moduleTranslation.convertType(privDecl.getType());
2337 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2338 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2339 llvmAllocType,
nullptr,
"omp.private.alloc");
2342 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2343 blockArg, llvmPrivateVar, privInitBlock);
2344 if (!privateVarOrError)
2345 return privateVarOrError.takeError();
2346 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2347 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2350 taskStructMgr.createGEPsToPrivateVars();
2351 for (
auto [i, llvmPrivVar] :
2354 assert(privateVarsInfo.
llvmVars[i] &&
2355 "This is added in the loop above");
2358 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2363 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2367 if (!privateDecl.readsFromMold())
2370 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2371 llvmPrivateVar = builder.CreateLoad(
2372 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2374 assert(llvmPrivateVar->getType() ==
2375 moduleTranslation.convertType(blockArg.getType()));
2376 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2380 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2382 return llvm::make_error<PreviouslyReportedError>();
2384 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2389 return llvm::make_error<PreviouslyReportedError>();
2392 taskStructMgr.freeStructPtr();
2394 return llvm::Error::success();
2397 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2403 llvm::omp::Directive::OMPD_taskgroup);
2407 moduleTranslation, dds);
2409 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2410 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2411 moduleTranslation.getOpenMPBuilder()->createTask(
2412 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2413 moduleTranslation.lookupValue(taskOp.getFinal()),
2414 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2415 taskOp.getMergeable(),
2416 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2417 moduleTranslation.lookupValue(taskOp.getPriority()));
2425 builder.restoreIP(*afterIP);
2430 static LogicalResult
2432 LLVM::ModuleTranslation &moduleTranslation) {
2433 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2437 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2438 builder.restoreIP(codegenIP);
2440 builder, moduleTranslation)
2445 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2446 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2447 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2453 builder.restoreIP(*afterIP);
2457 static LogicalResult
2459 LLVM::ModuleTranslation &moduleTranslation) {
2463 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2468 static LogicalResult
2470 LLVM::ModuleTranslation &moduleTranslation) {
2471 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2472 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2476 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2478 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2482 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2485 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2486 llvm::Type *ivType = step->getType();
2487 llvm::Value *chunk =
nullptr;
2488 if (wsloopOp.getScheduleChunk()) {
2489 llvm::Value *chunkVar =
2490 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2491 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2498 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2502 wsloopOp.getNumReductionVars());
2505 builder, moduleTranslation, privateVarsInfo, allocaIP);
2512 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2517 moduleTranslation, allocaIP, reductionDecls,
2518 privateReductionVariables, reductionVariableMap,
2519 deferredStores, isByRef)))
2528 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2530 wsloopOp.getPrivateNeedsBarrier())))
2533 assert(afterAllocas.get()->getSinglePredecessor());
2536 afterAllocas.get()->getSinglePredecessor(),
2537 reductionDecls, privateReductionVariables,
2538 reductionVariableMap, isByRef, deferredStores)))
2542 bool isOrdered = wsloopOp.getOrdered().has_value();
2543 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2544 bool isSimd = wsloopOp.getScheduleSimd();
2545 bool loopNeedsBarrier = !wsloopOp.getNowait();
2550 llvm::omp::WorksharingLoopType workshareLoopType =
2551 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2552 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2553 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2557 llvm::omp::Directive::OMPD_for);
2559 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2562 LinearClauseProcessor linearClauseProcessor;
2563 if (!wsloopOp.getLinearVars().empty()) {
2564 for (
mlir::Value linearVar : wsloopOp.getLinearVars())
2565 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2567 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2568 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2572 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2580 if (!wsloopOp.getLinearVars().empty()) {
2581 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2582 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2583 loopInfo->getPreheader());
2586 builder.restoreIP(*afterBarrierIP);
2587 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2588 loopInfo->getIndVar());
2589 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2590 loopInfo->getExit());
2593 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2596 bool noLoopMode =
false;
2597 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
2599 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
2603 if (loopOp == targetCapturedOp) {
2604 omp::TargetRegionFlags kernelFlags =
2605 targetOp.getKernelExecFlags(targetCapturedOp);
2606 if (omp::bitEnumContainsAll(kernelFlags,
2607 omp::TargetRegionFlags::spmd |
2608 omp::TargetRegionFlags::no_loop) &&
2609 !omp::bitEnumContainsAny(kernelFlags,
2610 omp::TargetRegionFlags::generic))
2615 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2616 ompBuilder->applyWorkshareLoop(
2617 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2618 convertToScheduleKind(schedule), chunk, isSimd,
2619 scheduleMod == omp::ScheduleModifier::monotonic,
2620 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2621 workshareLoopType, noLoopMode);
2627 if (!wsloopOp.getLinearVars().empty()) {
2628 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2629 assert(loopInfo->getLastIter() &&
2630 "`lastiter` in CanonicalLoopInfo is nullptr");
2631 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2632 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2633 loopInfo->getLastIter());
2636 for (
size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2637 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2639 builder.restoreIP(oldIP);
2647 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2648 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2658 static LogicalResult
2660 LLVM::ModuleTranslation &moduleTranslation) {
2661 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2663 assert(isByRef.size() == opInst.getNumReductionVars());
2664 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2675 opInst.getNumReductionVars());
2678 auto bodyGenCB = [&](InsertPointTy allocaIP,
2679 InsertPointTy codeGenIP) -> llvm::Error {
2681 builder, moduleTranslation, privateVarsInfo, allocaIP);
2683 return llvm::make_error<PreviouslyReportedError>();
2689 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2692 InsertPointTy(allocaIP.getBlock(),
2693 allocaIP.getBlock()->getTerminator()->getIterator());
2696 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2697 reductionDecls, privateReductionVariables, reductionVariableMap,
2698 deferredStores, isByRef)))
2699 return llvm::make_error<PreviouslyReportedError>();
2701 assert(afterAllocas.get()->getSinglePredecessor());
2702 builder.restoreIP(codeGenIP);
2708 return llvm::make_error<PreviouslyReportedError>();
2711 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2713 opInst.getPrivateNeedsBarrier())))
2714 return llvm::make_error<PreviouslyReportedError>();
2718 afterAllocas.get()->getSinglePredecessor(),
2719 reductionDecls, privateReductionVariables,
2720 reductionVariableMap, isByRef, deferredStores)))
2721 return llvm::make_error<PreviouslyReportedError>();
2726 moduleTranslation, allocaIP);
2730 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2732 return regionBlock.takeError();
2735 if (opInst.getNumReductionVars() > 0) {
2741 owningReductionGens, owningAtomicReductionGens,
2742 privateReductionVariables, reductionInfos);
2745 builder.SetInsertPoint((*regionBlock)->getTerminator());
2748 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2749 builder.SetInsertPoint(tempTerminator);
2751 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2752 ompBuilder->createReductions(
2753 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2755 if (!contInsertPoint)
2756 return contInsertPoint.takeError();
2758 if (!contInsertPoint->getBlock())
2759 return llvm::make_error<PreviouslyReportedError>();
2761 tempTerminator->eraseFromParent();
2762 builder.restoreIP(*contInsertPoint);
2765 return llvm::Error::success();
2768 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2769 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2778 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2779 InsertPointTy oldIP = builder.saveIP();
2780 builder.restoreIP(codeGenIP);
2785 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2786 [](omp::DeclareReductionOp reductionDecl) {
2787 return &reductionDecl.getCleanupRegion();
2790 reductionCleanupRegions, privateReductionVariables,
2791 moduleTranslation, builder,
"omp.reduction.cleanup")))
2792 return llvm::createStringError(
2793 "failed to inline `cleanup` region of `omp.declare_reduction`");
2798 return llvm::make_error<PreviouslyReportedError>();
2800 builder.restoreIP(oldIP);
2801 return llvm::Error::success();
2804 llvm::Value *ifCond =
nullptr;
2805 if (
auto ifVar = opInst.getIfExpr())
2806 ifCond = moduleTranslation.lookupValue(ifVar);
2807 llvm::Value *numThreads =
nullptr;
2808 if (
auto numThreadsVar = opInst.getNumThreads())
2809 numThreads = moduleTranslation.lookupValue(numThreadsVar);
2810 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2811 if (
auto bind = opInst.getProcBindKind())
2815 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2817 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2819 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2820 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2821 ifCond, numThreads, pbKind, isCancellable);
2826 builder.restoreIP(*afterIP);
2831 static llvm::omp::OrderKind
2834 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2836 case omp::ClauseOrderKind::Concurrent:
2837 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2839 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2843 static LogicalResult
2845 LLVM::ModuleTranslation &moduleTranslation) {
2846 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2847 auto simdOp = cast<omp::SimdOp>(opInst);
2855 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2858 simdOp.getNumReductionVars());
2863 assert(isByRef.size() == simdOp.getNumReductionVars());
2865 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2869 builder, moduleTranslation, privateVarsInfo, allocaIP);
2874 moduleTranslation, allocaIP, reductionDecls,
2875 privateReductionVariables, reductionVariableMap,
2876 deferredStores, isByRef)))
2887 assert(afterAllocas.get()->getSinglePredecessor());
2890 afterAllocas.get()->getSinglePredecessor(),
2891 reductionDecls, privateReductionVariables,
2892 reductionVariableMap, isByRef, deferredStores)))
2895 llvm::ConstantInt *simdlen =
nullptr;
2896 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2897 simdlen = builder.getInt64(simdlenVar.value());
2899 llvm::ConstantInt *safelen =
nullptr;
2900 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2901 safelen = builder.getInt64(safelenVar.value());
2903 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2906 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2907 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2909 for (
size_t i = 0; i < operands.size(); ++i) {
2910 llvm::Value *alignment =
nullptr;
2911 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
2912 llvm::Type *ty = llvmVal->
getType();
2914 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2915 alignment = builder.getInt64(intAttr.getInt());
2916 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2917 assert(alignment &&
"Invalid alignment value");
2921 if (!intAttr.getValue().isPowerOf2())
2924 auto curInsert = builder.saveIP();
2925 builder.SetInsertPoint(sourceBlock);
2926 llvmVal = builder.CreateLoad(ty, llvmVal);
2927 builder.restoreIP(curInsert);
2928 alignedVars[llvmVal] = alignment;
2932 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2937 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2939 ompBuilder->applySimd(loopInfo, alignedVars,
2941 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
2943 order, simdlen, safelen);
2950 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
2951 privateReductionVariables))) {
2952 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
2954 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
2955 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
2956 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
2960 llvm::Value *redValue = originalVariable;
2963 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
2964 llvm::Value *privateRedValue = builder.CreateLoad(
2965 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
2966 llvm::Value *reduced;
2968 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
2971 builder.restoreIP(res.get());
2975 builder.CreateStore(reduced, originalVariable);
2980 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
2981 [](omp::DeclareReductionOp reductionDecl) {
2982 return &reductionDecl.getCleanupRegion();
2985 moduleTranslation, builder,
2986 "omp.reduction.cleanup")))
2995 static LogicalResult
2997 LLVM::ModuleTranslation &moduleTranslation) {
2998 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2999 auto loopOp = cast<omp::LoopNestOp>(opInst);
3002 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3007 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3008 llvm::Value *iv) -> llvm::Error {
3010 moduleTranslation.mapValue(
3011 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3016 bodyInsertPoints.push_back(ip);
3018 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3019 return llvm::Error::success();
3022 builder.restoreIP(ip);
3024 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3026 return regionBlock.takeError();
3028 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3029 return llvm::Error::success();
3037 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3038 llvm::Value *lowerBound =
3039 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3040 llvm::Value *upperBound =
3041 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3042 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3047 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3048 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3050 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3052 computeIP = loopInfos.front()->getPreheaderIP();
3056 ompBuilder->createCanonicalLoop(
3057 loc, bodyGen, lowerBound, upperBound, step,
3058 true, loopOp.getLoopInclusive(), computeIP);
3063 loopInfos.push_back(*loopResult);
3066 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3067 loopInfos.front()->getAfterIP();
3070 if (
const auto &tiles = loopOp.getTileSizes()) {
3071 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3074 for (
auto tile : tiles.value()) {
3076 tileSizes.push_back(tileVal);
3079 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3080 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3084 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3085 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3086 afterIP = {afterAfterBB, afterAfterBB->begin()};
3090 for (
const auto &newLoop : newLoops)
3091 loopInfos.push_back(newLoop);
3095 const auto &numCollapse = loopOp.getCollapseNumLoops();
3097 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3099 auto newTopLoopInfo =
3100 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3102 assert(newTopLoopInfo &&
"New top loop information is missing");
3103 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3104 [&](OpenMPLoopInfoStackFrame &frame) {
3105 frame.loopInfo = newTopLoopInfo;
3113 builder.restoreIP(afterIP);
3118 static LogicalResult
3120 LLVM::ModuleTranslation &moduleTranslation) {
3121 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3123 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3124 Value loopIV = op.getInductionVar();
3125 Value loopTC = op.getTripCount();
3127 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3130 ompBuilder->createCanonicalLoop(
3132 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3135 moduleTranslation.mapValue(loopIV, llvmIV);
3137 builder.restoreIP(ip);
3142 return bodyGenStatus.takeError();
3144 llvmTC,
"omp.loop");
3148 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3149 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3150 builder.restoreIP(afterIP);
3153 if (
Value cli = op.getCli())
3154 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3161 static LogicalResult
3163 LLVM::ModuleTranslation &moduleTranslation) {
3164 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3166 Value applyee = op.getApplyee();
3167 assert(applyee &&
"Loop to apply unrolling on required");
3169 llvm::CanonicalLoopInfo *consBuilderCLI =
3170 moduleTranslation.lookupOMPLoop(applyee);
3171 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3172 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3174 moduleTranslation.invalidateOmpLoop(applyee);
3180 static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3181 LLVM::ModuleTranslation &moduleTranslation) {
3182 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3183 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3188 for (
Value size : op.getSizes()) {
3189 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3190 assert(translatedSize &&
3191 "sizes clause arguments must already be translated");
3192 translatedSizes.push_back(translatedSize);
3195 for (
Value applyee : op.getApplyees()) {
3196 llvm::CanonicalLoopInfo *consBuilderCLI =
3197 moduleTranslation.lookupOMPLoop(applyee);
3198 assert(applyee &&
"Canonical loop must already been translated");
3199 translatedLoops.push_back(consBuilderCLI);
3202 auto generatedLoops =
3203 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3204 if (!op.getGeneratees().empty()) {
3205 for (
auto [mlirLoop,
genLoop] :
3206 zip_equal(op.getGeneratees(), generatedLoops))
3207 moduleTranslation.mapOmpLoop(mlirLoop,
genLoop);
3211 for (
Value applyee : op.getApplyees())
3212 moduleTranslation.invalidateOmpLoop(applyee);
3218 static llvm::AtomicOrdering
3221 return llvm::AtomicOrdering::Monotonic;
3224 case omp::ClauseMemoryOrderKind::Seq_cst:
3225 return llvm::AtomicOrdering::SequentiallyConsistent;
3226 case omp::ClauseMemoryOrderKind::Acq_rel:
3227 return llvm::AtomicOrdering::AcquireRelease;
3228 case omp::ClauseMemoryOrderKind::Acquire:
3229 return llvm::AtomicOrdering::Acquire;
3230 case omp::ClauseMemoryOrderKind::Release:
3231 return llvm::AtomicOrdering::Release;
3232 case omp::ClauseMemoryOrderKind::Relaxed:
3233 return llvm::AtomicOrdering::Monotonic;
3235 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3239 static LogicalResult
3241 LLVM::ModuleTranslation &moduleTranslation) {
3242 auto readOp = cast<omp::AtomicReadOp>(opInst);
3246 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3247 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3250 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3253 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3254 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3256 llvm::Type *elementType =
3257 moduleTranslation.convertType(readOp.getElementType());
3259 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3260 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3261 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3266 static LogicalResult
3268 LLVM::ModuleTranslation &moduleTranslation) {
3269 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3273 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3274 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3277 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3279 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3280 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3281 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3282 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3285 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3293 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3294 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3295 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3296 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3297 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3298 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3299 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3300 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3301 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3302 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3306 bool &isIgnoreDenormalMode,
3307 bool &isFineGrainedMemory,
3308 bool &isRemoteMemory) {
3309 isIgnoreDenormalMode =
false;
3310 isFineGrainedMemory =
false;
3311 isRemoteMemory =
false;
3312 if (atomicUpdateOp &&
3313 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3314 mlir::omp::AtomicControlAttr atomicControlAttr =
3315 atomicUpdateOp.getAtomicControlAttr();
3316 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3317 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3318 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3323 static LogicalResult
3325 llvm::IRBuilderBase &builder,
3326 LLVM::ModuleTranslation &moduleTranslation) {
3327 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3332 auto &innerOpList = opInst.getRegion().front().getOperations();
3333 bool isXBinopExpr{
false};
3334 llvm::AtomicRMWInst::BinOp binop;
3336 llvm::Value *llvmExpr =
nullptr;
3337 llvm::Value *llvmX =
nullptr;
3338 llvm::Type *llvmXElementType =
nullptr;
3339 if (innerOpList.size() == 2) {
3345 opInst.getRegion().getArgument(0))) {
3346 return opInst.emitError(
"no atomic update operation with region argument"
3347 " as operand found inside atomic.update region");
3350 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3352 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3356 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3358 llvmX = moduleTranslation.lookupValue(opInst.getX());
3359 llvmXElementType = moduleTranslation.convertType(
3360 opInst.getRegion().getArgument(0).getType());
3361 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3365 llvm::AtomicOrdering atomicOrdering =
3370 [&opInst, &moduleTranslation](
3371 llvm::Value *atomicx,
3374 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3375 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3376 if (
failed(moduleTranslation.convertBlock(bb,
true, builder)))
3377 return llvm::make_error<PreviouslyReportedError>();
3379 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3380 assert(yieldop && yieldop.getResults().size() == 1 &&
3381 "terminator must be omp.yield op and it must have exactly one "
3383 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3386 bool isIgnoreDenormalMode;
3387 bool isFineGrainedMemory;
3388 bool isRemoteMemory;
3393 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3394 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3395 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3396 atomicOrdering, binop, updateFn,
3397 isXBinopExpr, isIgnoreDenormalMode,
3398 isFineGrainedMemory, isRemoteMemory);
3403 builder.restoreIP(*afterIP);
3407 static LogicalResult
3409 llvm::IRBuilderBase &builder,
3410 LLVM::ModuleTranslation &moduleTranslation) {
3411 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3416 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3417 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3419 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3420 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3422 assert((atomicUpdateOp || atomicWriteOp) &&
3423 "internal op must be an atomic.update or atomic.write op");
3425 if (atomicWriteOp) {
3426 isPostfixUpdate =
true;
3427 mlirExpr = atomicWriteOp.getExpr();
3429 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3430 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3431 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3434 if (innerOpList.size() == 2) {
3437 atomicUpdateOp.getRegion().getArgument(0))) {
3438 return atomicUpdateOp.emitError(
3439 "no atomic update operation with region argument"
3440 " as operand found inside atomic.update region");
3444 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3447 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3451 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3452 llvm::Value *llvmX =
3453 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3454 llvm::Value *llvmV =
3455 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3456 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3457 atomicCaptureOp.getAtomicReadOp().getElementType());
3458 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3461 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3465 llvm::AtomicOrdering atomicOrdering =
3469 [&](llvm::Value *atomicx,
3472 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
3473 Block &bb = *atomicUpdateOp.getRegion().
begin();
3474 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
3476 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3477 if (
failed(moduleTranslation.convertBlock(bb,
true, builder)))
3478 return llvm::make_error<PreviouslyReportedError>();
3480 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3481 assert(yieldop && yieldop.getResults().size() == 1 &&
3482 "terminator must be omp.yield op and it must have exactly one "
3484 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3487 bool isIgnoreDenormalMode;
3488 bool isFineGrainedMemory;
3489 bool isRemoteMemory;
3491 isFineGrainedMemory, isRemoteMemory);
3494 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3495 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3496 ompBuilder->createAtomicCapture(
3497 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3498 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
3499 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
3504 builder.restoreIP(*afterIP);
3509 omp::ClauseCancellationConstructType directive) {
3510 switch (directive) {
3511 case omp::ClauseCancellationConstructType::Loop:
3512 return llvm::omp::Directive::OMPD_for;
3513 case omp::ClauseCancellationConstructType::Parallel:
3514 return llvm::omp::Directive::OMPD_parallel;
3515 case omp::ClauseCancellationConstructType::Sections:
3516 return llvm::omp::Directive::OMPD_sections;
3517 case omp::ClauseCancellationConstructType::Taskgroup:
3518 return llvm::omp::Directive::OMPD_taskgroup;
3520 llvm_unreachable(
"Unhandled cancellation construct type");
3523 static LogicalResult
3525 LLVM::ModuleTranslation &moduleTranslation) {
3529 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3530 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3532 llvm::Value *ifCond =
nullptr;
3533 if (
Value ifVar = op.getIfExpr())
3534 ifCond = moduleTranslation.lookupValue(ifVar);
3536 llvm::omp::Directive cancelledDirective =
3539 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3540 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3545 builder.restoreIP(afterIP.get());
3550 static LogicalResult
3552 llvm::IRBuilderBase &builder,
3553 LLVM::ModuleTranslation &moduleTranslation) {
3557 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3558 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3560 llvm::omp::Directive cancelledDirective =
3563 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3564 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3569 builder.restoreIP(afterIP.get());
3576 static LogicalResult
3578 LLVM::ModuleTranslation &moduleTranslation) {
3579 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3580 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3581 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3586 Value symAddr = threadprivateOp.getSymAddr();
3589 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3592 if (!isa<LLVM::AddressOfOp>(symOp))
3593 return opInst.
emitError(
"Addressing symbol not found");
3594 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3596 LLVM::GlobalOp global =
3597 addressOfOp.getGlobal(moduleTranslation.symbolTable());
3598 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
3600 if (!ompBuilder->Config.isTargetDevice()) {
3601 llvm::Type *type = globalValue->getValueType();
3602 llvm::TypeSize typeSize =
3603 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3605 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3606 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3607 ompLoc, globalValue, size, global.getSymName() +
".cache");
3608 moduleTranslation.mapValue(opInst.
getResult(0), callInst);
3610 moduleTranslation.mapValue(opInst.
getResult(0), globalValue);
3616 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3618 switch (deviceClause) {
3619 case mlir::omp::DeclareTargetDeviceType::host:
3620 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3622 case mlir::omp::DeclareTargetDeviceType::nohost:
3623 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3625 case mlir::omp::DeclareTargetDeviceType::any:
3626 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3629 llvm_unreachable(
"unhandled device clause");
3632 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3634 mlir::omp::DeclareTargetCaptureClause captureClause) {
3635 switch (captureClause) {
3636 case mlir::omp::DeclareTargetCaptureClause::to:
3637 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3638 case mlir::omp::DeclareTargetCaptureClause::link:
3639 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3640 case mlir::omp::DeclareTargetCaptureClause::enter:
3641 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3643 llvm_unreachable(
"unhandled capture clause");
3648 llvm::OpenMPIRBuilder &ompBuilder) {
3650 llvm::raw_svector_ostream os(suffix);
3653 auto fileInfoCallBack = [&loc]() {
3654 return std::pair<std::string, uint64_t>(
3655 llvm::StringRef(loc.getFilename()), loc.getLine());
3658 auto vfs = llvm::vfs::getRealFileSystem();
3661 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
3663 os <<
"_decl_tgt_ref_ptr";
3669 if (
auto addressOfOp = value.
getDefiningOp<LLVM::AddressOfOp>()) {
3670 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3671 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3672 if (
auto declareTargetGlobal =
3673 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3674 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3675 mlir::omp::DeclareTargetCaptureClause::link)
3684 static llvm::Value *
3686 LLVM::ModuleTranslation &moduleTranslation) {
3687 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3689 if (
auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
3694 if (
auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
3695 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3696 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3697 addressOfOp.getGlobalName()))) {
3699 if (
auto declareTargetGlobal =
3700 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3701 gOp.getOperation())) {
3705 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3706 mlir::omp::DeclareTargetCaptureClause::link) ||
3707 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3708 mlir::omp::DeclareTargetCaptureClause::to &&
3709 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3713 if (gOp.getSymName().contains(suffix))
3714 return moduleTranslation.getLLVMModule()->getNamedValue(
3717 return moduleTranslation.getLLVMModule()->getNamedValue(
3718 (gOp.getSymName().str() + suffix.str()).str());
3729 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3733 void append(MapInfosTy &curInfo) {
3734 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3735 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3744 struct MapInfoData : MapInfosTy {
3756 void append(MapInfoData &CurInfo) {
3757 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3758 CurInfo.IsDeclareTarget.end());
3759 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3760 OriginalValue.append(CurInfo.OriginalValue.begin(),
3761 CurInfo.OriginalValue.end());
3762 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3763 MapInfosTy::append(CurInfo);
3770 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3771 arrTy.getElementType()))
3788 llvm::Value *basePointer,
3789 llvm::Type *baseType,
3790 llvm::IRBuilderBase &builder,
3791 LLVM::ModuleTranslation &moduleTranslation) {
3792 if (
auto memberClause =
3793 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3798 if (!memberClause.getBounds().empty()) {
3799 llvm::Value *elementCount = builder.getInt64(1);
3800 for (
auto bounds : memberClause.getBounds()) {
3801 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3802 bounds.getDefiningOp())) {
3807 elementCount = builder.CreateMul(
3811 moduleTranslation.lookupValue(boundOp.getUpperBound()),
3812 moduleTranslation.lookupValue(boundOp.getLowerBound())),
3813 builder.getInt64(1)));
3820 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3828 return builder.CreateMul(elementCount,
3829 builder.getInt64(underlyingTypeSzInBits / 8));
3838 LLVM::ModuleTranslation &moduleTranslation,
DataLayout &dl,
3842 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3850 for (
Value mapValue : mapVars) {
3851 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3852 for (
auto member : map.getMembers())
3853 if (member == mapOp)
3860 for (
Value mapValue : mapVars) {
3861 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3863 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3864 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
3865 mapData.Pointers.push_back(mapData.OriginalValue.back());
3867 if (llvm::Value *refPtr =
3869 moduleTranslation)) {
3870 mapData.IsDeclareTarget.push_back(
true);
3871 mapData.BasePointers.push_back(refPtr);
3873 mapData.IsDeclareTarget.push_back(
false);
3874 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3877 mapData.BaseType.push_back(
3878 moduleTranslation.convertType(mapOp.getVarType()));
3879 mapData.Sizes.push_back(
3880 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3881 mapData.BaseType.back(), builder, moduleTranslation));
3882 mapData.MapClause.push_back(mapOp.getOperation());
3883 mapData.Types.push_back(
3884 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3886 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
3888 if (mapOp.getMapperId())
3889 mapData.Mappers.push_back(
3890 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3891 mapOp, mapOp.getMapperIdAttr()));
3893 mapData.Mappers.push_back(
nullptr);
3894 mapData.IsAMapping.push_back(
true);
3895 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3898 auto findMapInfo = [&mapData](llvm::Value *val,
3899 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3902 for (llvm::Value *basePtr : mapData.OriginalValue) {
3903 if (basePtr == val && mapData.IsAMapping[index]) {
3905 mapData.Types[index] |=
3906 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3907 mapData.DevicePointers[index] = devInfoTy;
3916 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3917 for (
Value mapValue : useDevOperands) {
3918 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3920 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3921 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
3924 if (!findMapInfo(origValue, devInfoTy)) {
3925 mapData.OriginalValue.push_back(origValue);
3926 mapData.Pointers.push_back(mapData.OriginalValue.back());
3927 mapData.IsDeclareTarget.push_back(
false);
3928 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3929 mapData.BaseType.push_back(
3930 moduleTranslation.convertType(mapOp.getVarType()));
3931 mapData.Sizes.push_back(builder.getInt64(0));
3932 mapData.MapClause.push_back(mapOp.getOperation());
3933 mapData.Types.push_back(
3934 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3936 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
3937 mapData.DevicePointers.push_back(devInfoTy);
3938 mapData.Mappers.push_back(
nullptr);
3939 mapData.IsAMapping.push_back(
false);
3940 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3945 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3946 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3948 for (
Value mapValue : hasDevAddrOperands) {
3949 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3951 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3952 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
3954 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3955 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3957 mapData.OriginalValue.push_back(origValue);
3958 mapData.BasePointers.push_back(origValue);
3959 mapData.Pointers.push_back(origValue);
3960 mapData.IsDeclareTarget.push_back(
false);
3961 mapData.BaseType.push_back(
3962 moduleTranslation.convertType(mapOp.getVarType()));
3963 mapData.Sizes.push_back(
3964 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3965 mapData.MapClause.push_back(mapOp.getOperation());
3966 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3970 mapData.Types.push_back(mapType);
3974 if (mapOp.getMapperId()) {
3975 mapData.Mappers.push_back(
3976 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3977 mapOp, mapOp.getMapperIdAttr()));
3979 mapData.Mappers.push_back(
nullptr);
3982 mapData.Types.push_back(
3983 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3984 mapData.Mappers.push_back(
nullptr);
3987 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
3988 mapData.DevicePointers.push_back(
3989 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3990 mapData.IsAMapping.push_back(
false);
3991 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3996 auto *res = llvm::find(mapData.MapClause, memberOp);
3997 assert(res != mapData.MapClause.end() &&
3998 "MapInfoOp for member not found in MapData, cannot return index");
3999 return std::distance(mapData.MapClause.begin(), res);
4004 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4006 if (indexAttr.size() == 1)
4007 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4010 std::iota(indices.begin(), indices.end(), 0);
4012 llvm::sort(indices, [&](
const size_t a,
const size_t b) {
4013 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4014 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4015 for (
const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4016 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
4017 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
4019 if (aIndex == bIndex)
4022 if (aIndex < bIndex)
4025 if (aIndex > bIndex)
4032 return memberIndicesA.size() < memberIndicesB.size();
4035 return llvm::cast<omp::MapInfoOp>(
4036 mapInfo.getMembers()[indices.front()].getDefiningOp());
4058 static std::vector<llvm::Value *>
4060 llvm::IRBuilderBase &builder,
bool isArrayTy,
4062 std::vector<llvm::Value *> idx;
4073 idx.push_back(builder.getInt64(0));
4074 for (
int i = bounds.size() - 1; i >= 0; --i) {
4075 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4076 bounds[i].getDefiningOp())) {
4077 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4099 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
4100 for (
size_t i = 1; i < bounds.size(); ++i) {
4101 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4102 bounds[i].getDefiningOp())) {
4103 dimensionIndexSizeOffset.push_back(builder.CreateMul(
4104 moduleTranslation.lookupValue(boundOp.getExtent()),
4105 dimensionIndexSizeOffset[i - 1]));
4113 for (
int i = bounds.size() - 1; i >= 0; --i) {
4114 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4115 bounds[i].getDefiningOp())) {
4117 idx.emplace_back(builder.CreateMul(
4118 moduleTranslation.lookupValue(boundOp.getLowerBound()),
4119 dimensionIndexSizeOffset[i]));
4121 idx.back() = builder.CreateAdd(
4122 idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
4123 boundOp.getLowerBound()),
4124 dimensionIndexSizeOffset[i]));
4148 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4149 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4150 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
4151 assert(!ompBuilder.Config.isTargetDevice() &&
4152 "function only supported for host device codegen");
4158 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4160 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4161 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4164 bool hasUserMapper = mapData.Mappers[mapDataIndex] !=
nullptr;
4165 if (hasUserMapper) {
4166 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4170 mapFlags parentFlags = mapData.Types[mapDataIndex];
4171 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4172 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4173 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4174 baseFlag |= (parentFlags & preserve);
4177 combinedInfo.Types.emplace_back(baseFlag);
4178 combinedInfo.DevicePointers.emplace_back(
4179 mapData.DevicePointers[mapDataIndex]);
4180 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4182 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4183 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4193 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4195 llvm::Value *lowAddr, *highAddr;
4196 if (!parentClause.getPartialMap()) {
4197 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4198 builder.getPtrTy());
4199 highAddr = builder.CreatePointerCast(
4200 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4201 mapData.Pointers[mapDataIndex], 1),
4202 builder.getPtrTy());
4203 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4205 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4208 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4209 builder.getPtrTy());
4212 highAddr = builder.CreatePointerCast(
4213 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4214 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4215 builder.getPtrTy());
4216 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4219 llvm::Value *size = builder.CreateIntCast(
4220 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4221 builder.getInt64Ty(),
4223 combinedInfo.Sizes.push_back(size);
4225 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4226 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4234 if (!parentClause.getPartialMap()) {
4239 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4240 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4241 combinedInfo.Types.emplace_back(mapFlag);
4242 combinedInfo.DevicePointers.emplace_back(
4244 combinedInfo.Mappers.emplace_back(
nullptr);
4246 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4247 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4248 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4249 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4251 return memberOfFlag;
4263 if (mapOp.getVarPtrPtr())
4277 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4278 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4279 MapInfoData &mapData, uint64_t mapDataIndex,
4280 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4281 assert(!ompBuilder.Config.isTargetDevice() &&
4282 "function only supported for host device codegen");
4285 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4287 for (
auto mappedMembers : parentClause.getMembers()) {
4289 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4292 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4303 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4304 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4305 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4306 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4307 combinedInfo.Types.emplace_back(mapFlag);
4308 combinedInfo.DevicePointers.emplace_back(
4310 combinedInfo.Mappers.emplace_back(
nullptr);
4311 combinedInfo.Names.emplace_back(
4313 combinedInfo.BasePointers.emplace_back(
4314 mapData.BasePointers[mapDataIndex]);
4315 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4316 combinedInfo.Sizes.emplace_back(builder.getInt64(
4317 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
4323 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4324 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4325 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4326 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4328 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4330 combinedInfo.Types.emplace_back(mapFlag);
4331 combinedInfo.DevicePointers.emplace_back(
4332 mapData.DevicePointers[memberDataIdx]);
4333 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4334 combinedInfo.Names.emplace_back(
4336 uint64_t basePointerIndex =
4338 combinedInfo.BasePointers.emplace_back(
4339 mapData.BasePointers[basePointerIndex]);
4340 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4342 llvm::Value *size = mapData.Sizes[memberDataIdx];
4344 size = builder.CreateSelect(
4345 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4346 builder.getInt64(0), size);
4349 combinedInfo.Sizes.emplace_back(size);
4354 MapInfosTy &combinedInfo,
bool isTargetParams,
4355 int mapDataParentIdx = -1) {
4359 auto mapFlag = mapData.Types[mapDataIdx];
4360 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4364 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4366 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4367 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4369 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4371 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4376 if (mapDataParentIdx >= 0)
4377 combinedInfo.BasePointers.emplace_back(
4378 mapData.BasePointers[mapDataParentIdx]);
4380 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4382 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4383 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4384 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4385 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4386 combinedInfo.Types.emplace_back(mapFlag);
4387 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4391 llvm::IRBuilderBase &builder,
4392 llvm::OpenMPIRBuilder &ompBuilder,
4394 MapInfoData &mapData, uint64_t mapDataIndex,
4395 bool isTargetParams) {
4396 assert(!ompBuilder.Config.isTargetDevice() &&
4397 "function only supported for host device codegen");
4400 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4405 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4406 auto memberClause = llvm::cast<omp::MapInfoOp>(
4407 parentClause.getMembers()[0].getDefiningOp());
4424 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4426 combinedInfo, mapData, mapDataIndex, isTargetParams);
4428 combinedInfo, mapData, mapDataIndex,
4429 memberOfParentFlag);
4438 LLVM::ModuleTranslation &moduleTranslation,
4439 llvm::IRBuilderBase &builder) {
4440 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4441 "function only supported for host device codegen");
4442 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4444 if (!mapData.IsDeclareTarget[i]) {
4445 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4446 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4456 switch (captureKind) {
4457 case omp::VariableCaptureKind::ByRef: {
4458 llvm::Value *newV = mapData.Pointers[i];
4460 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4463 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4465 if (!offsetIdx.empty())
4466 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4468 mapData.Pointers[i] = newV;
4470 case omp::VariableCaptureKind::ByCopy: {
4471 llvm::Type *type = mapData.BaseType[i];
4473 if (mapData.Pointers[i]->getType()->isPointerTy())
4474 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4476 newV = mapData.Pointers[i];
4479 auto curInsert = builder.saveIP();
4480 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
4482 auto *memTempAlloc =
4483 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4484 builder.SetCurrentDebugLocation(DbgLoc);
4485 builder.restoreIP(curInsert);
4487 builder.CreateStore(newV, memTempAlloc);
4488 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4491 mapData.Pointers[i] = newV;
4492 mapData.BasePointers[i] = newV;
4494 case omp::VariableCaptureKind::This:
4495 case omp::VariableCaptureKind::VLAType:
4496 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4505 LLVM::ModuleTranslation &moduleTranslation,
4507 MapInfoData &mapData,
bool isTargetParams =
false) {
4508 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4509 "function only supported for host device codegen");
4524 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4531 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4534 if (mapData.IsAMember[i])
4537 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4538 if (!mapInfoOp.getMembers().empty()) {
4540 combinedInfo, mapData, i, isTargetParams);
4550 LLVM::ModuleTranslation &moduleTranslation,
4551 llvm::StringRef mapperFuncName);
4555 LLVM::ModuleTranslation &moduleTranslation) {
4556 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4557 "function only supported for host device codegen");
4558 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4559 std::string mapperFuncName =
4560 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
4561 {
"omp_mapper", declMapperOp.getSymName()});
4563 if (
auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
4572 LLVM::ModuleTranslation &moduleTranslation,
4573 llvm::StringRef mapperFuncName) {
4574 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4575 "function only supported for host device codegen");
4576 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4577 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4579 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4580 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
4583 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4586 MapInfosTy combinedInfo;
4588 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4589 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4590 builder.restoreIP(codeGenIP);
4591 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
4592 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
4593 builder.GetInsertBlock());
4594 if (
failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
4597 return llvm::make_error<PreviouslyReportedError>();
4598 MapInfoData mapData;
4601 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4605 moduleTranslation.forgetMapping(declMapperOp.getRegion());
4606 return combinedInfo;
4610 if (!combinedInfo.Mappers[i])
4617 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4619 return newFn.takeError();
4620 moduleTranslation.mapFunction(mapperFuncName, *newFn);
4624 static LogicalResult
4626 LLVM::ModuleTranslation &moduleTranslation) {
4627 llvm::Value *ifCond =
nullptr;
4628 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4632 llvm::omp::RuntimeFunction RTLFn;
4635 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4636 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4638 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4639 bool isOffloadEntry =
4640 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4642 LogicalResult result =
4644 .Case([&](omp::TargetDataOp dataOp) {
4648 if (
auto ifVar = dataOp.getIfExpr())
4649 ifCond = moduleTranslation.lookupValue(ifVar);
4651 if (
auto devId = dataOp.getDevice())
4652 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4653 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4654 deviceID = intAttr.getInt();
4656 mapVars = dataOp.getMapVars();
4657 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4658 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4661 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4665 if (
auto ifVar = enterDataOp.getIfExpr())
4666 ifCond = moduleTranslation.lookupValue(ifVar);
4668 if (
auto devId = enterDataOp.getDevice())
4669 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4670 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4671 deviceID = intAttr.getInt();
4673 enterDataOp.getNowait()
4674 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4675 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4676 mapVars = enterDataOp.getMapVars();
4677 info.HasNoWait = enterDataOp.getNowait();
4680 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4684 if (
auto ifVar = exitDataOp.getIfExpr())
4685 ifCond = moduleTranslation.lookupValue(ifVar);
4687 if (
auto devId = exitDataOp.getDevice())
4688 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4689 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4690 deviceID = intAttr.getInt();
4692 RTLFn = exitDataOp.getNowait()
4693 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4694 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4695 mapVars = exitDataOp.getMapVars();
4696 info.HasNoWait = exitDataOp.getNowait();
4699 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4703 if (
auto ifVar = updateDataOp.getIfExpr())
4704 ifCond = moduleTranslation.lookupValue(ifVar);
4706 if (
auto devId = updateDataOp.getDevice())
4707 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4708 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4709 deviceID = intAttr.getInt();
4712 updateDataOp.getNowait()
4713 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4714 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4715 mapVars = updateDataOp.getMapVars();
4716 info.HasNoWait = updateDataOp.getNowait();
4719 .DefaultUnreachable(
"unexpected operation");
4724 if (!isOffloadEntry)
4725 ifCond = builder.getFalse();
4727 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4728 MapInfoData mapData;
4730 builder, useDevicePtrVars, useDeviceAddrVars);
4733 MapInfosTy combinedInfo;
4734 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4735 builder.restoreIP(codeGenIP);
4736 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4737 return combinedInfo;
4743 [&moduleTranslation](
4744 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4748 for (
auto [arg, useDevVar] :
4749 llvm::zip_equal(blockArgs, useDeviceVars)) {
4751 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4752 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4753 : mapInfoOp.getVarPtr();
4756 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4757 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4758 mapInfoData.MapClause, mapInfoData.DevicePointers,
4759 mapInfoData.BasePointers)) {
4760 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4761 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4762 devicePointer != type)
4765 if (llvm::Value *devPtrInfoMap =
4766 mapper ? mapper(basePointer) : basePointer) {
4767 moduleTranslation.mapValue(arg, devPtrInfoMap);
4774 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4775 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4776 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4779 builder.restoreIP(codeGenIP);
4780 assert(isa<omp::TargetDataOp>(op) &&
4781 "BodyGen requested for non TargetDataOp");
4782 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4783 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4784 switch (bodyGenType) {
4785 case BodyGenTy::Priv:
4787 if (!info.DevicePtrInfoMap.empty()) {
4788 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4789 blockArgIface.getUseDeviceAddrBlockArgs(),
4790 useDeviceAddrVars, mapData,
4791 [&](llvm::Value *basePointer) -> llvm::Value * {
4792 if (!info.DevicePtrInfoMap[basePointer].second)
4794 return builder.CreateLoad(
4796 info.DevicePtrInfoMap[basePointer].second);
4798 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4799 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4800 mapData, [&](llvm::Value *basePointer) {
4801 return info.DevicePtrInfoMap[basePointer].second;
4805 moduleTranslation)))
4806 return llvm::make_error<PreviouslyReportedError>();
4809 case BodyGenTy::DupNoPriv:
4810 if (info.DevicePtrInfoMap.empty()) {
4813 if (!ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4814 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4815 blockArgIface.getUseDeviceAddrBlockArgs(),
4816 useDeviceAddrVars, mapData);
4817 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4818 blockArgIface.getUseDevicePtrBlockArgs(),
4819 useDevicePtrVars, mapData);
4823 case BodyGenTy::NoPriv:
4825 if (info.DevicePtrInfoMap.empty()) {
4828 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4829 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4830 blockArgIface.getUseDeviceAddrBlockArgs(),
4831 useDeviceAddrVars, mapData);
4832 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4833 blockArgIface.getUseDevicePtrBlockArgs(),
4834 useDevicePtrVars, mapData);
4838 moduleTranslation)))
4839 return llvm::make_error<PreviouslyReportedError>();
4843 return builder.saveIP();
4846 auto customMapperCB =
4848 if (!combinedInfo.Mappers[i])
4850 info.HasMapper =
true;
4855 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4856 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4858 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4859 if (isa<omp::TargetDataOp>(op))
4860 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4861 builder.getInt64(deviceID), ifCond,
4862 info, genMapInfoCB, customMapperCB,
4865 return ompBuilder->createTargetData(
4866 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4867 info, genMapInfoCB, customMapperCB, &RTLFn);
4873 builder.restoreIP(*afterIP);
4877 static LogicalResult
4879 LLVM::ModuleTranslation &moduleTranslation) {
4880 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4881 auto distributeOp = cast<omp::DistributeOp>(opInst);
4888 bool doDistributeReduction =
4892 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4897 if (doDistributeReduction) {
4898 isByRef =
getIsByRef(teamsOp.getReductionByref());
4899 assert(isByRef.size() == teamsOp.getNumReductionVars());
4902 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4906 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4907 .getReductionBlockArgs();
4910 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4911 reductionDecls, privateReductionVariables, reductionVariableMap,
4916 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4917 auto bodyGenCB = [&](InsertPointTy allocaIP,
4918 InsertPointTy codeGenIP) -> llvm::Error {
4922 moduleTranslation, allocaIP);
4925 builder.restoreIP(codeGenIP);
4931 return llvm::make_error<PreviouslyReportedError>();
4936 return llvm::make_error<PreviouslyReportedError>();
4939 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
4941 distributeOp.getPrivateNeedsBarrier())))
4942 return llvm::make_error<PreviouslyReportedError>();
4944 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4948 builder, moduleTranslation);
4950 return regionBlock.takeError();
4951 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4956 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4959 auto schedule = omp::ClauseScheduleKind::Static;
4960 bool isOrdered =
false;
4961 std::optional<omp::ScheduleModifier> scheduleMod;
4962 bool isSimd =
false;
4963 llvm::omp::WorksharingLoopType workshareLoopType =
4964 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4965 bool loopNeedsBarrier =
false;
4966 llvm::Value *chunk =
nullptr;
4968 llvm::CanonicalLoopInfo *loopInfo =
4970 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4971 ompBuilder->applyWorkshareLoop(
4972 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4973 convertToScheduleKind(schedule), chunk, isSimd,
4974 scheduleMod == omp::ScheduleModifier::monotonic,
4975 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4979 return wsloopIP.takeError();
4983 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4985 return llvm::make_error<PreviouslyReportedError>();
4987 return llvm::Error::success();
4990 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4992 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4993 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4994 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4999 builder.restoreIP(*afterIP);
5001 if (doDistributeReduction) {
5004 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5005 privateReductionVariables, isByRef,
5014 static LogicalResult
5016 LLVM::ModuleTranslation &moduleTranslation) {
5017 if (!cast<mlir::ModuleOp>(op))
5020 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5022 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5023 attribute.getOpenmpDeviceVersion());
5025 if (attribute.getNoGpuLib())
5028 ompBuilder->createGlobalFlag(
5029 attribute.getDebugKind() ,
5030 "__omp_rtl_debug_kind");
5031 ompBuilder->createGlobalFlag(
5033 .getAssumeTeamsOversubscription()
5035 "__omp_rtl_assume_teams_oversubscription");
5036 ompBuilder->createGlobalFlag(
5038 .getAssumeThreadsOversubscription()
5040 "__omp_rtl_assume_threads_oversubscription");
5041 ompBuilder->createGlobalFlag(
5042 attribute.getAssumeNoThreadState() ,
5043 "__omp_rtl_assume_no_thread_state");
5044 ompBuilder->createGlobalFlag(
5046 .getAssumeNoNestedParallelism()
5048 "__omp_rtl_assume_no_nested_parallelism");
5053 omp::TargetOp targetOp,
5054 llvm::StringRef parentName =
"") {
5055 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5057 assert(fileLoc &&
"No file found from location");
5058 StringRef fileName = fileLoc.getFilename().getValue();
5060 llvm::sys::fs::UniqueID id;
5061 uint64_t line = fileLoc.getLine();
5062 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5064 size_t deviceId = 0xdeadf17e;
5066 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5068 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5069 id.getFile(), line);
5075 LLVM::ModuleTranslation &moduleTranslation,
5076 llvm::IRBuilderBase &builder, llvm::Function *func) {
5077 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5078 "function only supported for target device codegen");
5079 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5080 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5093 if (mapData.IsDeclareTarget[i]) {
5100 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5101 convertUsersOfConstantsToInstructions(constant, func,
false);
5108 for (llvm::User *user : mapData.OriginalValue[i]->users())
5109 userVec.push_back(user);
5111 for (llvm::User *user : userVec) {
5112 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5113 if (insn->getFunction() == func) {
5114 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5115 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
5116 mapData.BasePointers[i]);
5117 load->moveBefore(insn->getIterator());
5118 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
5165 static llvm::IRBuilderBase::InsertPoint
5167 llvm::Value *input, llvm::Value *&retVal,
5168 llvm::IRBuilderBase &builder,
5169 llvm::OpenMPIRBuilder &ompBuilder,
5170 LLVM::ModuleTranslation &moduleTranslation,
5171 llvm::IRBuilderBase::InsertPoint allocaIP,
5172 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5173 assert(ompBuilder.Config.isTargetDevice() &&
5174 "function only supported for target device codegen");
5175 builder.restoreIP(allocaIP);
5177 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5178 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
5179 ompBuilder.M.getContext());
5180 unsigned alignmentValue = 0;
5182 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
5183 if (mapData.OriginalValue[i] == input) {
5184 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5185 capture = mapOp.getMapCaptureType();
5187 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
5188 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5192 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5193 unsigned int defaultAS =
5194 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5197 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5199 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5200 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5202 builder.CreateStore(&arg, v);
5204 builder.restoreIP(codeGenIP);
5207 case omp::VariableCaptureKind::ByCopy: {
5211 case omp::VariableCaptureKind::ByRef: {
5212 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5214 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5229 if (v->getType()->isPointerTy() && alignmentValue) {
5230 llvm::MDBuilder MDB(builder.getContext());
5231 loadInst->setMetadata(
5232 llvm::LLVMContext::MD_align,
5235 llvm::Type::getInt64Ty(builder.getContext()),
5242 case omp::VariableCaptureKind::This:
5243 case omp::VariableCaptureKind::VLAType:
5246 assert(
false &&
"Currently unsupported capture kind");
5250 return builder.saveIP();
5267 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
5268 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5269 blockArgIface.getHostEvalBlockArgs())) {
5270 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5274 .Case([&](omp::TeamsOp teamsOp) {
5275 if (teamsOp.getNumTeamsLower() == blockArg)
5276 numTeamsLower = hostEvalVar;
5277 else if (teamsOp.getNumTeamsUpper() == blockArg)
5278 numTeamsUpper = hostEvalVar;
5279 else if (teamsOp.getThreadLimit() == blockArg)
5280 threadLimit = hostEvalVar;
5282 llvm_unreachable(
"unsupported host_eval use");
5284 .Case([&](omp::ParallelOp parallelOp) {
5285 if (parallelOp.getNumThreads() == blockArg)
5286 numThreads = hostEvalVar;
5288 llvm_unreachable(
"unsupported host_eval use");
5290 .Case([&](omp::LoopNestOp loopOp) {
5291 auto processBounds =
5296 if (lb == blockArg) {
5299 (*outBounds)[i] = hostEvalVar;
5305 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5306 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5308 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5310 assert(found &&
"unsupported host_eval use");
5312 .DefaultUnreachable(
"unsupported host_eval use");
5324 template <
typename OpTy>
5329 if (OpTy casted = dyn_cast<OpTy>(op))
5332 if (immediateParent)
5333 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5342 return std::nullopt;
5345 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5346 return constAttr.getInt();
5348 return std::nullopt;
5353 uint64_t sizeInBytes = sizeInBits / 8;
5357 template <
typename OpTy>
5359 if (op.getNumReductionVars() > 0) {
5364 members.reserve(reductions.size());
5365 for (omp::DeclareReductionOp &red : reductions)
5366 members.push_back(red.getType());
5368 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5384 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5385 bool isTargetDevice,
bool isGPU) {
5388 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5389 if (!isTargetDevice) {
5396 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5397 numTeamsLower = teamsOp.getNumTeamsLower();
5398 numTeamsUpper = teamsOp.getNumTeamsUpper();
5399 threadLimit = teamsOp.getThreadLimit();
5402 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5403 numThreads = parallelOp.getNumThreads();
5408 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5409 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5412 if (numTeamsUpper) {
5414 minTeamsVal = maxTeamsVal = *val;
5416 minTeamsVal = maxTeamsVal = 0;
5418 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5420 castOrGetParentOfType<omp::SimdOp>(capturedOp,
5422 minTeamsVal = maxTeamsVal = 1;
5424 minTeamsVal = maxTeamsVal = -1;
5429 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
5443 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5444 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5445 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5448 int32_t maxThreadsVal = -1;
5449 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5450 setMaxValueFromClause(numThreads, maxThreadsVal);
5451 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5458 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5459 if (combinedMaxThreadsVal < 0 ||
5460 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5461 combinedMaxThreadsVal = teamsThreadLimitVal;
5463 if (combinedMaxThreadsVal < 0 ||
5464 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5465 combinedMaxThreadsVal = maxThreadsVal;
5467 int32_t reductionDataSize = 0;
5468 if (isGPU && capturedOp) {
5469 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5474 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5476 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5477 omp::TargetRegionFlags::spmd) &&
5478 "invalid kernel flags");
5480 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5481 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5482 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5483 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5484 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5485 if (omp::bitEnumContainsAll(kernelFlags,
5486 omp::TargetRegionFlags::spmd |
5487 omp::TargetRegionFlags::no_loop) &&
5488 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
5489 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
5491 attrs.MinTeams = minTeamsVal;
5492 attrs.MaxTeams.front() = maxTeamsVal;
5493 attrs.MinThreads = 1;
5494 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5495 attrs.ReductionDataSize = reductionDataSize;
5498 if (attrs.ReductionDataSize != 0)
5499 attrs.ReductionBufferLength = 1024;
5510 LLVM::ModuleTranslation &moduleTranslation,
5511 omp::TargetOp targetOp,
Operation *capturedOp,
5512 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5513 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5514 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5516 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5520 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5523 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5524 attrs.TargetThreadLimit.front() =
5525 moduleTranslation.lookupValue(targetThreadLimit);
5528 attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
5531 attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);
5533 if (teamsThreadLimit)
5534 attrs.TeamsThreadLimit.front() =
5535 moduleTranslation.lookupValue(teamsThreadLimit);
5538 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
5540 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5541 omp::TargetRegionFlags::trip_count)) {
5542 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5543 attrs.LoopTripCount =
nullptr;
5548 for (
auto [loopLower, loopUpper, loopStep] :
5549 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5550 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
5551 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
5552 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
5554 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5555 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5556 loc, lowerBound, upperBound, step,
true,
5557 loopOp.getLoopInclusive());
5559 if (!attrs.LoopTripCount) {
5560 attrs.LoopTripCount = tripCount;
5565 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5571 static LogicalResult
5573 LLVM::ModuleTranslation &moduleTranslation) {
5574 auto targetOp = cast<omp::TargetOp>(opInst);
5578 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
5587 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
5588 assert(parentBB &&
"No insert block is set for the builder");
5589 llvm::Function *parentLLVMFn = parentBB->getParent();
5590 assert(parentLLVMFn &&
"Parent Function must be valid");
5591 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
5593 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
5594 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
5596 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5597 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5598 bool isGPU = ompBuilder->Config.isGPU();
5601 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5602 auto &targetRegion = targetOp.getRegion();
5619 llvm::Function *llvmOutlinedFn =
nullptr;
5623 bool isOffloadEntry =
5624 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5631 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5633 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5634 std::optional<DenseI64ArrayAttr> privateMapIndices =
5635 targetOp.getPrivateMapsAttr();
5637 for (
auto [privVarIdx, privVarSymPair] :
5639 auto privVar = std::get<0>(privVarSymPair);
5640 auto privSym = std::get<1>(privVarSymPair);
5642 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5643 omp::PrivateClauseOp privatizer =
5646 if (!privatizer.needsMap())
5650 targetOp.getMappedValueForPrivateVar(privVarIdx);
5651 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5652 "variable that needs mapping");
5657 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5658 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5662 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5664 varType == privVar.getType() &&
5665 "Type of private var doesn't match the type of the mapped value");
5669 mappedPrivateVars.insert(
5671 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5672 (*privateMapIndices)[privVarIdx])});
5676 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5677 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5678 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5679 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5680 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5683 llvm::Function *llvmParentFn =
5684 moduleTranslation.lookupFunction(parentFn.getName());
5685 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5686 assert(llvmParentFn && llvmOutlinedFn &&
5687 "Both parent and outlined functions must exist at this point");
5689 if (outlinedFnLoc && llvmParentFn->getSubprogram())
5690 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
5692 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5693 attr.isStringAttribute())
5694 llvmOutlinedFn->addFnAttr(attr);
5696 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5697 attr.isStringAttribute())
5698 llvmOutlinedFn->addFnAttr(attr);
5700 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5701 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5702 llvm::Value *mapOpValue =
5703 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
5704 moduleTranslation.mapValue(arg, mapOpValue);
5706 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5707 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5708 llvm::Value *mapOpValue =
5709 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
5710 moduleTranslation.mapValue(arg, mapOpValue);
5719 allocaIP, &mappedPrivateVars);
5722 return llvm::make_error<PreviouslyReportedError>();
5724 builder.restoreIP(codeGenIP);
5726 &mappedPrivateVars),
5729 return llvm::make_error<PreviouslyReportedError>();
5732 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
5734 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5735 return llvm::make_error<PreviouslyReportedError>();
5739 std::back_inserter(privateCleanupRegions),
5740 [](omp::PrivateClauseOp privatizer) {
5741 return &privatizer.getDeallocRegion();
5745 targetRegion,
"omp.target", builder, moduleTranslation);
5748 return exitBlock.takeError();
5750 builder.SetInsertPoint(*exitBlock);
5751 if (!privateCleanupRegions.empty()) {
5753 privateCleanupRegions, privateVarsInfo.
llvmVars,
5754 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5756 return llvm::createStringError(
5757 "failed to inline `dealloc` region of `omp.private` "
5758 "op in the target region");
5760 return builder.saveIP();
5763 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5766 StringRef parentName = parentFn.getName();
5768 llvm::TargetRegionEntryInfo entryInfo;
5772 MapInfoData mapData;
5777 MapInfosTy combinedInfos;
5779 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5780 builder.restoreIP(codeGenIP);
5781 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5782 return combinedInfos;
5785 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5786 llvm::Value *&retVal, InsertPointTy allocaIP,
5787 InsertPointTy codeGenIP)
5788 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5789 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5790 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5796 if (!isTargetDevice) {
5797 retVal = cast<llvm::Value>(&arg);
5802 *ompBuilder, moduleTranslation,
5803 allocaIP, codeGenIP);
5806 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5807 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5808 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5810 isTargetDevice, isGPU);
5814 if (!isTargetDevice)
5816 targetCapturedOp, runtimeAttrs);
5824 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5825 llvm::Value *value = moduleTranslation.lookupValue(var);
5826 moduleTranslation.mapValue(arg, value);
5828 if (!llvm::isa<llvm::Constant>(value))
5829 kernelInput.push_back(value);
5832 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5839 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5840 kernelInput.push_back(mapData.OriginalValue[i]);
5845 moduleTranslation, dds);
5847 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5849 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5851 llvm::OpenMPIRBuilder::TargetDataInfo info(
5855 auto customMapperCB =
5857 if (!combinedInfos.Mappers[i])
5859 info.HasMapper =
true;
5864 llvm::Value *ifCond =
nullptr;
5865 if (
Value targetIfCond = targetOp.getIfExpr())
5866 ifCond = moduleTranslation.lookupValue(targetIfCond);
5868 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5869 moduleTranslation.getOpenMPBuilder()->createTarget(
5870 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5871 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5872 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5877 builder.restoreIP(*afterIP);
5881 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
5888 static LogicalResult
5890 LLVM::ModuleTranslation &moduleTranslation) {
5898 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5899 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5901 if (!offloadMod.getIsTargetDevice())
5904 omp::DeclareTargetDeviceType declareType =
5905 attribute.getDeviceType().getValue();
5907 if (declareType == omp::DeclareTargetDeviceType::host) {
5908 llvm::Function *llvmFunc =
5909 moduleTranslation.lookupFunction(funcOp.getName());
5910 llvmFunc->dropAllReferences();
5911 llvmFunc->eraseFromParent();
5917 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5918 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
5919 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5920 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5921 bool isDeclaration = gOp.isDeclaration();
5922 bool isExternallyVisible =
5925 llvm::StringRef mangledName = gOp.getSymName();
5926 auto captureClause =
5932 std::vector<llvm::GlobalVariable *> generatedRefs;
5934 std::vector<llvm::Triple> targetTriple;
5935 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5937 LLVM::LLVMDialect::getTargetTripleAttrName()));
5938 if (targetTripleAttr)
5939 targetTriple.emplace_back(targetTripleAttr.data());
5941 auto fileInfoCallBack = [&loc]() {
5942 std::string filename =
"";
5943 std::uint64_t lineNo = 0;
5946 filename = loc.getFilename().str();
5947 lineNo = loc.getLine();
5950 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5954 auto vfs = llvm::vfs::getRealFileSystem();
5956 ompBuilder->registerTargetGlobalVariable(
5957 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5958 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
5959 mangledName, generatedRefs,
false, targetTriple,
5961 gVal->getType(), gVal);
5963 if (ompBuilder->Config.isTargetDevice() &&
5964 (attribute.getCaptureClause().getValue() !=
5965 mlir::omp::DeclareTargetCaptureClause::to ||
5966 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5967 ompBuilder->getAddrOfDeclareTargetVar(
5968 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5969 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
5970 mangledName, generatedRefs,
false, targetTriple,
5971 gVal->getType(),
nullptr,
5992 if (mlir::isa<omp::ThreadprivateOp>(op))
5995 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
5996 mlir::isa<omp::TargetFreeMemOp>(op))
6000 if (
auto declareTargetIface =
6001 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6002 parentFn.getOperation()))
6003 if (declareTargetIface.isDeclareTarget() &&
6004 declareTargetIface.getDeclareTargetDeviceType() !=
6005 mlir::omp::DeclareTargetDeviceType::host)
6012 llvm::Module *llvmModule) {
6013 llvm::Type *i64Ty = builder.getInt64Ty();
6014 llvm::Type *i32Ty = builder.getInt32Ty();
6015 llvm::Type *returnType = builder.getPtrTy(0);
6016 llvm::FunctionType *fnType =
6018 llvm::Function *func = cast<llvm::Function>(
6019 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
6023 static LogicalResult
6025 LLVM::ModuleTranslation &moduleTranslation) {
6026 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6031 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6035 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6037 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6038 mlir::Type heapTy = allocMemOp.getAllocatedType();
6039 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
6040 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6041 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6042 for (
auto typeParam : allocMemOp.getTypeparams())
6044 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
6046 llvm::CallInst *call =
6047 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6048 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6051 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
6056 llvm::Module *llvmModule) {
6057 llvm::Type *ptrTy = builder.getPtrTy(0);
6058 llvm::Type *i32Ty = builder.getInt32Ty();
6059 llvm::Type *voidTy = builder.getVoidTy();
6060 llvm::FunctionType *fnType =
6062 llvm::Function *func = dyn_cast<llvm::Function>(
6063 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
6067 static LogicalResult
6069 LLVM::ModuleTranslation &moduleTranslation) {
6070 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6075 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6079 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6082 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
6084 llvm::Value *intToPtr =
6085 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
6086 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
6092 static LogicalResult
6094 LLVM::ModuleTranslation &moduleTranslation) {
6095 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6103 bool isOutermostLoopWrapper =
6104 isa_and_present<omp::LoopWrapperInterface>(op) &&
6105 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
6107 if (isOutermostLoopWrapper)
6108 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
6112 .Case([&](omp::BarrierOp op) -> LogicalResult {
6116 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6117 ompBuilder->createBarrier(builder.saveIP(),
6118 llvm::omp::OMPD_barrier);
6120 if (res.succeeded()) {
6123 builder.restoreIP(*afterIP);
6127 .Case([&](omp::TaskyieldOp op) {
6131 ompBuilder->createTaskyield(builder.saveIP());
6134 .Case([&](omp::FlushOp op) {
6146 ompBuilder->createFlush(builder.saveIP());
6149 .Case([&](omp::ParallelOp op) {
6152 .Case([&](omp::MaskedOp) {
6155 .Case([&](omp::MasterOp) {
6158 .Case([&](omp::CriticalOp) {
6161 .Case([&](omp::OrderedRegionOp) {
6164 .Case([&](omp::OrderedOp) {
6167 .Case([&](omp::WsloopOp) {
6170 .Case([&](omp::SimdOp) {
6173 .Case([&](omp::AtomicReadOp) {
6176 .Case([&](omp::AtomicWriteOp) {
6179 .Case([&](omp::AtomicUpdateOp op) {
6182 .Case([&](omp::AtomicCaptureOp op) {
6185 .Case([&](omp::CancelOp op) {
6188 .Case([&](omp::CancellationPointOp op) {
6191 .Case([&](omp::SectionsOp) {
6194 .Case([&](omp::SingleOp op) {
6197 .Case([&](omp::TeamsOp op) {
6200 .Case([&](omp::TaskOp op) {
6203 .Case([&](omp::TaskgroupOp op) {
6206 .Case([&](omp::TaskwaitOp op) {
6209 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6210 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6211 omp::CriticalDeclareOp>([](
auto op) {
6224 .Case([&](omp::ThreadprivateOp) {
6227 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
6228 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
6231 .Case([&](omp::TargetOp) {
6234 .Case([&](omp::DistributeOp) {
6237 .Case([&](omp::LoopNestOp) {
6240 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
6247 .Case([&](omp::NewCliOp op) {
6252 .Case([&](omp::CanonicalLoopOp op) {
6255 .Case([&](omp::UnrollHeuristicOp op) {
6264 .Case([&](omp::TileOp op) {
6265 return applyTile(op, builder, moduleTranslation);
6267 .Case([&](omp::TargetAllocMemOp) {
6270 .Case([&](omp::TargetFreeMemOp) {
6275 <<
"not yet implemented: " << inst->
getName();
6278 if (isOutermostLoopWrapper)
6279 moduleTranslation.stackPop();
6284 static LogicalResult
6286 LLVM::ModuleTranslation &moduleTranslation) {
6290 static LogicalResult
6292 LLVM::ModuleTranslation &moduleTranslation) {
6293 if (isa<omp::TargetOp>(op))
6295 if (isa<omp::TargetDataOp>(op))
6299 if (isa<omp::TargetOp>(oper)) {
6304 if (isa<omp::TargetDataOp>(oper)) {
6314 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
6315 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
6316 !oper->getRegions().empty()) {
6317 if (
auto blockArgsIface =
6318 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
6324 if (isa<mlir::omp::AtomicUpdateOp>(oper))
6325 for (
auto [operand, arg] :
6326 llvm::zip_equal(oper->getOperands(),
6327 oper->getRegion(0).getArguments())) {
6328 moduleTranslation.mapValue(
6329 arg, builder.CreateLoad(
6330 moduleTranslation.convertType(arg.getType()),
6331 moduleTranslation.lookupValue(operand)));
6335 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
6336 assert(builder.GetInsertBlock() &&
6337 "No insert block is set for the builder");
6338 for (
auto iv : loopNest.getIVs()) {
6340 moduleTranslation.mapValue(
6342 moduleTranslation.convertType(iv.getType())));
6346 for (
Region ®ion : oper->getRegions()) {
6353 region, oper->getName().getStringRef().str() +
".fake.region",
6354 builder, moduleTranslation, &phis);
6358 builder.SetInsertPoint(result.get(), result.get()->end());
6365 }).wasInterrupted();
6366 return failure(interrupted);
6373 class OpenMPDialectLLVMIRTranslationInterface
6381 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
6382 LLVM::ModuleTranslation &moduleTranslation)
const final;
6389 LLVM::ModuleTranslation &moduleTranslation)
const final;
6394 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6397 LLVM::ModuleTranslation &moduleTranslation)
const {
6400 .Case(
"omp.is_target_device",
6402 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6403 llvm::OpenMPIRBuilderConfig &
config =
6404 moduleTranslation.getOpenMPBuilder()->Config;
6405 config.setIsTargetDevice(deviceAttr.getValue());
6412 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6413 llvm::OpenMPIRBuilderConfig &
config =
6414 moduleTranslation.getOpenMPBuilder()->Config;
6415 config.setIsGPU(gpuAttr.getValue());
6420 .Case(
"omp.host_ir_filepath",
6422 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6423 llvm::OpenMPIRBuilder *ompBuilder =
6424 moduleTranslation.getOpenMPBuilder();
6425 auto VFS = llvm::vfs::getRealFileSystem();
6426 ompBuilder->loadOffloadInfoMetadata(*VFS,
6427 filepathAttr.getValue());
6434 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6438 .Case(
"omp.version",
6440 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6441 llvm::OpenMPIRBuilder *ompBuilder =
6442 moduleTranslation.getOpenMPBuilder();
6443 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6444 versionAttr.getVersion());
6449 .Case(
"omp.declare_target",
6451 if (
auto declareTargetAttr =
6452 dyn_cast<omp::DeclareTargetAttr>(attr))
6457 .Case(
"omp.requires",
6459 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6460 using Requires = omp::ClauseRequires;
6461 Requires flags = requiresAttr.getValue();
6462 llvm::OpenMPIRBuilderConfig &
config =
6463 moduleTranslation.getOpenMPBuilder()->Config;
6464 config.setHasRequiresReverseOffload(
6465 bitEnumContainsAll(flags, Requires::reverse_offload));
6466 config.setHasRequiresUnifiedAddress(
6467 bitEnumContainsAll(flags, Requires::unified_address));
6468 config.setHasRequiresUnifiedSharedMemory(
6469 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6470 config.setHasRequiresDynamicAllocators(
6471 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6476 .Case(
"omp.target_triples",
6478 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6479 llvm::OpenMPIRBuilderConfig &
config =
6480 moduleTranslation.getOpenMPBuilder()->Config;
6481 config.TargetTriples.clear();
6482 config.TargetTriples.reserve(triplesAttr.size());
6483 for (
Attribute tripleAttr : triplesAttr) {
6484 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6485 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6503 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6504 Operation *op, llvm::IRBuilderBase &builder,
6505 LLVM::ModuleTranslation &moduleTranslation)
const {
6507 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6508 if (ompBuilder->Config.isTargetDevice()) {
6518 registry.
insert<omp::OpenMPDialect>();
6520 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
union mlir::linalg::@1252::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.