24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 mlir::Value &linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(moduleTranslation.
lookupValue(linearVar));
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
191 llvm::LoadInst *linearVarStart = builder.CreateLoad(
192 linearVarTypes[index], linearPreconditionVars[index]);
193 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
194 if (linearVarTypes[index]->isIntegerTy()) {
195 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
196 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
197 }
else if (linearVarTypes[index]->isFloatingPointTy()) {
198 auto cvt = builder.CreateSIToFP(mulInst, linearVarTypes[index]);
199 auto addInst = builder.CreateFAdd(linearVarStart, cvt);
200 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
207 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
208 llvm::BasicBlock *loopExit) {
209 linearFinalizationBB = loopExit->splitBasicBlock(
210 loopExit->getTerminator(),
"omp_loop.linear_finalization");
211 linearExitBB = linearFinalizationBB->splitBasicBlock(
212 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
213 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
214 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
218 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
219 finalizeLinearVar(llvm::IRBuilderBase &builder,
220 LLVM::ModuleTranslation &moduleTranslation,
221 llvm::Value *lastIter) {
223 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
224 llvm::Value *loopLastIterLoad = builder.CreateLoad(
225 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
226 llvm::Value *isLast =
227 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
228 llvm::ConstantInt::get(
229 llvm::Type::getInt32Ty(builder.getContext()), 0));
231 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
232 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
233 llvm::LoadInst *linearVarTemp =
234 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
235 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
245 builder.SetInsertPoint(linearExitBB->getTerminator());
247 builder.saveIP(), llvm::omp::OMPD_barrier);
252 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
254 llvm::SmallVector<llvm::User *> users;
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(user);
257 for (
auto *user : users) {
258 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259 if (userInst->getParent()->getName().str().find(BBName) !=
261 user->replaceUsesOfWith(linearOrigVal[varIndex],
262 linearLoopBodyTemps[varIndex]);
273 SymbolRefAttr symbolName) {
274 omp::PrivateClauseOp privatizer =
277 assert(privatizer &&
"privatizer not found in the symbol table");
288 auto todo = [&op](StringRef clauseName) {
289 return op.
emitError() <<
"not yet implemented: Unhandled clause "
290 << clauseName <<
" in " << op.
getName()
294 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
295 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
296 result = todo(
"allocate");
298 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
300 result = todo(
"ompx_bare");
302 auto checkCancelDirective = [&todo](
auto op, LogicalResult &
result) {
303 omp::ClauseCancellationConstructType cancelledDirective =
304 op.getCancelDirective();
307 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
314 if (isa_and_nonnull<omp::TaskloopOp>(parent))
315 result = todo(
"cancel directive inside of taskloop");
318 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
319 if (!op.getDependVars().empty() || op.getDependKinds())
322 auto checkDevice = [&todo](
auto op, LogicalResult &
result) {
326 auto checkHint = [](
auto op, LogicalResult &) {
330 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
331 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
332 op.getInReductionSyms())
333 result = todo(
"in_reduction");
335 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
339 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
340 if (op.getOrder() || op.getOrderMod())
343 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &
result) {
344 if (op.getParLevelSimd())
345 result = todo(
"parallelization-level");
347 auto checkPriority = [&todo](
auto op, LogicalResult &
result) {
348 if (op.getPriority())
349 result = todo(
"priority");
351 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
352 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
353 result = todo(
"privatization");
355 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
356 if (isa<omp::TeamsOp>(op))
357 if (!op.getReductionVars().empty() || op.getReductionByref() ||
358 op.getReductionSyms())
359 result = todo(
"reduction");
360 if (op.getReductionMod() &&
361 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
362 result = todo(
"reduction with modifier");
364 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
365 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
366 op.getTaskReductionSyms())
367 result = todo(
"task_reduction");
369 auto checkUntied = [&todo](
auto op, LogicalResult &
result) {
376 .Case([&](omp::CancelOp op) { checkCancelDirective(op,
result); })
377 .Case([&](omp::CancellationPointOp op) {
378 checkCancelDirective(op,
result);
380 .Case([&](omp::DistributeOp op) {
381 checkAllocate(op,
result);
384 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op,
result); })
385 .Case([&](omp::SectionsOp op) {
386 checkAllocate(op,
result);
388 checkReduction(op,
result);
390 .Case([&](omp::SingleOp op) {
391 checkAllocate(op,
result);
394 .Case([&](omp::TeamsOp op) {
395 checkAllocate(op,
result);
398 .Case([&](omp::TaskOp op) {
399 checkAllocate(op,
result);
400 checkInReduction(op,
result);
402 .Case([&](omp::TaskgroupOp op) {
403 checkAllocate(op,
result);
404 checkTaskReduction(op,
result);
406 .Case([&](omp::TaskwaitOp op) {
410 .Case([&](omp::TaskloopOp op) {
413 checkPriority(op,
result);
415 .Case([&](omp::WsloopOp op) {
416 checkAllocate(op,
result);
418 checkReduction(op,
result);
420 .Case([&](omp::ParallelOp op) {
421 checkAllocate(op,
result);
422 checkReduction(op,
result);
424 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
425 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
426 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
427 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
428 [&](
auto op) { checkDepend(op,
result); })
429 .Case([&](omp::TargetOp op) {
430 checkAllocate(op,
result);
433 checkInReduction(op,
result);
445 llvm::handleAllErrors(
447 [&](
const PreviouslyReportedError &) {
result = failure(); },
448 [&](
const llvm::ErrorInfoBase &err) {
465static llvm::OpenMPIRBuilder::InsertPointTy
471 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
473 [&](OpenMPAllocaStackFrame &frame) {
474 allocaInsertPoint = frame.allocaInsertPoint;
482 allocaInsertPoint.getBlock()->getParent() ==
483 builder.GetInsertBlock()->getParent())
484 return allocaInsertPoint;
493 if (builder.GetInsertBlock() ==
494 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
495 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
496 "Assuming end of basic block");
497 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
498 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
499 builder.GetInsertBlock()->getNextNode());
500 builder.CreateBr(entryBB);
501 builder.SetInsertPoint(entryBB);
504 llvm::BasicBlock &funcEntryBlock =
505 builder.GetInsertBlock()->getParent()->getEntryBlock();
506 return llvm::OpenMPIRBuilder::InsertPointTy(
507 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
513static llvm::CanonicalLoopInfo *
515 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
516 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
517 [&](OpenMPLoopInfoStackFrame &frame) {
518 loopInfo = frame.loopInfo;
530 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
533 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
535 llvm::BasicBlock *continuationBlock =
536 splitBB(builder,
true,
"omp.region.cont");
537 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
539 llvm::LLVMContext &llvmContext = builder.getContext();
540 for (
Block &bb : region) {
541 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
542 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
543 builder.GetInsertBlock()->getNextNode());
544 moduleTranslation.
mapBlock(&bb, llvmBB);
547 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
554 unsigned numYields = 0;
556 if (!isLoopWrapper) {
557 bool operandsProcessed =
false;
559 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
560 if (!operandsProcessed) {
561 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
562 continuationBlockPHITypes.push_back(
563 moduleTranslation.
convertType(yield->getOperand(i).getType()));
565 operandsProcessed =
true;
567 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
568 "mismatching number of values yielded from the region");
569 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
570 llvm::Type *operandType =
571 moduleTranslation.
convertType(yield->getOperand(i).getType());
573 assert(continuationBlockPHITypes[i] == operandType &&
574 "values of mismatching types yielded from the region");
584 if (!continuationBlockPHITypes.empty())
586 continuationBlockPHIs &&
587 "expected continuation block PHIs if converted regions yield values");
588 if (continuationBlockPHIs) {
589 llvm::IRBuilderBase::InsertPointGuard guard(builder);
590 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
591 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
592 for (llvm::Type *ty : continuationBlockPHITypes)
593 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
599 for (
Block *bb : blocks) {
600 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
603 if (bb->isEntryBlock()) {
604 assert(sourceTerminator->getNumSuccessors() == 1 &&
605 "provided entry block has multiple successors");
606 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
607 "ContinuationBlock is not the successor of the entry block");
608 sourceTerminator->setSuccessor(0, llvmBB);
611 llvm::IRBuilderBase::InsertPointGuard guard(builder);
613 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
614 return llvm::make_error<PreviouslyReportedError>();
619 builder.CreateBr(continuationBlock);
630 Operation *terminator = bb->getTerminator();
631 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
632 builder.CreateBr(continuationBlock);
634 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
635 (*continuationBlockPHIs)[i]->addIncoming(
649 return continuationBlock;
655 case omp::ClauseProcBindKind::Close:
656 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
657 case omp::ClauseProcBindKind::Master:
658 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
659 case omp::ClauseProcBindKind::Primary:
660 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
661 case omp::ClauseProcBindKind::Spread:
662 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
664 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
674 omp::BlockArgOpenMPOpInterface blockArgIface) {
676 blockArgIface.getBlockArgsPairs(blockArgsPairs);
677 for (
auto [var, arg] : blockArgsPairs)
685 auto maskedOp = cast<omp::MaskedOp>(opInst);
686 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
691 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
693 auto ®ion = maskedOp.getRegion();
694 builder.restoreIP(codeGenIP);
702 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
704 llvm::Value *filterVal =
nullptr;
705 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
706 filterVal = moduleTranslation.
lookupValue(filterVar);
708 llvm::LLVMContext &llvmContext = builder.getContext();
710 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
712 assert(filterVal !=
nullptr);
713 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
714 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
721 builder.restoreIP(*afterIP);
729 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
730 auto masterOp = cast<omp::MasterOp>(opInst);
735 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
737 auto ®ion = masterOp.getRegion();
738 builder.restoreIP(codeGenIP);
746 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
748 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
749 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
756 builder.restoreIP(*afterIP);
764 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
765 auto criticalOp = cast<omp::CriticalOp>(opInst);
770 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
772 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
773 builder.restoreIP(codeGenIP);
781 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
783 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
784 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
785 llvm::Constant *hint =
nullptr;
788 if (criticalOp.getNameAttr()) {
791 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
792 auto criticalDeclareOp =
796 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
797 static_cast<int>(criticalDeclareOp.getHint()));
799 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
801 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
806 builder.restoreIP(*afterIP);
813 template <
typename OP>
816 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
819 collectPrivatizationDecls<OP>(op);
834 void collectPrivatizationDecls(OP op) {
835 std::optional<ArrayAttr> attr = op.getPrivateSyms();
840 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
851 std::optional<ArrayAttr> attr = op.getReductionSyms();
855 reductions.reserve(reductions.size() + op.getNumReductionVars());
856 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
857 reductions.push_back(
869 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
878 llvm::Instruction *potentialTerminator =
879 builder.GetInsertBlock()->empty() ?
nullptr
880 : &builder.GetInsertBlock()->back();
882 if (potentialTerminator && potentialTerminator->isTerminator())
883 potentialTerminator->removeFromParent();
884 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
887 region.
front(),
true, builder)))
891 if (continuationBlockArgs)
893 *continuationBlockArgs,
900 if (potentialTerminator && potentialTerminator->isTerminator()) {
901 llvm::BasicBlock *block = builder.GetInsertBlock();
902 if (block->empty()) {
908 potentialTerminator->insertInto(block, block->begin());
910 potentialTerminator->insertAfter(&block->back());
924 if (continuationBlockArgs)
925 llvm::append_range(*continuationBlockArgs, phis);
926 builder.SetInsertPoint(*continuationBlock,
927 (*continuationBlock)->getFirstInsertionPt());
934using OwningReductionGen =
935 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
936 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
938using OwningAtomicReductionGen =
939 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
940 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
942using OwningDataPtrPtrReductionGen =
943 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
944 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
950static OwningReductionGen
956 OwningReductionGen gen =
957 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
958 llvm::Value *
lhs, llvm::Value *
rhs,
959 llvm::Value *&
result)
mutable
960 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
961 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
962 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
963 builder.restoreIP(insertPoint);
966 "omp.reduction.nonatomic.body", builder,
967 moduleTranslation, &phis)))
968 return llvm::createStringError(
969 "failed to inline `combiner` region of `omp.declare_reduction`");
970 result = llvm::getSingleElement(phis);
971 return builder.saveIP();
980static OwningAtomicReductionGen
982 llvm::IRBuilderBase &builder,
984 if (decl.getAtomicReductionRegion().empty())
985 return OwningAtomicReductionGen();
991 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
992 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
993 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
994 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
995 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
996 builder.restoreIP(insertPoint);
999 "omp.reduction.atomic.body", builder,
1000 moduleTranslation, &phis)))
1001 return llvm::createStringError(
1002 "failed to inline `atomic` region of `omp.declare_reduction`");
1003 assert(phis.empty());
1004 return builder.saveIP();
1013static OwningDataPtrPtrReductionGen
1017 return OwningDataPtrPtrReductionGen();
1019 OwningDataPtrPtrReductionGen refDataPtrGen =
1020 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1021 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1022 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1023 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1024 builder.restoreIP(insertPoint);
1027 "omp.data_ptr_ptr.body", builder,
1028 moduleTranslation, &phis)))
1029 return llvm::createStringError(
1030 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1031 result = llvm::getSingleElement(phis);
1032 return builder.saveIP();
1035 return refDataPtrGen;
1042 auto orderedOp = cast<omp::OrderedOp>(opInst);
1047 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1048 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1049 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1051 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1053 size_t indexVecValues = 0;
1054 while (indexVecValues < vecValues.size()) {
1056 storeValues.reserve(numLoops);
1057 for (
unsigned i = 0; i < numLoops; i++) {
1058 storeValues.push_back(vecValues[indexVecValues]);
1061 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1063 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1064 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1065 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1075 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1076 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1081 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1083 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1084 builder.restoreIP(codeGenIP);
1092 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1094 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1095 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1097 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1102 builder.restoreIP(*afterIP);
1108struct DeferredStore {
1109 DeferredStore(llvm::Value *value, llvm::Value *address)
1110 : value(value), address(address) {}
1113 llvm::Value *address;
1120template <
typename T>
1123 llvm::IRBuilderBase &builder,
1125 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1131 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1132 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1135 deferredStores.reserve(loop.getNumReductionVars());
1137 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1138 Region &allocRegion = reductionDecls[i].getAllocRegion();
1140 if (allocRegion.
empty())
1145 builder, moduleTranslation, &phis)))
1146 return loop.emitError(
1147 "failed to inline `alloc` region of `omp.declare_reduction`");
1149 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1150 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1154 llvm::Value *var = builder.CreateAlloca(
1155 moduleTranslation.
convertType(reductionDecls[i].getType()));
1157 llvm::Type *ptrTy = builder.getPtrTy();
1158 llvm::Value *castVar =
1159 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1160 llvm::Value *castPhi =
1161 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1163 deferredStores.emplace_back(castPhi, castVar);
1165 privateReductionVariables[i] = castVar;
1166 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1167 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1169 assert(allocRegion.
empty() &&
1170 "allocaction is implicit for by-val reduction");
1171 llvm::Value *var = builder.CreateAlloca(
1172 moduleTranslation.
convertType(reductionDecls[i].getType()));
1174 llvm::Type *ptrTy = builder.getPtrTy();
1175 llvm::Value *castVar =
1176 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1178 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1179 privateReductionVariables[i] = castVar;
1180 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1188template <
typename T>
1191 llvm::IRBuilderBase &builder,
1196 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1197 Region &initializerRegion = reduction.getInitializerRegion();
1200 mlir::Value mlirSource = loop.getReductionVars()[i];
1201 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1202 llvm::Value *origVal = llvmSource;
1204 if (!isa<LLVM::LLVMPointerType>(
1205 reduction.getInitializerMoldArg().getType()) &&
1206 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1209 reduction.getInitializerMoldArg().getType()),
1210 llvmSource,
"omp_orig");
1212 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1215 llvm::Value *allocation =
1216 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1217 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1223 llvm::BasicBlock *block =
nullptr) {
1224 if (block ==
nullptr)
1225 block = builder.GetInsertBlock();
1227 if (block->empty() || block->getTerminator() ==
nullptr)
1228 builder.SetInsertPoint(block);
1230 builder.SetInsertPoint(block->getTerminator());
1238template <
typename OP>
1241 llvm::IRBuilderBase &builder,
1243 llvm::BasicBlock *latestAllocaBlock,
1249 if (op.getNumReductionVars() == 0)
1252 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1253 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1254 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1255 builder.restoreIP(allocaIP);
1258 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1260 if (!reductionDecls[i].getAllocRegion().empty())
1266 byRefVars[i] = builder.CreateAlloca(
1267 moduleTranslation.
convertType(reductionDecls[i].getType()));
1275 for (
auto [data, addr] : deferredStores)
1276 builder.CreateStore(data, addr);
1281 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1286 reductionVariableMap, i);
1294 "omp.reduction.neutral", builder,
1295 moduleTranslation, &phis)))
1298 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1299 "reduction neutral element declaration region");
1304 if (!reductionDecls[i].getAllocRegion().empty())
1313 builder.CreateStore(phis[0], byRefVars[i]);
1315 privateReductionVariables[i] = byRefVars[i];
1316 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1317 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1320 builder.CreateStore(phis[0], privateReductionVariables[i]);
1327 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1334template <
typename T>
1335static void collectReductionInfo(
1336 T loop, llvm::IRBuilderBase &builder,
1345 unsigned numReductions = loop.getNumReductionVars();
1347 for (
unsigned i = 0; i < numReductions; ++i) {
1350 owningAtomicReductionGens.push_back(
1353 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1357 reductionInfos.reserve(numReductions);
1358 for (
unsigned i = 0; i < numReductions; ++i) {
1359 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1360 if (owningAtomicReductionGens[i])
1361 atomicGen = owningAtomicReductionGens[i];
1362 llvm::Value *variable =
1363 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1366 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1367 allocatedType = alloca.getElemType();
1374 reductionInfos.push_back(
1376 privateReductionVariables[i],
1377 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1381 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1382 reductionDecls[i].getByrefElementType()
1384 *reductionDecls[i].getByrefElementType())
1394 llvm::IRBuilderBase &builder, StringRef regionName,
1395 bool shouldLoadCleanupRegionArg =
true) {
1396 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1397 if (cleanupRegion->empty())
1403 llvm::Instruction *potentialTerminator =
1404 builder.GetInsertBlock()->empty() ?
nullptr
1405 : &builder.GetInsertBlock()->back();
1406 if (potentialTerminator && potentialTerminator->isTerminator())
1407 builder.SetInsertPoint(potentialTerminator);
1408 llvm::Value *privateVarValue =
1409 shouldLoadCleanupRegionArg
1410 ? builder.CreateLoad(
1412 privateVariables[i])
1413 : privateVariables[i];
1418 moduleTranslation)))
1431 OP op, llvm::IRBuilderBase &builder,
1433 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1436 bool isNowait =
false,
bool isTeamsReduction =
false) {
1438 if (op.getNumReductionVars() == 0)
1450 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1452 owningReductionGenRefDataPtrGens,
1453 privateReductionVariables, reductionInfos, isByRef);
1458 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1459 builder.SetInsertPoint(tempTerminator);
1460 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1461 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1462 isByRef, isNowait, isTeamsReduction);
1467 if (!contInsertPoint->getBlock())
1468 return op->emitOpError() <<
"failed to convert reductions";
1470 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1471 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1476 tempTerminator->eraseFromParent();
1477 builder.restoreIP(*afterIP);
1481 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1482 [](omp::DeclareReductionOp reductionDecl) {
1483 return &reductionDecl.getCleanupRegion();
1486 moduleTranslation, builder,
1487 "omp.reduction.cleanup");
1498template <
typename OP>
1502 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1507 if (op.getNumReductionVars() == 0)
1513 allocaIP, reductionDecls,
1514 privateReductionVariables, reductionVariableMap,
1515 deferredStores, isByRef)))
1518 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1519 allocaIP.getBlock(), reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 isByRef, deferredStores);
1535 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1538 Value blockArg = (*mappedPrivateVars)[privateVar];
1541 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1542 "A block argument corresponding to a mapped var should have "
1545 if (privVarType == blockArgType)
1552 if (!isa<LLVM::LLVMPointerType>(privVarType))
1553 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1566 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1568 Region &initRegion = privDecl.getInitRegion();
1569 if (initRegion.
empty())
1570 return llvmPrivateVar;
1574 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1575 assert(nonPrivateVar);
1576 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1577 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1582 moduleTranslation, &phis)))
1583 return llvm::createStringError(
1584 "failed to inline `init` region of `omp.private`");
1586 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1605 return llvm::Error::success();
1607 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1610 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1613 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1615 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1616 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1619 return privVarOrErr.takeError();
1621 llvmPrivateVar = privVarOrErr.get();
1622 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1627 return llvm::Error::success();
1637 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1640 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1641 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1642 allocaTerminator->getIterator()),
1643 true, allocaTerminator->getStableDebugLoc(),
1644 "omp.region.after_alloca");
1646 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1648 allocaTerminator = allocaIP.getBlock()->getTerminator();
1649 builder.SetInsertPoint(allocaTerminator);
1651 assert(allocaTerminator->getNumSuccessors() == 1 &&
1652 "This is an unconditional branch created by splitBB");
1654 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1655 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1657 unsigned int allocaAS =
1658 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1661 .getProgramAddressSpace();
1663 for (
auto [privDecl, mlirPrivVar, blockArg] :
1666 llvm::Type *llvmAllocType =
1667 moduleTranslation.
convertType(privDecl.getType());
1668 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1669 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1670 llvmAllocType,
nullptr,
"omp.private.alloc");
1671 if (allocaAS != defaultAS)
1672 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1673 builder.getPtrTy(defaultAS));
1675 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1678 return afterAllocas;
1689 bool needsFirstprivate =
1690 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1691 return privOp.getDataSharingType() ==
1692 omp::DataSharingClauseType::FirstPrivate;
1695 if (!needsFirstprivate)
1698 llvm::BasicBlock *copyBlock =
1699 splitBB(builder,
true,
"omp.private.copy");
1702 for (
auto [decl, mlirVar, llvmVar] :
1703 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1704 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1708 Region ©Region = decl.getCopyRegion();
1712 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1713 assert(nonPrivateVar);
1714 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1717 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1721 moduleTranslation)))
1722 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1734 if (insertBarrier) {
1736 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1737 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1752 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1753 [](omp::PrivateClauseOp privatizer) {
1754 return &privatizer.getDeallocRegion();
1758 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1759 "omp.private.dealloc",
false)))
1760 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1761 "`omp.private` op in");
1773 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1783 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1784 using StorableBodyGenCallbackTy =
1785 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1787 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1793 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1797 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1801 sectionsOp.getNumReductionVars());
1805 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1808 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1809 reductionDecls, privateReductionVariables, reductionVariableMap,
1816 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1820 Region ®ion = sectionOp.getRegion();
1821 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1822 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1823 builder.restoreIP(codeGenIP);
1830 sectionsOp.getRegion().getNumArguments());
1831 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1832 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1833 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1835 moduleTranslation.
mapValue(sectionArg, llvmVal);
1842 sectionCBs.push_back(sectionCB);
1848 if (sectionCBs.empty())
1851 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1856 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1857 llvm::Value &vPtr, llvm::Value *&replacementValue)
1858 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1859 replacementValue = &vPtr;
1865 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1869 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1870 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1872 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1873 sectionsOp.getNowait());
1878 builder.restoreIP(*afterIP);
1882 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1883 privateReductionVariables, isByRef, sectionsOp.getNowait());
1890 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1891 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1896 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1897 builder.restoreIP(codegenIP);
1899 builder, moduleTranslation)
1902 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1906 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1909 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1910 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1912 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1913 llvmCPFuncs.push_back(
1917 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1919 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1925 builder.restoreIP(*afterIP);
1931 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1936 for (
auto ra : iface.getReductionBlockArgs())
1937 for (
auto &use : ra.getUses()) {
1938 auto *useOp = use.getOwner();
1940 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1941 debugUses.push_back(useOp);
1945 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1950 Operation *currentOp = currentDistOp.getOperation();
1951 if (distOp && (distOp != currentOp))
1960 for (
auto *use : debugUses)
1969 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1974 unsigned numReductionVars = op.getNumReductionVars();
1978 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1984 if (doTeamsReduction) {
1985 isByRef =
getIsByRef(op.getReductionByref());
1987 assert(isByRef.size() == op.getNumReductionVars());
1990 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1995 op, reductionArgs, builder, moduleTranslation, allocaIP,
1996 reductionDecls, privateReductionVariables, reductionVariableMap,
2001 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2003 moduleTranslation, allocaIP);
2004 builder.restoreIP(codegenIP);
2010 llvm::Value *numTeamsLower =
nullptr;
2011 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2012 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2014 llvm::Value *numTeamsUpper =
nullptr;
2015 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
2016 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
2018 llvm::Value *threadLimit =
nullptr;
2019 if (
Value threadLimitVar = op.getThreadLimit())
2020 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
2022 llvm::Value *ifExpr =
nullptr;
2023 if (
Value ifVar = op.getIfExpr())
2026 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2027 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2029 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2034 builder.restoreIP(*afterIP);
2035 if (doTeamsReduction) {
2038 op, builder, moduleTranslation, allocaIP, reductionDecls,
2039 privateReductionVariables, isByRef,
2049 if (dependVars.empty())
2051 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2052 llvm::omp::RTLDependenceKindTy type;
2054 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2055 case mlir::omp::ClauseTaskDepend::taskdependin:
2056 type = llvm::omp::RTLDependenceKindTy::DepIn;
2061 case mlir::omp::ClauseTaskDepend::taskdependout:
2062 case mlir::omp::ClauseTaskDepend::taskdependinout:
2063 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2065 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2066 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2068 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2069 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2072 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2073 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2074 dds.emplace_back(dd);
2086 llvm::IRBuilderBase &llvmBuilder,
2088 llvm::omp::Directive cancelDirective) {
2089 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2090 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2094 llvmBuilder.restoreIP(ip);
2100 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2101 return llvm::Error::success();
2106 ompBuilder.pushFinalizationCB(
2116 llvm::OpenMPIRBuilder &ompBuilder,
2117 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2118 ompBuilder.popFinalizationCB();
2119 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2120 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2121 assert(cancelBranch->getNumSuccessors() == 1 &&
2122 "cancel branch should have one target");
2123 cancelBranch->setSuccessor(0, constructFini);
2130class TaskContextStructManager {
2132 TaskContextStructManager(llvm::IRBuilderBase &builder,
2133 LLVM::ModuleTranslation &moduleTranslation,
2134 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2135 : builder{builder}, moduleTranslation{moduleTranslation},
2136 privateDecls{privateDecls} {}
2142 void generateTaskContextStruct();
2148 void createGEPsToPrivateVars();
2151 void freeStructPtr();
2153 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2154 return llvmPrivateVarGEPs;
2157 llvm::Value *getStructPtr() {
return structPtr; }
2160 llvm::IRBuilderBase &builder;
2161 LLVM::ModuleTranslation &moduleTranslation;
2162 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2165 SmallVector<llvm::Type *> privateVarTypes;
2169 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2172 llvm::Value *structPtr =
nullptr;
2174 llvm::Type *structTy =
nullptr;
2178void TaskContextStructManager::generateTaskContextStruct() {
2179 if (privateDecls.empty())
2181 privateVarTypes.reserve(privateDecls.size());
2183 for (omp::PrivateClauseOp &privOp : privateDecls) {
2186 if (!privOp.readsFromMold())
2188 Type mlirType = privOp.getType();
2189 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2192 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2195 llvm::DataLayout dataLayout =
2196 builder.GetInsertBlock()->getModule()->getDataLayout();
2197 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2198 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2201 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2203 "omp.task.context_ptr");
2206void TaskContextStructManager::createGEPsToPrivateVars() {
2208 assert(privateVarTypes.empty());
2213 llvmPrivateVarGEPs.clear();
2214 llvmPrivateVarGEPs.reserve(privateDecls.size());
2215 llvm::Value *zero = builder.getInt32(0);
2217 for (
auto privDecl : privateDecls) {
2218 if (!privDecl.readsFromMold()) {
2220 llvmPrivateVarGEPs.push_back(
nullptr);
2223 llvm::Value *iVal = builder.getInt32(i);
2224 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2225 llvmPrivateVarGEPs.push_back(gep);
2230void TaskContextStructManager::freeStructPtr() {
2234 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2236 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2237 builder.CreateFree(structPtr);
2244 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2249 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2261 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2266 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2267 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2268 builder.getContext(),
"omp.task.start",
2269 builder.GetInsertBlock()->getParent());
2270 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2271 builder.SetInsertPoint(branchToTaskStartBlock);
2274 llvm::BasicBlock *copyBlock =
2275 splitBB(builder,
true,
"omp.private.copy");
2276 llvm::BasicBlock *initBlock =
2277 splitBB(builder,
true,
"omp.private.init");
2293 moduleTranslation, allocaIP);
2296 builder.SetInsertPoint(initBlock->getTerminator());
2299 taskStructMgr.generateTaskContextStruct();
2306 taskStructMgr.createGEPsToPrivateVars();
2308 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2311 taskStructMgr.getLLVMPrivateVarGEPs())) {
2313 if (!privDecl.readsFromMold())
2315 assert(llvmPrivateVarAlloc &&
2316 "reads from mold so shouldn't have been skipped");
2319 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2320 blockArg, llvmPrivateVarAlloc, initBlock);
2321 if (!privateVarOrErr)
2322 return handleError(privateVarOrErr, *taskOp.getOperation());
2331 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2332 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2333 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2335 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2336 llvmPrivateVarAlloc);
2338 assert(llvmPrivateVarAlloc->getType() ==
2339 moduleTranslation.
convertType(blockArg.getType()));
2349 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2350 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2351 taskOp.getPrivateNeedsBarrier())))
2352 return llvm::failure();
2355 builder.SetInsertPoint(taskStartBlock);
2357 auto bodyCB = [&](InsertPointTy allocaIP,
2358 InsertPointTy codegenIP) -> llvm::Error {
2362 moduleTranslation, allocaIP);
2365 builder.restoreIP(codegenIP);
2367 llvm::BasicBlock *privInitBlock =
nullptr;
2369 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2372 auto [blockArg, privDecl, mlirPrivVar] = zip;
2374 if (privDecl.readsFromMold())
2377 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2378 llvm::Type *llvmAllocType =
2379 moduleTranslation.
convertType(privDecl.getType());
2380 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2381 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2382 llvmAllocType,
nullptr,
"omp.private.alloc");
2385 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2386 blockArg, llvmPrivateVar, privInitBlock);
2387 if (!privateVarOrError)
2388 return privateVarOrError.takeError();
2389 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2390 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2393 taskStructMgr.createGEPsToPrivateVars();
2394 for (
auto [i, llvmPrivVar] :
2395 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2397 assert(privateVarsInfo.
llvmVars[i] &&
2398 "This is added in the loop above");
2401 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2406 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2410 if (!privateDecl.readsFromMold())
2413 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2414 llvmPrivateVar = builder.CreateLoad(
2415 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2417 assert(llvmPrivateVar->getType() ==
2418 moduleTranslation.
convertType(blockArg.getType()));
2419 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2423 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2424 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2425 return llvm::make_error<PreviouslyReportedError>();
2427 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2432 return llvm::make_error<PreviouslyReportedError>();
2435 taskStructMgr.freeStructPtr();
2437 return llvm::Error::success();
2446 llvm::omp::Directive::OMPD_taskgroup);
2450 moduleTranslation, dds);
2452 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2453 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2455 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2457 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2458 taskOp.getMergeable(),
2459 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2460 moduleTranslation.
lookupValue(taskOp.getPriority()));
2468 builder.restoreIP(*afterIP);
2476 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2480 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2481 builder.restoreIP(codegenIP);
2483 builder, moduleTranslation)
2488 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2489 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2496 builder.restoreIP(*afterIP);
2515 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2519 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2521 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2525 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2528 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2529 llvm::Type *ivType = step->getType();
2530 llvm::Value *chunk =
nullptr;
2531 if (wsloopOp.getScheduleChunk()) {
2532 llvm::Value *chunkVar =
2533 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2534 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2537 omp::DistributeOp distributeOp =
nullptr;
2538 llvm::Value *distScheduleChunk =
nullptr;
2539 bool hasDistSchedule =
false;
2540 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
2541 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
2542 hasDistSchedule = distributeOp.getDistScheduleStatic();
2543 if (distributeOp.getDistScheduleChunkSize()) {
2544 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
2545 distributeOp.getDistScheduleChunkSize());
2546 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2554 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2558 wsloopOp.getNumReductionVars());
2561 builder, moduleTranslation, privateVarsInfo, allocaIP);
2568 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2573 moduleTranslation, allocaIP, reductionDecls,
2574 privateReductionVariables, reductionVariableMap,
2575 deferredStores, isByRef)))
2584 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2586 wsloopOp.getPrivateNeedsBarrier())))
2589 assert(afterAllocas.get()->getSinglePredecessor());
2590 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2592 afterAllocas.get()->getSinglePredecessor(),
2593 reductionDecls, privateReductionVariables,
2594 reductionVariableMap, isByRef, deferredStores)))
2598 bool isOrdered = wsloopOp.getOrdered().has_value();
2599 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2600 bool isSimd = wsloopOp.getScheduleSimd();
2601 bool loopNeedsBarrier = !wsloopOp.getNowait();
2606 llvm::omp::WorksharingLoopType workshareLoopType =
2607 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2608 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2609 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2613 llvm::omp::Directive::OMPD_for);
2615 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2618 LinearClauseProcessor linearClauseProcessor;
2620 if (!wsloopOp.getLinearVars().empty()) {
2621 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
2623 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
2625 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
2626 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2628 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2629 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2633 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2641 if (!wsloopOp.getLinearVars().empty()) {
2642 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2643 loopInfo->getPreheader());
2644 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2646 builder.saveIP(), llvm::omp::OMPD_barrier);
2649 builder.restoreIP(*afterBarrierIP);
2650 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2651 loopInfo->getIndVar());
2652 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
2655 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2658 bool noLoopMode =
false;
2659 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
2661 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
2665 if (loopOp == targetCapturedOp) {
2666 omp::TargetRegionFlags kernelFlags =
2667 targetOp.getKernelExecFlags(targetCapturedOp);
2668 if (omp::bitEnumContainsAll(kernelFlags,
2669 omp::TargetRegionFlags::spmd |
2670 omp::TargetRegionFlags::no_loop) &&
2671 !omp::bitEnumContainsAny(kernelFlags,
2672 omp::TargetRegionFlags::generic))
2677 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2678 ompBuilder->applyWorkshareLoop(
2679 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2680 convertToScheduleKind(schedule), chunk, isSimd,
2681 scheduleMod == omp::ScheduleModifier::monotonic,
2682 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2683 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
2689 if (!wsloopOp.getLinearVars().empty()) {
2690 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2691 assert(loopInfo->getLastIter() &&
2692 "`lastiter` in CanonicalLoopInfo is nullptr");
2693 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2694 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2695 loopInfo->getLastIter());
2698 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
2699 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2701 builder.restoreIP(oldIP);
2709 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2710 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2723 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2725 assert(isByRef.size() == opInst.getNumReductionVars());
2738 opInst.getNumReductionVars());
2741 auto bodyGenCB = [&](InsertPointTy allocaIP,
2742 InsertPointTy codeGenIP) -> llvm::Error {
2744 builder, moduleTranslation, privateVarsInfo, allocaIP);
2746 return llvm::make_error<PreviouslyReportedError>();
2752 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2755 InsertPointTy(allocaIP.getBlock(),
2756 allocaIP.getBlock()->getTerminator()->getIterator());
2759 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2760 reductionDecls, privateReductionVariables, reductionVariableMap,
2761 deferredStores, isByRef)))
2762 return llvm::make_error<PreviouslyReportedError>();
2764 assert(afterAllocas.get()->getSinglePredecessor());
2765 builder.restoreIP(codeGenIP);
2771 return llvm::make_error<PreviouslyReportedError>();
2774 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2776 opInst.getPrivateNeedsBarrier())))
2777 return llvm::make_error<PreviouslyReportedError>();
2780 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
2781 afterAllocas.get()->getSinglePredecessor(),
2782 reductionDecls, privateReductionVariables,
2783 reductionVariableMap, isByRef, deferredStores)))
2784 return llvm::make_error<PreviouslyReportedError>();
2789 moduleTranslation, allocaIP);
2793 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2795 return regionBlock.takeError();
2798 if (opInst.getNumReductionVars() > 0) {
2803 owningReductionGenRefDataPtrGens;
2805 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
2807 owningReductionGenRefDataPtrGens,
2808 privateReductionVariables, reductionInfos, isByRef);
2811 builder.SetInsertPoint((*regionBlock)->getTerminator());
2814 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2815 builder.SetInsertPoint(tempTerminator);
2817 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2818 ompBuilder->createReductions(
2819 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2821 if (!contInsertPoint)
2822 return contInsertPoint.takeError();
2824 if (!contInsertPoint->getBlock())
2825 return llvm::make_error<PreviouslyReportedError>();
2827 tempTerminator->eraseFromParent();
2828 builder.restoreIP(*contInsertPoint);
2831 return llvm::Error::success();
2834 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2835 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2844 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2845 InsertPointTy oldIP = builder.saveIP();
2846 builder.restoreIP(codeGenIP);
2851 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2852 [](omp::DeclareReductionOp reductionDecl) {
2853 return &reductionDecl.getCleanupRegion();
2856 reductionCleanupRegions, privateReductionVariables,
2857 moduleTranslation, builder,
"omp.reduction.cleanup")))
2858 return llvm::createStringError(
2859 "failed to inline `cleanup` region of `omp.declare_reduction`");
2864 return llvm::make_error<PreviouslyReportedError>();
2868 if (isCancellable) {
2869 auto IPOrErr = ompBuilder->createBarrier(
2870 llvm::OpenMPIRBuilder::LocationDescription(builder),
2871 llvm::omp::Directive::OMPD_unknown,
2875 return IPOrErr.takeError();
2878 builder.restoreIP(oldIP);
2879 return llvm::Error::success();
2882 llvm::Value *ifCond =
nullptr;
2883 if (
auto ifVar = opInst.getIfExpr())
2885 llvm::Value *numThreads =
nullptr;
2886 if (
auto numThreadsVar = opInst.getNumThreads())
2887 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2888 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2889 if (
auto bind = opInst.getProcBindKind())
2892 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2894 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2896 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2897 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2898 ifCond, numThreads, pbKind, isCancellable);
2903 builder.restoreIP(*afterIP);
2908static llvm::omp::OrderKind
2911 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2913 case omp::ClauseOrderKind::Concurrent:
2914 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2916 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2924 auto simdOp = cast<omp::SimdOp>(opInst);
2932 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2935 simdOp.getNumReductionVars());
2940 assert(isByRef.size() == simdOp.getNumReductionVars());
2942 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2946 LinearClauseProcessor linearClauseProcessor;
2948 if (!simdOp.getLinearVars().empty()) {
2949 auto linearVarTypes = simdOp.getLinearVarTypes().value();
2951 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
2952 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars()))
2953 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2955 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
2956 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2960 builder, moduleTranslation, privateVarsInfo, allocaIP);
2965 moduleTranslation, allocaIP, reductionDecls,
2966 privateReductionVariables, reductionVariableMap,
2967 deferredStores, isByRef)))
2978 assert(afterAllocas.get()->getSinglePredecessor());
2979 if (failed(initReductionVars(simdOp, reductionArgs, builder,
2981 afterAllocas.get()->getSinglePredecessor(),
2982 reductionDecls, privateReductionVariables,
2983 reductionVariableMap, isByRef, deferredStores)))
2986 llvm::ConstantInt *simdlen =
nullptr;
2987 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2988 simdlen = builder.getInt64(simdlenVar.value());
2990 llvm::ConstantInt *safelen =
nullptr;
2991 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2992 safelen = builder.getInt64(safelenVar.value());
2994 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2997 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2998 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3000 for (
size_t i = 0; i < operands.size(); ++i) {
3001 llvm::Value *alignment =
nullptr;
3002 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3003 llvm::Type *ty = llvmVal->getType();
3005 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3006 alignment = builder.getInt64(intAttr.getInt());
3007 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3008 assert(alignment &&
"Invalid alignment value");
3012 if (!intAttr.getValue().isPowerOf2())
3015 auto curInsert = builder.saveIP();
3016 builder.SetInsertPoint(sourceBlock);
3017 llvmVal = builder.CreateLoad(ty, llvmVal);
3018 builder.restoreIP(curInsert);
3019 alignedVars[llvmVal] = alignment;
3023 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3030 if (simdOp.getLinearVars().size()) {
3031 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3032 loopInfo->getPreheader());
3034 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3035 loopInfo->getIndVar());
3037 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3039 ompBuilder->applySimd(loopInfo, alignedVars,
3041 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3043 order, simdlen, safelen);
3045 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++)
3046 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3053 for (
auto [i, tuple] : llvm::enumerate(
3054 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3055 privateReductionVariables))) {
3056 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3058 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3059 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3060 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3064 llvm::Value *redValue = originalVariable;
3067 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3068 llvm::Value *privateRedValue = builder.CreateLoad(
3069 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3070 llvm::Value *reduced;
3072 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3075 builder.restoreIP(res.get());
3079 builder.CreateStore(reduced, originalVariable);
3084 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3085 [](omp::DeclareReductionOp reductionDecl) {
3086 return &reductionDecl.getCleanupRegion();
3089 moduleTranslation, builder,
3090 "omp.reduction.cleanup")))
3103 auto loopOp = cast<omp::LoopNestOp>(opInst);
3106 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3111 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3112 llvm::Value *iv) -> llvm::Error {
3115 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3120 bodyInsertPoints.push_back(ip);
3122 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3123 return llvm::Error::success();
3126 builder.restoreIP(ip);
3128 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3130 return regionBlock.takeError();
3132 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3133 return llvm::Error::success();
3141 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3142 llvm::Value *lowerBound =
3143 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3144 llvm::Value *upperBound =
3145 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3146 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3151 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3152 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3154 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3156 computeIP = loopInfos.front()->getPreheaderIP();
3160 ompBuilder->createCanonicalLoop(
3161 loc, bodyGen, lowerBound, upperBound, step,
3162 true, loopOp.getLoopInclusive(), computeIP);
3167 loopInfos.push_back(*loopResult);
3170 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3171 loopInfos.front()->getAfterIP();
3174 if (
const auto &tiles = loopOp.getTileSizes()) {
3175 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3178 for (
auto tile : tiles.value()) {
3179 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3180 tileSizes.push_back(tileVal);
3183 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3184 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3188 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3189 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3190 afterIP = {afterAfterBB, afterAfterBB->begin()};
3194 for (
const auto &newLoop : newLoops)
3195 loopInfos.push_back(newLoop);
3199 const auto &numCollapse = loopOp.getCollapseNumLoops();
3201 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3203 auto newTopLoopInfo =
3204 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3206 assert(newTopLoopInfo &&
"New top loop information is missing");
3207 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3208 [&](OpenMPLoopInfoStackFrame &frame) {
3209 frame.loopInfo = newTopLoopInfo;
3217 builder.restoreIP(afterIP);
3227 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3228 Value loopIV = op.getInductionVar();
3229 Value loopTC = op.getTripCount();
3231 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3234 ompBuilder->createCanonicalLoop(
3236 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3239 moduleTranslation.
mapValue(loopIV, llvmIV);
3241 builder.restoreIP(ip);
3246 return bodyGenStatus.takeError();
3248 llvmTC,
"omp.loop");
3250 return op.emitError(llvm::toString(llvmOrError.takeError()));
3252 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3253 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3254 builder.restoreIP(afterIP);
3257 if (
Value cli = op.getCli())
3270 Value applyee = op.getApplyee();
3271 assert(applyee &&
"Loop to apply unrolling on required");
3273 llvm::CanonicalLoopInfo *consBuilderCLI =
3275 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3276 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3284static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3287 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3292 for (
Value size : op.getSizes()) {
3293 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3294 assert(translatedSize &&
3295 "sizes clause arguments must already be translated");
3296 translatedSizes.push_back(translatedSize);
3299 for (
Value applyee : op.getApplyees()) {
3300 llvm::CanonicalLoopInfo *consBuilderCLI =
3302 assert(applyee &&
"Canonical loop must already been translated");
3303 translatedLoops.push_back(consBuilderCLI);
3306 auto generatedLoops =
3307 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3308 if (!op.getGeneratees().empty()) {
3309 for (
auto [mlirLoop,
genLoop] :
3310 zip_equal(op.getGeneratees(), generatedLoops))
3315 for (
Value applyee : op.getApplyees())
3322static llvm::AtomicOrdering
3325 return llvm::AtomicOrdering::Monotonic;
3328 case omp::ClauseMemoryOrderKind::Seq_cst:
3329 return llvm::AtomicOrdering::SequentiallyConsistent;
3330 case omp::ClauseMemoryOrderKind::Acq_rel:
3331 return llvm::AtomicOrdering::AcquireRelease;
3332 case omp::ClauseMemoryOrderKind::Acquire:
3333 return llvm::AtomicOrdering::Acquire;
3334 case omp::ClauseMemoryOrderKind::Release:
3335 return llvm::AtomicOrdering::Release;
3336 case omp::ClauseMemoryOrderKind::Relaxed:
3337 return llvm::AtomicOrdering::Monotonic;
3339 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3346 auto readOp = cast<omp::AtomicReadOp>(opInst);
3351 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3354 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3357 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3358 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3360 llvm::Type *elementType =
3361 moduleTranslation.
convertType(readOp.getElementType());
3363 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3364 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3365 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3373 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3378 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3381 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3383 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3384 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3385 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3386 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3389 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3397 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3398 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3399 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3400 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3401 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3402 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3403 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3404 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3405 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3406 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3410 bool &isIgnoreDenormalMode,
3411 bool &isFineGrainedMemory,
3412 bool &isRemoteMemory) {
3413 isIgnoreDenormalMode =
false;
3414 isFineGrainedMemory =
false;
3415 isRemoteMemory =
false;
3416 if (atomicUpdateOp &&
3417 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3418 mlir::omp::AtomicControlAttr atomicControlAttr =
3419 atomicUpdateOp.getAtomicControlAttr();
3420 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3421 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3422 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3429 llvm::IRBuilderBase &builder,
3436 auto &innerOpList = opInst.getRegion().front().getOperations();
3437 bool isXBinopExpr{
false};
3438 llvm::AtomicRMWInst::BinOp binop;
3440 llvm::Value *llvmExpr =
nullptr;
3441 llvm::Value *llvmX =
nullptr;
3442 llvm::Type *llvmXElementType =
nullptr;
3443 if (innerOpList.size() == 2) {
3449 opInst.getRegion().getArgument(0))) {
3450 return opInst.emitError(
"no atomic update operation with region argument"
3451 " as operand found inside atomic.update region");
3454 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3456 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3460 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3462 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3464 opInst.getRegion().getArgument(0).getType());
3465 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3469 llvm::AtomicOrdering atomicOrdering =
3474 [&opInst, &moduleTranslation](
3475 llvm::Value *atomicx,
3478 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3479 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3480 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3481 return llvm::make_error<PreviouslyReportedError>();
3483 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3484 assert(yieldop && yieldop.getResults().size() == 1 &&
3485 "terminator must be omp.yield op and it must have exactly one "
3487 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3490 bool isIgnoreDenormalMode;
3491 bool isFineGrainedMemory;
3492 bool isRemoteMemory;
3497 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3498 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3499 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3500 atomicOrdering, binop, updateFn,
3501 isXBinopExpr, isIgnoreDenormalMode,
3502 isFineGrainedMemory, isRemoteMemory);
3507 builder.restoreIP(*afterIP);
3513 llvm::IRBuilderBase &builder,
3520 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3521 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3523 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3524 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3526 assert((atomicUpdateOp || atomicWriteOp) &&
3527 "internal op must be an atomic.update or atomic.write op");
3529 if (atomicWriteOp) {
3530 isPostfixUpdate =
true;
3531 mlirExpr = atomicWriteOp.getExpr();
3533 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3534 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3535 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3538 if (innerOpList.size() == 2) {
3541 atomicUpdateOp.getRegion().getArgument(0))) {
3542 return atomicUpdateOp.emitError(
3543 "no atomic update operation with region argument"
3544 " as operand found inside atomic.update region");
3548 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3551 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3555 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3556 llvm::Value *llvmX =
3557 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3558 llvm::Value *llvmV =
3559 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3560 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3561 atomicCaptureOp.getAtomicReadOp().getElementType());
3562 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3565 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3569 llvm::AtomicOrdering atomicOrdering =
3573 [&](llvm::Value *atomicx,
3576 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3577 Block &bb = *atomicUpdateOp.getRegion().
begin();
3578 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3580 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3581 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3582 return llvm::make_error<PreviouslyReportedError>();
3584 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3585 assert(yieldop && yieldop.getResults().size() == 1 &&
3586 "terminator must be omp.yield op and it must have exactly one "
3588 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3591 bool isIgnoreDenormalMode;
3592 bool isFineGrainedMemory;
3593 bool isRemoteMemory;
3595 isFineGrainedMemory, isRemoteMemory);
3598 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3599 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3600 ompBuilder->createAtomicCapture(
3601 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3602 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
3603 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
3605 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3608 builder.restoreIP(*afterIP);
3613 omp::ClauseCancellationConstructType directive) {
3614 switch (directive) {
3615 case omp::ClauseCancellationConstructType::Loop:
3616 return llvm::omp::Directive::OMPD_for;
3617 case omp::ClauseCancellationConstructType::Parallel:
3618 return llvm::omp::Directive::OMPD_parallel;
3619 case omp::ClauseCancellationConstructType::Sections:
3620 return llvm::omp::Directive::OMPD_sections;
3621 case omp::ClauseCancellationConstructType::Taskgroup:
3622 return llvm::omp::Directive::OMPD_taskgroup;
3624 llvm_unreachable(
"Unhandled cancellation construct type");
3633 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3636 llvm::Value *ifCond =
nullptr;
3637 if (
Value ifVar = op.getIfExpr())
3640 llvm::omp::Directive cancelledDirective =
3643 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3644 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3646 if (failed(
handleError(afterIP, *op.getOperation())))
3649 builder.restoreIP(afterIP.get());
3656 llvm::IRBuilderBase &builder,
3661 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3664 llvm::omp::Directive cancelledDirective =
3667 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3668 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3670 if (failed(
handleError(afterIP, *op.getOperation())))
3673 builder.restoreIP(afterIP.get());
3683 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3685 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3690 Value symAddr = threadprivateOp.getSymAddr();
3693 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3696 if (!isa<LLVM::AddressOfOp>(symOp))
3697 return opInst.
emitError(
"Addressing symbol not found");
3698 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3700 LLVM::GlobalOp global =
3701 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3702 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3704 if (!ompBuilder->Config.isTargetDevice()) {
3705 llvm::Type *type = globalValue->getValueType();
3706 llvm::TypeSize typeSize =
3707 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3709 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3710 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3711 ompLoc, globalValue, size, global.getSymName() +
".cache");
3720static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3722 switch (deviceClause) {
3723 case mlir::omp::DeclareTargetDeviceType::host:
3724 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3726 case mlir::omp::DeclareTargetDeviceType::nohost:
3727 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3729 case mlir::omp::DeclareTargetDeviceType::any:
3730 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3733 llvm_unreachable(
"unhandled device clause");
3736static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3738 mlir::omp::DeclareTargetCaptureClause captureClause) {
3739 switch (captureClause) {
3740 case mlir::omp::DeclareTargetCaptureClause::to:
3741 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3742 case mlir::omp::DeclareTargetCaptureClause::link:
3743 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3744 case mlir::omp::DeclareTargetCaptureClause::enter:
3745 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3746 case mlir::omp::DeclareTargetCaptureClause::none:
3747 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
3749 llvm_unreachable(
"unhandled capture clause");
3754 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
3756 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
3757 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3758 return modOp.lookupSymbol(addressOfOp.getGlobalName());
3763static llvm::SmallString<64>
3765 llvm::OpenMPIRBuilder &ompBuilder) {
3767 llvm::raw_svector_ostream os(suffix);
3770 auto fileInfoCallBack = [&loc]() {
3771 return std::pair<std::string, uint64_t>(
3772 llvm::StringRef(loc.getFilename()), loc.getLine());
3775 auto vfs = llvm::vfs::getRealFileSystem();
3778 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
3780 os <<
"_decl_tgt_ref_ptr";
3786 if (
auto declareTargetGlobal =
3787 dyn_cast_if_present<omp::DeclareTargetInterface>(
3789 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3790 omp::DeclareTargetCaptureClause::link)
3796 if (
auto declareTargetGlobal =
3797 dyn_cast_if_present<omp::DeclareTargetInterface>(
3799 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3800 omp::DeclareTargetCaptureClause::to ||
3801 declareTargetGlobal.getDeclareTargetCaptureClause() ==
3802 omp::DeclareTargetCaptureClause::enter)
3816 if (
auto declareTargetGlobal =
3817 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
3820 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3821 omp::DeclareTargetCaptureClause::link) ||
3822 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3823 omp::DeclareTargetCaptureClause::to &&
3824 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3828 if (gOp.getSymName().contains(suffix))
3833 (gOp.getSymName().str() + suffix.str()).str());
3842struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3843 SmallVector<Operation *, 4> Mappers;
3846 void append(MapInfosTy &curInfo) {
3847 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3848 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3857struct MapInfoData : MapInfosTy {
3858 llvm::SmallVector<bool, 4> IsDeclareTarget;
3859 llvm::SmallVector<bool, 4> IsAMember;
3861 llvm::SmallVector<bool, 4> IsAMapping;
3862 llvm::SmallVector<mlir::Operation *, 4> MapClause;
3863 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
3866 llvm::SmallVector<llvm::Type *, 4> BaseType;
3869 void append(MapInfoData &CurInfo) {
3870 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3871 CurInfo.IsDeclareTarget.end());
3872 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3873 OriginalValue.append(CurInfo.OriginalValue.begin(),
3874 CurInfo.OriginalValue.end());
3875 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3876 MapInfosTy::append(CurInfo);
3880enum class TargetDirectiveEnumTy : uint32_t {
3884 TargetEnterData = 3,
3889static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
3890 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
3891 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
3892 .Case([](omp::TargetEnterDataOp) {
3893 return TargetDirectiveEnumTy::TargetEnterData;
3895 .Case([&](omp::TargetExitDataOp) {
3896 return TargetDirectiveEnumTy::TargetExitData;
3898 .Case([&](omp::TargetUpdateOp) {
3899 return TargetDirectiveEnumTy::TargetUpdate;
3901 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
3902 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
3909 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3910 arrTy.getElementType()))
3927 llvm::Value *basePointer,
3928 llvm::Type *baseType,
3929 llvm::IRBuilderBase &builder,
3931 if (
auto memberClause =
3932 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3937 if (!memberClause.getBounds().empty()) {
3938 llvm::Value *elementCount = builder.getInt64(1);
3939 for (
auto bounds : memberClause.getBounds()) {
3940 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3941 bounds.getDefiningOp())) {
3946 elementCount = builder.CreateMul(
3950 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3951 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3952 builder.getInt64(1)));
3959 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3967 return builder.CreateMul(elementCount,
3968 builder.getInt64(underlyingTypeSzInBits / 8));
3979static llvm::omp::OpenMPOffloadMappingFlags
3981 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
3982 return (mlirFlags & flag) == flag;
3984 const bool hasExplicitMap =
3985 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
3986 omp::ClauseMapFlags::none;
3988 llvm::omp::OpenMPOffloadMappingFlags mapType =
3989 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
3992 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
3995 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
3998 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4001 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4004 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4007 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4010 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4013 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4016 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4019 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4022 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4025 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4028 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4029 if (!hasExplicitMap)
4030 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4040 ArrayRef<Value> useDevAddrOperands = {},
4041 ArrayRef<Value> hasDevAddrOperands = {}) {
4042 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4050 for (Value mapValue : mapVars) {
4051 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4052 for (
auto member : map.getMembers())
4053 if (member == mapOp)
4060 for (Value mapValue : mapVars) {
4061 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4063 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4064 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4065 mapData.Pointers.push_back(mapData.OriginalValue.back());
4067 if (llvm::Value *refPtr =
4069 mapData.IsDeclareTarget.push_back(
true);
4070 mapData.BasePointers.push_back(refPtr);
4072 mapData.IsDeclareTarget.push_back(
true);
4073 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4075 mapData.IsDeclareTarget.push_back(
false);
4076 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4079 mapData.BaseType.push_back(
4080 moduleTranslation.
convertType(mapOp.getVarType()));
4081 mapData.Sizes.push_back(
4082 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4083 mapData.BaseType.back(), builder, moduleTranslation));
4084 mapData.MapClause.push_back(mapOp.getOperation());
4086 mapData.Names.push_back(LLVM::createMappingInformation(
4088 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4089 if (mapOp.getMapperId())
4090 mapData.Mappers.push_back(
4092 mapOp, mapOp.getMapperIdAttr()));
4094 mapData.Mappers.push_back(
nullptr);
4095 mapData.IsAMapping.push_back(
true);
4096 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4099 auto findMapInfo = [&mapData](llvm::Value *val,
4100 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4103 for (llvm::Value *basePtr : mapData.OriginalValue) {
4104 if (basePtr == val && mapData.IsAMapping[index]) {
4106 mapData.Types[index] |=
4107 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4108 mapData.DevicePointers[index] = devInfoTy;
4116 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4117 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4118 for (Value mapValue : useDevOperands) {
4119 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4121 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4122 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4125 if (!findMapInfo(origValue, devInfoTy)) {
4126 mapData.OriginalValue.push_back(origValue);
4127 mapData.Pointers.push_back(mapData.OriginalValue.back());
4128 mapData.IsDeclareTarget.push_back(
false);
4129 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4130 mapData.BaseType.push_back(
4131 moduleTranslation.
convertType(mapOp.getVarType()));
4132 mapData.Sizes.push_back(builder.getInt64(0));
4133 mapData.MapClause.push_back(mapOp.getOperation());
4134 mapData.Types.push_back(
4135 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4136 mapData.Names.push_back(LLVM::createMappingInformation(
4138 mapData.DevicePointers.push_back(devInfoTy);
4139 mapData.Mappers.push_back(
nullptr);
4140 mapData.IsAMapping.push_back(
false);
4141 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4146 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4147 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4149 for (Value mapValue : hasDevAddrOperands) {
4150 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4152 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4153 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4155 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4157 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4158 omp::ClauseMapFlags::none;
4160 mapData.OriginalValue.push_back(origValue);
4161 mapData.BasePointers.push_back(origValue);
4162 mapData.Pointers.push_back(origValue);
4163 mapData.IsDeclareTarget.push_back(
false);
4164 mapData.BaseType.push_back(
4165 moduleTranslation.
convertType(mapOp.getVarType()));
4166 mapData.Sizes.push_back(
4167 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4168 mapData.MapClause.push_back(mapOp.getOperation());
4169 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4173 mapData.Types.push_back(mapType);
4177 if (mapOp.getMapperId()) {
4178 mapData.Mappers.push_back(
4180 mapOp, mapOp.getMapperIdAttr()));
4182 mapData.Mappers.push_back(
nullptr);
4187 mapData.Types.push_back(
4188 isDevicePtr ? mapType
4189 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4190 mapData.Mappers.push_back(
nullptr);
4192 mapData.Names.push_back(LLVM::createMappingInformation(
4194 mapData.DevicePointers.push_back(
4195 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4196 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4197 mapData.IsAMapping.push_back(
false);
4198 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4203 auto *res = llvm::find(mapData.MapClause, memberOp);
4204 assert(res != mapData.MapClause.end() &&
4205 "MapInfoOp for member not found in MapData, cannot return index");
4206 return std::distance(mapData.MapClause.begin(), res);
4210 omp::MapInfoOp mapInfo) {
4211 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4221 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4222 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4224 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4225 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4226 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4228 if (aIndex == bIndex)
4231 if (aIndex < bIndex)
4234 if (aIndex > bIndex)
4241 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4243 occludedChildren.push_back(
b);
4245 occludedChildren.push_back(a);
4246 return memberAParent;
4252 for (
auto v : occludedChildren)
4259 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4261 if (indexAttr.size() == 1)
4262 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4266 return llvm::cast<omp::MapInfoOp>(
4291static std::vector<llvm::Value *>
4293 llvm::IRBuilderBase &builder,
bool isArrayTy,
4295 std::vector<llvm::Value *> idx;
4306 idx.push_back(builder.getInt64(0));
4307 for (
int i = bounds.size() - 1; i >= 0; --i) {
4308 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4309 bounds[i].getDefiningOp())) {
4310 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4328 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4329 for (
int i = bounds.size() - 1; i >= 0; --i) {
4330 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4331 bounds[i].getDefiningOp())) {
4332 if (i == ((
int)bounds.size() - 1))
4334 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4336 idx.back() = builder.CreateAdd(
4337 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4338 boundOp.getExtent())),
4339 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4348 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
4349 return cast<IntegerAttr>(value).getInt();
4357 omp::MapInfoOp parentOp) {
4359 if (parentOp.getMembers().empty())
4363 if (parentOp.getMembers().size() == 1) {
4364 overlapMapDataIdxs.push_back(0);
4370 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4371 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4372 memberByIndex.push_back(
4373 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4378 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4379 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
4385 for (
auto v : memberByIndex) {
4389 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
4392 llvm::SmallVector<int64_t> xArr(x.second.size());
4393 getAsIntegers(x.second, xArr);
4394 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4395 xArr.size() >= vArr.size();
4401 for (
auto v : memberByIndex)
4402 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4403 overlapMapDataIdxs.push_back(v.first);
4415 if (mapOp.getVarPtrPtr())
4444 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4445 MapInfoData &mapData, uint64_t mapDataIndex,
4446 TargetDirectiveEnumTy targetDirective) {
4447 assert(!ompBuilder.Config.isTargetDevice() &&
4448 "function only supported for host device codegen");
4454 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4455 (targetDirective == TargetDirectiveEnumTy::Target &&
4456 !mapData.IsDeclareTarget[mapDataIndex])
4457 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4458 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4461 bool hasUserMapper = mapData.Mappers[mapDataIndex] !=
nullptr;
4462 if (hasUserMapper) {
4463 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4467 mapFlags parentFlags = mapData.Types[mapDataIndex];
4468 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4469 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4470 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4471 baseFlag |= (parentFlags & preserve);
4474 combinedInfo.Types.emplace_back(baseFlag);
4475 combinedInfo.DevicePointers.emplace_back(
4476 mapData.DevicePointers[mapDataIndex]);
4477 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4479 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4480 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4490 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4491 llvm::Value *lowAddr, *highAddr;
4492 if (!parentClause.getPartialMap()) {
4493 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4494 builder.getPtrTy());
4495 highAddr = builder.CreatePointerCast(
4496 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4497 mapData.Pointers[mapDataIndex], 1),
4498 builder.getPtrTy());
4499 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4501 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4504 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4505 builder.getPtrTy());
4508 highAddr = builder.CreatePointerCast(
4509 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4510 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4511 builder.getPtrTy());
4512 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4515 llvm::Value *size = builder.CreateIntCast(
4516 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4517 builder.getInt64Ty(),
4519 combinedInfo.Sizes.push_back(size);
4521 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4522 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4530 if (!parentClause.getPartialMap()) {
4535 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4536 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
4537 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
4538 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4539 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4541 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
4542 combinedInfo.Types.emplace_back(mapFlag);
4543 combinedInfo.DevicePointers.emplace_back(
4544 mapData.DevicePointers[mapDataIndex]);
4546 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4547 combinedInfo.BasePointers.emplace_back(
4548 mapData.BasePointers[mapDataIndex]);
4549 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4550 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4551 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4562 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4563 builder.getPtrTy());
4564 highAddr = builder.CreatePointerCast(
4565 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4566 mapData.Pointers[mapDataIndex], 1),
4567 builder.getPtrTy());
4574 for (
auto v : overlapIdxs) {
4577 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
4578 combinedInfo.Types.emplace_back(mapFlag);
4579 combinedInfo.DevicePointers.emplace_back(
4580 mapData.DevicePointers[mapDataOverlapIdx]);
4582 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4583 combinedInfo.BasePointers.emplace_back(
4584 mapData.BasePointers[mapDataIndex]);
4585 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4586 combinedInfo.Pointers.emplace_back(lowAddr);
4587 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4588 builder.CreatePtrDiff(builder.getInt8Ty(),
4589 mapData.OriginalValue[mapDataOverlapIdx],
4591 builder.getInt64Ty(),
true));
4592 lowAddr = builder.CreateConstGEP1_32(
4594 mapData.MapClause[mapDataOverlapIdx]))
4595 ? builder.getPtrTy()
4596 : mapData.BaseType[mapDataOverlapIdx],
4597 mapData.BasePointers[mapDataOverlapIdx], 1);
4600 combinedInfo.Types.emplace_back(mapFlag);
4601 combinedInfo.DevicePointers.emplace_back(
4602 mapData.DevicePointers[mapDataIndex]);
4604 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4605 combinedInfo.BasePointers.emplace_back(
4606 mapData.BasePointers[mapDataIndex]);
4607 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4608 combinedInfo.Pointers.emplace_back(lowAddr);
4609 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4610 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4611 builder.getInt64Ty(),
true));
4614 return memberOfFlag;
4620 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4621 MapInfoData &mapData, uint64_t mapDataIndex,
4622 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
4623 TargetDirectiveEnumTy targetDirective) {
4624 assert(!ompBuilder.Config.isTargetDevice() &&
4625 "function only supported for host device codegen");
4628 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4630 for (
auto mappedMembers : parentClause.getMembers()) {
4632 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4635 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4646 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4647 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4648 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4649 combinedInfo.Types.emplace_back(mapFlag);
4650 combinedInfo.DevicePointers.emplace_back(
4651 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4652 combinedInfo.Mappers.emplace_back(
nullptr);
4653 combinedInfo.Names.emplace_back(
4655 combinedInfo.BasePointers.emplace_back(
4656 mapData.BasePointers[mapDataIndex]);
4657 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4658 combinedInfo.Sizes.emplace_back(builder.getInt64(
4659 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
4665 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4666 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4667 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4669 ? parentClause.getVarPtr()
4670 : parentClause.getVarPtrPtr());
4673 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
4674 targetDirective != TargetDirectiveEnumTy::TargetData))) {
4675 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4678 combinedInfo.Types.emplace_back(mapFlag);
4679 combinedInfo.DevicePointers.emplace_back(
4680 mapData.DevicePointers[memberDataIdx]);
4681 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4682 combinedInfo.Names.emplace_back(
4684 uint64_t basePointerIndex =
4686 combinedInfo.BasePointers.emplace_back(
4687 mapData.BasePointers[basePointerIndex]);
4688 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4690 llvm::Value *size = mapData.Sizes[memberDataIdx];
4692 size = builder.CreateSelect(
4693 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4694 builder.getInt64(0), size);
4697 combinedInfo.Sizes.emplace_back(size);
4702 MapInfosTy &combinedInfo,
4703 TargetDirectiveEnumTy targetDirective,
4704 int mapDataParentIdx = -1) {
4708 auto mapFlag = mapData.Types[mapDataIdx];
4709 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4713 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4715 if (targetDirective == TargetDirectiveEnumTy::Target &&
4716 !mapData.IsDeclareTarget[mapDataIdx])
4717 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4719 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4721 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4726 if (mapDataParentIdx >= 0)
4727 combinedInfo.BasePointers.emplace_back(
4728 mapData.BasePointers[mapDataParentIdx]);
4730 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4732 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4733 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4734 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4735 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4736 combinedInfo.Types.emplace_back(mapFlag);
4737 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4741 llvm::IRBuilderBase &builder,
4742 llvm::OpenMPIRBuilder &ompBuilder,
4744 MapInfoData &mapData, uint64_t mapDataIndex,
4745 TargetDirectiveEnumTy targetDirective) {
4746 assert(!ompBuilder.Config.isTargetDevice() &&
4747 "function only supported for host device codegen");
4750 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4755 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4756 auto memberClause = llvm::cast<omp::MapInfoOp>(
4757 parentClause.getMembers()[0].getDefiningOp());
4774 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4776 combinedInfo, mapData, mapDataIndex,
4779 combinedInfo, mapData, mapDataIndex,
4780 memberOfParentFlag, targetDirective);
4790 llvm::IRBuilderBase &builder) {
4792 "function only supported for host device codegen");
4793 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4795 if (!mapData.IsDeclareTarget[i]) {
4796 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4797 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4807 switch (captureKind) {
4808 case omp::VariableCaptureKind::ByRef: {
4809 llvm::Value *newV = mapData.Pointers[i];
4811 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4814 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4816 if (!offsetIdx.empty())
4817 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4819 mapData.Pointers[i] = newV;
4821 case omp::VariableCaptureKind::ByCopy: {
4822 llvm::Type *type = mapData.BaseType[i];
4824 if (mapData.Pointers[i]->getType()->isPointerTy())
4825 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4827 newV = mapData.Pointers[i];
4830 auto curInsert = builder.saveIP();
4831 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
4833 auto *memTempAlloc =
4834 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4835 builder.SetCurrentDebugLocation(DbgLoc);
4836 builder.restoreIP(curInsert);
4838 builder.CreateStore(newV, memTempAlloc);
4839 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4842 mapData.Pointers[i] = newV;
4843 mapData.BasePointers[i] = newV;
4845 case omp::VariableCaptureKind::This:
4846 case omp::VariableCaptureKind::VLAType:
4847 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4858 MapInfoData &mapData,
4859 TargetDirectiveEnumTy targetDirective) {
4861 "function only supported for host device codegen");
4882 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4885 if (mapData.IsAMember[i])
4888 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4889 if (!mapInfoOp.getMembers().empty()) {
4891 combinedInfo, mapData, i, targetDirective);
4899static llvm::Expected<llvm::Function *>
4901 LLVM::ModuleTranslation &moduleTranslation,
4902 llvm::StringRef mapperFuncName,
4903 TargetDirectiveEnumTy targetDirective);
4905static llvm::Expected<llvm::Function *>
4908 TargetDirectiveEnumTy targetDirective) {
4910 "function only supported for host device codegen");
4911 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4912 std::string mapperFuncName =
4914 {
"omp_mapper", declMapperOp.getSymName()});
4916 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4920 mapperFuncName, targetDirective);
4923static llvm::Expected<llvm::Function *>
4926 llvm::StringRef mapperFuncName,
4927 TargetDirectiveEnumTy targetDirective) {
4929 "function only supported for host device codegen");
4930 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4931 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4934 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4937 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4940 MapInfosTy combinedInfo;
4942 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4943 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4944 builder.restoreIP(codeGenIP);
4945 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4946 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4947 builder.GetInsertBlock());
4948 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4951 return llvm::make_error<PreviouslyReportedError>();
4952 MapInfoData mapData;
4955 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
4961 return combinedInfo;
4965 if (!combinedInfo.Mappers[i])
4968 moduleTranslation, targetDirective);
4972 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4974 return newFn.takeError();
4975 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4982 llvm::Value *ifCond =
nullptr;
4983 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4987 llvm::omp::RuntimeFunction RTLFn;
4989 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
4992 llvm::OpenMPIRBuilder::TargetDataInfo info(
4995 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4996 bool isOffloadEntry =
4997 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5001 .Case([&](omp::TargetDataOp dataOp) {
5005 if (
auto ifVar = dataOp.getIfExpr())
5008 if (
auto devId = dataOp.getDevice())
5009 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5010 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5011 deviceID = intAttr.getInt();
5013 mapVars = dataOp.getMapVars();
5014 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5015 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5018 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5022 if (
auto ifVar = enterDataOp.getIfExpr())
5025 if (
auto devId = enterDataOp.getDevice())
5026 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5027 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5028 deviceID = intAttr.getInt();
5030 enterDataOp.getNowait()
5031 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5032 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5033 mapVars = enterDataOp.getMapVars();
5034 info.HasNoWait = enterDataOp.getNowait();
5037 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5041 if (
auto ifVar = exitDataOp.getIfExpr())
5044 if (
auto devId = exitDataOp.getDevice())
5045 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5046 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5047 deviceID = intAttr.getInt();
5049 RTLFn = exitDataOp.getNowait()
5050 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5051 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5052 mapVars = exitDataOp.getMapVars();
5053 info.HasNoWait = exitDataOp.getNowait();
5056 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5060 if (
auto ifVar = updateDataOp.getIfExpr())
5063 if (
auto devId = updateDataOp.getDevice())
5064 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5065 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5066 deviceID = intAttr.getInt();
5069 updateDataOp.getNowait()
5070 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5071 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5072 mapVars = updateDataOp.getMapVars();
5073 info.HasNoWait = updateDataOp.getNowait();
5076 .DefaultUnreachable(
"unexpected operation");
5081 if (!isOffloadEntry)
5082 ifCond = builder.getFalse();
5084 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5085 MapInfoData mapData;
5087 builder, useDevicePtrVars, useDeviceAddrVars);
5090 MapInfosTy combinedInfo;
5091 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5092 builder.restoreIP(codeGenIP);
5093 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5095 return combinedInfo;
5101 [&moduleTranslation](
5102 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5106 for (
auto [arg, useDevVar] :
5107 llvm::zip_equal(blockArgs, useDeviceVars)) {
5109 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5110 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5111 : mapInfoOp.getVarPtr();
5114 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5115 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5116 mapInfoData.MapClause, mapInfoData.DevicePointers,
5117 mapInfoData.BasePointers)) {
5118 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5119 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5120 devicePointer != type)
5123 if (llvm::Value *devPtrInfoMap =
5124 mapper ? mapper(basePointer) : basePointer) {
5125 moduleTranslation.
mapValue(arg, devPtrInfoMap);
5132 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5133 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5134 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5137 builder.restoreIP(codeGenIP);
5138 assert(isa<omp::TargetDataOp>(op) &&
5139 "BodyGen requested for non TargetDataOp");
5140 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5141 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
5142 switch (bodyGenType) {
5143 case BodyGenTy::Priv:
5145 if (!info.DevicePtrInfoMap.empty()) {
5146 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5147 blockArgIface.getUseDeviceAddrBlockArgs(),
5148 useDeviceAddrVars, mapData,
5149 [&](llvm::Value *basePointer) -> llvm::Value * {
5150 if (!info.DevicePtrInfoMap[basePointer].second)
5152 return builder.CreateLoad(
5154 info.DevicePtrInfoMap[basePointer].second);
5156 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5157 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5158 mapData, [&](llvm::Value *basePointer) {
5159 return info.DevicePtrInfoMap[basePointer].second;
5163 moduleTranslation)))
5164 return llvm::make_error<PreviouslyReportedError>();
5167 case BodyGenTy::DupNoPriv:
5168 if (info.DevicePtrInfoMap.empty()) {
5171 if (!ompBuilder->Config.IsTargetDevice.value_or(
false)) {
5172 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5173 blockArgIface.getUseDeviceAddrBlockArgs(),
5174 useDeviceAddrVars, mapData);
5175 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5176 blockArgIface.getUseDevicePtrBlockArgs(),
5177 useDevicePtrVars, mapData);
5181 case BodyGenTy::NoPriv:
5183 if (info.DevicePtrInfoMap.empty()) {
5186 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
5187 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5188 blockArgIface.getUseDeviceAddrBlockArgs(),
5189 useDeviceAddrVars, mapData);
5190 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5191 blockArgIface.getUseDevicePtrBlockArgs(),
5192 useDevicePtrVars, mapData);
5196 moduleTranslation)))
5197 return llvm::make_error<PreviouslyReportedError>();
5201 return builder.saveIP();
5204 auto customMapperCB =
5206 if (!combinedInfo.Mappers[i])
5208 info.HasMapper =
true;
5210 moduleTranslation, targetDirective);
5213 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5214 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5216 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5217 if (isa<omp::TargetDataOp>(op))
5218 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5219 builder.getInt64(deviceID), ifCond,
5220 info, genMapInfoCB, customMapperCB,
5223 return ompBuilder->createTargetData(
5224 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
5225 info, genMapInfoCB, customMapperCB, &RTLFn);
5231 builder.restoreIP(*afterIP);
5239 auto distributeOp = cast<omp::DistributeOp>(opInst);
5246 bool doDistributeReduction =
5250 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5255 if (doDistributeReduction) {
5256 isByRef =
getIsByRef(teamsOp.getReductionByref());
5257 assert(isByRef.size() == teamsOp.getNumReductionVars());
5260 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5264 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5265 .getReductionBlockArgs();
5268 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5269 reductionDecls, privateReductionVariables, reductionVariableMap,
5274 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5275 auto bodyGenCB = [&](InsertPointTy allocaIP,
5276 InsertPointTy codeGenIP) -> llvm::Error {
5280 moduleTranslation, allocaIP);
5283 builder.restoreIP(codeGenIP);
5289 return llvm::make_error<PreviouslyReportedError>();
5294 return llvm::make_error<PreviouslyReportedError>();
5297 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
5299 distributeOp.getPrivateNeedsBarrier())))
5300 return llvm::make_error<PreviouslyReportedError>();
5303 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5306 builder, moduleTranslation);
5308 return regionBlock.takeError();
5309 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5314 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5317 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5318 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5319 : omp::ClauseScheduleKind::Static;
5321 bool isOrdered = hasDistSchedule;
5322 std::optional<omp::ScheduleModifier> scheduleMod;
5323 bool isSimd =
false;
5324 llvm::omp::WorksharingLoopType workshareLoopType =
5325 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5326 bool loopNeedsBarrier =
false;
5327 llvm::Value *chunk = moduleTranslation.
lookupValue(
5328 distributeOp.getDistScheduleChunkSize());
5329 llvm::CanonicalLoopInfo *loopInfo =
5331 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5332 ompBuilder->applyWorkshareLoop(
5333 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5334 convertToScheduleKind(schedule), chunk, isSimd,
5335 scheduleMod == omp::ScheduleModifier::monotonic,
5336 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5337 workshareLoopType,
false, hasDistSchedule, chunk);
5340 return wsloopIP.takeError();
5343 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5345 return llvm::make_error<PreviouslyReportedError>();
5347 return llvm::Error::success();
5350 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5352 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5353 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5354 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5359 builder.restoreIP(*afterIP);
5361 if (doDistributeReduction) {
5364 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5365 privateReductionVariables, isByRef,
5377 if (!cast<mlir::ModuleOp>(op))
5382 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5383 attribute.getOpenmpDeviceVersion());
5385 if (attribute.getNoGpuLib())
5388 ompBuilder->createGlobalFlag(
5389 attribute.getDebugKind() ,
5390 "__omp_rtl_debug_kind");
5391 ompBuilder->createGlobalFlag(
5393 .getAssumeTeamsOversubscription()
5395 "__omp_rtl_assume_teams_oversubscription");
5396 ompBuilder->createGlobalFlag(
5398 .getAssumeThreadsOversubscription()
5400 "__omp_rtl_assume_threads_oversubscription");
5401 ompBuilder->createGlobalFlag(
5402 attribute.getAssumeNoThreadState() ,
5403 "__omp_rtl_assume_no_thread_state");
5404 ompBuilder->createGlobalFlag(
5406 .getAssumeNoNestedParallelism()
5408 "__omp_rtl_assume_no_nested_parallelism");
5413 omp::TargetOp targetOp,
5414 llvm::StringRef parentName =
"") {
5415 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5417 assert(fileLoc &&
"No file found from location");
5418 StringRef fileName = fileLoc.getFilename().getValue();
5420 llvm::sys::fs::UniqueID id;
5421 uint64_t line = fileLoc.getLine();
5422 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5423 size_t fileHash = llvm::hash_value(fileName.str());
5424 size_t deviceId = 0xdeadf17e;
5426 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5428 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5429 id.getFile(), line);
5436 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5438 "function only supported for target device codegen");
5439 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5440 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5453 if (mapData.IsDeclareTarget[i]) {
5460 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5461 convertUsersOfConstantsToInstructions(constant,
func,
false);
5468 for (llvm::User *user : mapData.OriginalValue[i]->users())
5469 userVec.push_back(user);
5471 for (llvm::User *user : userVec) {
5472 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5473 if (insn->getFunction() ==
func) {
5474 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5475 llvm::Value *substitute = mapData.BasePointers[i];
5477 : mapOp.getVarPtr())) {
5478 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5479 substitute = builder.CreateLoad(
5480 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5481 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5483 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5530static llvm::IRBuilderBase::InsertPoint
5532 llvm::Value *input, llvm::Value *&retVal,
5533 llvm::IRBuilderBase &builder,
5534 llvm::OpenMPIRBuilder &ompBuilder,
5536 llvm::IRBuilderBase::InsertPoint allocaIP,
5537 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5538 assert(ompBuilder.Config.isTargetDevice() &&
5539 "function only supported for target device codegen");
5540 builder.restoreIP(allocaIP);
5542 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5544 ompBuilder.M.getContext());
5545 unsigned alignmentValue = 0;
5547 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
5548 if (mapData.OriginalValue[i] == input) {
5549 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5550 capture = mapOp.getMapCaptureType();
5553 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5557 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5558 unsigned int defaultAS =
5559 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5562 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5564 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5565 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5567 builder.CreateStore(&arg, v);
5569 builder.restoreIP(codeGenIP);
5572 case omp::VariableCaptureKind::ByCopy: {
5576 case omp::VariableCaptureKind::ByRef: {
5577 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5579 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5594 if (v->getType()->isPointerTy() && alignmentValue) {
5595 llvm::MDBuilder MDB(builder.getContext());
5596 loadInst->setMetadata(
5597 llvm::LLVMContext::MD_align,
5598 llvm::MDNode::get(builder.getContext(),
5599 MDB.createConstant(llvm::ConstantInt::get(
5600 llvm::Type::getInt64Ty(builder.getContext()),
5607 case omp::VariableCaptureKind::This:
5608 case omp::VariableCaptureKind::VLAType:
5611 assert(
false &&
"Currently unsupported capture kind");
5615 return builder.saveIP();
5632 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
5633 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5634 blockArgIface.getHostEvalBlockArgs())) {
5635 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5639 .Case([&](omp::TeamsOp teamsOp) {
5640 if (teamsOp.getNumTeamsLower() == blockArg)
5641 numTeamsLower = hostEvalVar;
5642 else if (teamsOp.getNumTeamsUpper() == blockArg)
5643 numTeamsUpper = hostEvalVar;
5644 else if (teamsOp.getThreadLimit() == blockArg)
5645 threadLimit = hostEvalVar;
5647 llvm_unreachable(
"unsupported host_eval use");
5649 .Case([&](omp::ParallelOp parallelOp) {
5650 if (parallelOp.getNumThreads() == blockArg)
5651 numThreads = hostEvalVar;
5653 llvm_unreachable(
"unsupported host_eval use");
5655 .Case([&](omp::LoopNestOp loopOp) {
5656 auto processBounds =
5660 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
5661 if (lb == blockArg) {
5664 (*outBounds)[i] = hostEvalVar;
5670 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5671 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5673 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5675 assert(found &&
"unsupported host_eval use");
5677 .DefaultUnreachable(
"unsupported host_eval use");
5689template <
typename OpTy>
5694 if (OpTy casted = dyn_cast<OpTy>(op))
5697 if (immediateParent)
5698 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5707 return std::nullopt;
5710 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5711 return constAttr.getInt();
5713 return std::nullopt;
5718 uint64_t sizeInBytes = sizeInBits / 8;
5722template <
typename OpTy>
5724 if (op.getNumReductionVars() > 0) {
5729 members.reserve(reductions.size());
5730 for (omp::DeclareReductionOp &red : reductions)
5731 members.push_back(red.getType());
5733 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5749 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5750 bool isTargetDevice,
bool isGPU) {
5753 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5754 if (!isTargetDevice) {
5762 numTeamsLower = teamsOp.getNumTeamsLower();
5763 numTeamsUpper = teamsOp.getNumTeamsUpper();
5764 threadLimit = teamsOp.getThreadLimit();
5768 numThreads = parallelOp.getNumThreads();
5773 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5777 if (numTeamsUpper) {
5779 minTeamsVal = maxTeamsVal = *val;
5781 minTeamsVal = maxTeamsVal = 0;
5787 minTeamsVal = maxTeamsVal = 1;
5789 minTeamsVal = maxTeamsVal = -1;
5794 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
5808 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5809 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5810 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5813 int32_t maxThreadsVal = -1;
5815 setMaxValueFromClause(numThreads, maxThreadsVal);
5823 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5824 if (combinedMaxThreadsVal < 0 ||
5825 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5826 combinedMaxThreadsVal = teamsThreadLimitVal;
5828 if (combinedMaxThreadsVal < 0 ||
5829 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5830 combinedMaxThreadsVal = maxThreadsVal;
5832 int32_t reductionDataSize = 0;
5833 if (isGPU && capturedOp) {
5839 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5841 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5842 omp::TargetRegionFlags::spmd) &&
5843 "invalid kernel flags");
5845 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5846 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5847 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5848 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5849 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5850 if (omp::bitEnumContainsAll(kernelFlags,
5851 omp::TargetRegionFlags::spmd |
5852 omp::TargetRegionFlags::no_loop) &&
5853 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
5854 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
5856 attrs.MinTeams = minTeamsVal;
5857 attrs.MaxTeams.front() = maxTeamsVal;
5858 attrs.MinThreads = 1;
5859 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5860 attrs.ReductionDataSize = reductionDataSize;
5863 if (attrs.ReductionDataSize != 0)
5864 attrs.ReductionBufferLength = 1024;
5876 omp::TargetOp targetOp,
Operation *capturedOp,
5877 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5879 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5881 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5885 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5888 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5889 attrs.TargetThreadLimit.front() =
5893 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5896 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5898 if (teamsThreadLimit)
5899 attrs.TeamsThreadLimit.front() =
5903 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5905 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5906 omp::TargetRegionFlags::trip_count)) {
5908 attrs.LoopTripCount =
nullptr;
5913 for (
auto [loopLower, loopUpper, loopStep] :
5914 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5915 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5916 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5917 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5919 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5920 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5921 loc, lowerBound, upperBound, step,
true,
5922 loopOp.getLoopInclusive());
5924 if (!attrs.LoopTripCount) {
5925 attrs.LoopTripCount = tripCount;
5930 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5939 auto targetOp = cast<omp::TargetOp>(opInst);
5943 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
5952 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
5953 assert(parentBB &&
"No insert block is set for the builder");
5954 llvm::Function *parentLLVMFn = parentBB->getParent();
5955 assert(parentLLVMFn &&
"Parent Function must be valid");
5956 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
5957 builder.SetCurrentDebugLocation(llvm::DILocation::get(
5958 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
5959 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
5962 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5963 bool isGPU = ompBuilder->Config.isGPU();
5966 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5967 auto &targetRegion = targetOp.getRegion();
5984 llvm::Function *llvmOutlinedFn =
nullptr;
5985 TargetDirectiveEnumTy targetDirective =
5986 getTargetDirectiveEnumTyFromOp(&opInst);
5990 bool isOffloadEntry =
5991 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5998 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6000 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6001 std::optional<DenseI64ArrayAttr> privateMapIndices =
6002 targetOp.getPrivateMapsAttr();
6004 for (
auto [privVarIdx, privVarSymPair] :
6005 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6006 auto privVar = std::get<0>(privVarSymPair);
6007 auto privSym = std::get<1>(privVarSymPair);
6009 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6010 omp::PrivateClauseOp privatizer =
6013 if (!privatizer.needsMap())
6017 targetOp.getMappedValueForPrivateVar(privVarIdx);
6018 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6019 "variable that needs mapping");
6024 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6025 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6029 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6031 varType == privVar.getType() &&
6032 "Type of private var doesn't match the type of the mapped value");
6036 mappedPrivateVars.insert(
6038 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6039 (*privateMapIndices)[privVarIdx])});
6043 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6044 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6045 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6046 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6047 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6050 llvm::Function *llvmParentFn =
6052 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6053 assert(llvmParentFn && llvmOutlinedFn &&
6054 "Both parent and outlined functions must exist at this point");
6056 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6057 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6059 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6060 attr.isStringAttribute())
6061 llvmOutlinedFn->addFnAttr(attr);
6063 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6064 attr.isStringAttribute())
6065 llvmOutlinedFn->addFnAttr(attr);
6067 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6068 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6069 llvm::Value *mapOpValue =
6070 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6071 moduleTranslation.
mapValue(arg, mapOpValue);
6073 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6074 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6075 llvm::Value *mapOpValue =
6076 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6077 moduleTranslation.
mapValue(arg, mapOpValue);
6086 allocaIP, &mappedPrivateVars);
6089 return llvm::make_error<PreviouslyReportedError>();
6091 builder.restoreIP(codeGenIP);
6093 &mappedPrivateVars),
6096 return llvm::make_error<PreviouslyReportedError>();
6099 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
6101 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6102 return llvm::make_error<PreviouslyReportedError>();
6106 std::back_inserter(privateCleanupRegions),
6107 [](omp::PrivateClauseOp privatizer) {
6108 return &privatizer.getDeallocRegion();
6112 targetRegion,
"omp.target", builder, moduleTranslation);
6115 return exitBlock.takeError();
6117 builder.SetInsertPoint(*exitBlock);
6118 if (!privateCleanupRegions.empty()) {
6120 privateCleanupRegions, privateVarsInfo.
llvmVars,
6121 moduleTranslation, builder,
"omp.targetop.private.cleanup",
6123 return llvm::createStringError(
6124 "failed to inline `dealloc` region of `omp.private` "
6125 "op in the target region");
6127 return builder.saveIP();
6130 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6133 StringRef parentName = parentFn.getName();
6135 llvm::TargetRegionEntryInfo entryInfo;
6139 MapInfoData mapData;
6144 MapInfosTy combinedInfos;
6146 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6147 builder.restoreIP(codeGenIP);
6148 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6150 return combinedInfos;
6153 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6154 llvm::Value *&retVal, InsertPointTy allocaIP,
6155 InsertPointTy codeGenIP)
6156 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6157 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6158 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6164 if (!isTargetDevice) {
6165 retVal = cast<llvm::Value>(&arg);
6170 *ompBuilder, moduleTranslation,
6171 allocaIP, codeGenIP);
6174 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6175 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6176 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6178 isTargetDevice, isGPU);
6182 if (!isTargetDevice)
6184 targetCapturedOp, runtimeAttrs);
6192 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6193 llvm::Value *value = moduleTranslation.
lookupValue(var);
6194 moduleTranslation.
mapValue(arg, value);
6196 if (!llvm::isa<llvm::Constant>(value))
6197 kernelInput.push_back(value);
6200 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6207 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6208 kernelInput.push_back(mapData.OriginalValue[i]);
6213 moduleTranslation, dds);
6215 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6217 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6219 llvm::OpenMPIRBuilder::TargetDataInfo info(
6223 auto customMapperCB =
6225 if (!combinedInfos.Mappers[i])
6227 info.HasMapper =
true;
6229 moduleTranslation, targetDirective);
6232 llvm::Value *ifCond =
nullptr;
6233 if (
Value targetIfCond = targetOp.getIfExpr())
6234 ifCond = moduleTranslation.
lookupValue(targetIfCond);
6236 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6238 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6239 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6240 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6245 builder.restoreIP(*afterIP);
6266 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6267 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6269 if (!offloadMod.getIsTargetDevice())
6272 omp::DeclareTargetDeviceType declareType =
6273 attribute.getDeviceType().getValue();
6275 if (declareType == omp::DeclareTargetDeviceType::host) {
6276 llvm::Function *llvmFunc =
6278 llvmFunc->dropAllReferences();
6279 llvmFunc->eraseFromParent();
6285 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6286 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6287 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6289 bool isDeclaration = gOp.isDeclaration();
6290 bool isExternallyVisible =
6293 llvm::StringRef mangledName = gOp.getSymName();
6294 auto captureClause =
6300 std::vector<llvm::GlobalVariable *> generatedRefs;
6302 std::vector<llvm::Triple> targetTriple;
6303 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6305 LLVM::LLVMDialect::getTargetTripleAttrName()));
6306 if (targetTripleAttr)
6307 targetTriple.emplace_back(targetTripleAttr.data());
6309 auto fileInfoCallBack = [&loc]() {
6310 std::string filename =
"";
6311 std::uint64_t lineNo = 0;
6314 filename = loc.getFilename().str();
6315 lineNo = loc.getLine();
6318 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6322 auto vfs = llvm::vfs::getRealFileSystem();
6324 ompBuilder->registerTargetGlobalVariable(
6325 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6326 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6327 mangledName, generatedRefs,
false, targetTriple,
6329 gVal->getType(), gVal);
6331 if (ompBuilder->Config.isTargetDevice() &&
6332 (attribute.getCaptureClause().getValue() !=
6333 mlir::omp::DeclareTargetCaptureClause::to ||
6334 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6335 ompBuilder->getAddrOfDeclareTargetVar(
6336 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6337 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6338 mangledName, generatedRefs,
false, targetTriple,
6339 gVal->getType(),
nullptr,
6360 if (mlir::isa<omp::ThreadprivateOp>(op))
6363 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
6364 mlir::isa<omp::TargetFreeMemOp>(op))
6368 if (
auto declareTargetIface =
6369 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6370 parentFn.getOperation()))
6371 if (declareTargetIface.isDeclareTarget() &&
6372 declareTargetIface.getDeclareTargetDeviceType() !=
6373 mlir::omp::DeclareTargetDeviceType::host)
6380 llvm::Module *llvmModule) {
6381 llvm::Type *i64Ty = builder.getInt64Ty();
6382 llvm::Type *i32Ty = builder.getInt32Ty();
6383 llvm::Type *returnType = builder.getPtrTy(0);
6384 llvm::FunctionType *fnType =
6385 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
6386 llvm::Function *
func = cast<llvm::Function>(
6387 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
6394 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6399 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6403 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6405 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6406 mlir::Type heapTy = allocMemOp.getAllocatedType();
6407 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
6408 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6409 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6410 for (
auto typeParam : allocMemOp.getTypeparams())
6412 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
6414 llvm::CallInst *call =
6415 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6416 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6419 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
6424 llvm::Module *llvmModule) {
6425 llvm::Type *ptrTy = builder.getPtrTy(0);
6426 llvm::Type *i32Ty = builder.getInt32Ty();
6427 llvm::Type *voidTy = builder.getVoidTy();
6428 llvm::FunctionType *fnType =
6429 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
6430 llvm::Function *
func = dyn_cast<llvm::Function>(
6431 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
6438 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6443 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6447 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6450 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
6452 llvm::Value *intToPtr =
6453 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
6454 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
6471 bool isOutermostLoopWrapper =
6472 isa_and_present<omp::LoopWrapperInterface>(op) &&
6473 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
6475 if (isOutermostLoopWrapper)
6476 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
6480 .Case([&](omp::BarrierOp op) -> LogicalResult {
6484 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6485 ompBuilder->createBarrier(builder.saveIP(),
6486 llvm::omp::OMPD_barrier);
6488 if (res.succeeded()) {
6491 builder.restoreIP(*afterIP);
6495 .Case([&](omp::TaskyieldOp op) {
6499 ompBuilder->createTaskyield(builder.saveIP());
6502 .Case([&](omp::FlushOp op) {
6514 ompBuilder->createFlush(builder.saveIP());
6517 .Case([&](omp::ParallelOp op) {
6520 .Case([&](omp::MaskedOp) {
6523 .Case([&](omp::MasterOp) {
6526 .Case([&](omp::CriticalOp) {
6529 .Case([&](omp::OrderedRegionOp) {
6532 .Case([&](omp::OrderedOp) {
6535 .Case([&](omp::WsloopOp) {
6538 .Case([&](omp::SimdOp) {
6541 .Case([&](omp::AtomicReadOp) {
6544 .Case([&](omp::AtomicWriteOp) {
6547 .Case([&](omp::AtomicUpdateOp op) {
6550 .Case([&](omp::AtomicCaptureOp op) {
6553 .Case([&](omp::CancelOp op) {
6556 .Case([&](omp::CancellationPointOp op) {
6559 .Case([&](omp::SectionsOp) {
6562 .Case([&](omp::SingleOp op) {
6565 .Case([&](omp::TeamsOp op) {
6568 .Case([&](omp::TaskOp op) {
6571 .Case([&](omp::TaskgroupOp op) {
6574 .Case([&](omp::TaskwaitOp op) {
6577 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6578 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6579 omp::CriticalDeclareOp>([](
auto op) {
6592 .Case([&](omp::ThreadprivateOp) {
6595 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
6596 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
6599 .Case([&](omp::TargetOp) {
6602 .Case([&](omp::DistributeOp) {
6605 .Case([&](omp::LoopNestOp) {
6608 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
6615 .Case([&](omp::NewCliOp op) {
6620 .Case([&](omp::CanonicalLoopOp op) {
6623 .Case([&](omp::UnrollHeuristicOp op) {
6632 .Case([&](omp::TileOp op) {
6633 return applyTile(op, builder, moduleTranslation);
6635 .Case([&](omp::TargetAllocMemOp) {
6638 .Case([&](omp::TargetFreeMemOp) {
6643 <<
"not yet implemented: " << inst->
getName();
6646 if (isOutermostLoopWrapper)
6661 if (isa<omp::TargetOp>(op))
6663 if (isa<omp::TargetDataOp>(op))
6667 if (isa<omp::TargetOp>(oper)) {
6672 if (isa<omp::TargetDataOp>(oper)) {
6682 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
6683 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
6684 !oper->getRegions().empty()) {
6685 if (
auto blockArgsIface =
6686 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
6692 if (isa<mlir::omp::AtomicUpdateOp>(oper))
6693 for (
auto [operand, arg] :
6694 llvm::zip_equal(oper->getOperands(),
6695 oper->getRegion(0).getArguments())) {
6697 arg, builder.CreateLoad(
6703 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
6704 assert(builder.GetInsertBlock() &&
6705 "No insert block is set for the builder");
6706 for (
auto iv : loopNest.getIVs()) {
6709 iv, llvm::PoisonValue::get(
6714 for (
Region ®ion : oper->getRegions()) {
6721 region, oper->getName().getStringRef().str() +
".fake.region",
6722 builder, moduleTranslation, &phis);
6726 builder.SetInsertPoint(
result.get(),
result.get()->end());
6733 }).wasInterrupted();
6734 return failure(interrupted);
6741class OpenMPDialectLLVMIRTranslationInterface
6742 :
public LLVMTranslationDialectInterface {
6749 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6750 LLVM::ModuleTranslation &moduleTranslation)
const final;
6755 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6756 NamedAttribute attribute,
6757 LLVM::ModuleTranslation &moduleTranslation)
const final;
6762LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6763 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6764 NamedAttribute attribute,
6765 LLVM::ModuleTranslation &moduleTranslation)
const {
6766 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6768 .Case(
"omp.is_target_device",
6769 [&](Attribute attr) {
6770 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6771 llvm::OpenMPIRBuilderConfig &
config =
6773 config.setIsTargetDevice(deviceAttr.getValue());
6779 [&](Attribute attr) {
6780 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6781 llvm::OpenMPIRBuilderConfig &
config =
6783 config.setIsGPU(gpuAttr.getValue());
6788 .Case(
"omp.host_ir_filepath",
6789 [&](Attribute attr) {
6790 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6791 llvm::OpenMPIRBuilder *ompBuilder =
6793 auto VFS = llvm::vfs::getRealFileSystem();
6794 ompBuilder->loadOffloadInfoMetadata(*VFS,
6795 filepathAttr.getValue());
6801 [&](Attribute attr) {
6802 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6806 .Case(
"omp.version",
6807 [&](Attribute attr) {
6808 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6809 llvm::OpenMPIRBuilder *ompBuilder =
6811 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6812 versionAttr.getVersion());
6817 .Case(
"omp.declare_target",
6818 [&](Attribute attr) {
6819 if (
auto declareTargetAttr =
6820 dyn_cast<omp::DeclareTargetAttr>(attr))
6825 .Case(
"omp.requires",
6826 [&](Attribute attr) {
6827 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6828 using Requires = omp::ClauseRequires;
6829 Requires flags = requiresAttr.getValue();
6830 llvm::OpenMPIRBuilderConfig &
config =
6832 config.setHasRequiresReverseOffload(
6833 bitEnumContainsAll(flags, Requires::reverse_offload));
6834 config.setHasRequiresUnifiedAddress(
6835 bitEnumContainsAll(flags, Requires::unified_address));
6836 config.setHasRequiresUnifiedSharedMemory(
6837 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6838 config.setHasRequiresDynamicAllocators(
6839 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6844 .Case(
"omp.target_triples",
6845 [&](Attribute attr) {
6846 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6847 llvm::OpenMPIRBuilderConfig &
config =
6849 config.TargetTriples.clear();
6850 config.TargetTriples.reserve(triplesAttr.size());
6851 for (Attribute tripleAttr : triplesAttr) {
6852 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6853 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6861 .Default([](Attribute) {
6871LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6872 Operation *op, llvm::IRBuilderBase &builder,
6873 LLVM::ModuleTranslation &moduleTranslation)
const {
6876 if (ompBuilder->Config.isTargetDevice()) {
6886 registry.
insert<omp::OpenMPDialect>();
6888 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars