24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
65 llvm_unreachable(
"unhandled schedule clause argument");
70class OpenMPAllocaStackFrame
75 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
76 : allocaInsertPoint(allocaIP) {}
77 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
83class OpenMPLoopInfoStackFrame
87 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
106class PreviouslyReportedError
107 :
public llvm::ErrorInfo<PreviouslyReportedError> {
109 void log(raw_ostream &)
const override {
113 std::error_code convertToErrorCode()
const override {
115 "PreviouslyReportedError doesn't support ECError conversion");
122char PreviouslyReportedError::ID = 0;
133class LinearClauseProcessor {
136 SmallVector<llvm::Value *> linearPreconditionVars;
137 SmallVector<llvm::Value *> linearLoopBodyTemps;
138 SmallVector<llvm::AllocaInst *> linearOrigVars;
139 SmallVector<llvm::Value *> linearOrigVal;
140 SmallVector<llvm::Value *> linearSteps;
141 llvm::BasicBlock *linearFinalizationBB;
142 llvm::BasicBlock *linearExitBB;
143 llvm::BasicBlock *linearLastIterExitBB;
147 void createLinearVar(llvm::IRBuilderBase &builder,
148 LLVM::ModuleTranslation &moduleTranslation,
149 mlir::Value &linearVar) {
150 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
152 linearPreconditionVars.push_back(builder.CreateAlloca(
153 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_var"));
154 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
155 linearVarAlloca->getAllocatedType(),
nullptr,
".linear_result");
156 linearOrigVal.push_back(moduleTranslation.
lookupValue(linearVar));
157 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
158 linearOrigVars.push_back(linearVarAlloca);
163 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
164 mlir::Value &linearStep) {
165 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
169 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
170 initLinearVar(llvm::IRBuilderBase &builder,
171 LLVM::ModuleTranslation &moduleTranslation,
172 llvm::BasicBlock *loopPreHeader) {
173 builder.SetInsertPoint(loopPreHeader->getTerminator());
174 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
175 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
176 linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
177 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
179 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
181 builder.saveIP(), llvm::omp::OMPD_barrier);
182 return afterBarrierIP;
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
191 llvm::LoadInst *linearVarStart =
192 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
194 linearPreconditionVars[index]);
195 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
196 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
197 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
203 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
204 llvm::BasicBlock *loopExit) {
205 linearFinalizationBB = loopExit->splitBasicBlock(
206 loopExit->getTerminator(),
"omp_loop.linear_finalization");
207 linearExitBB = linearFinalizationBB->splitBasicBlock(
208 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
209 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
210 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
214 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
215 finalizeLinearVar(llvm::IRBuilderBase &builder,
216 LLVM::ModuleTranslation &moduleTranslation,
217 llvm::Value *lastIter) {
219 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
220 llvm::Value *loopLastIterLoad = builder.CreateLoad(
221 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
222 llvm::Value *isLast =
223 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
224 llvm::ConstantInt::get(
225 llvm::Type::getInt32Ty(builder.getContext()), 0));
227 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
228 for (
size_t index = 0; index < linearOrigVars.size(); index++) {
229 llvm::LoadInst *linearVarTemp =
230 builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
231 linearLoopBodyTemps[index]);
232 builder.CreateStore(linearVarTemp, linearOrigVars[index]);
238 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
239 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
240 linearFinalizationBB->getTerminator()->eraseFromParent();
242 builder.SetInsertPoint(linearExitBB->getTerminator());
244 builder.saveIP(), llvm::omp::OMPD_barrier);
249 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
251 llvm::SmallVector<llvm::User *> users;
252 for (llvm::User *user : linearOrigVal[varIndex]->users())
253 users.push_back(user);
254 for (
auto *user : users) {
255 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
256 if (userInst->getParent()->getName().str() == BBName)
257 user->replaceUsesOfWith(linearOrigVal[varIndex],
258 linearLoopBodyTemps[varIndex]);
269 SymbolRefAttr symbolName) {
270 omp::PrivateClauseOp privatizer =
273 assert(privatizer &&
"privatizer not found in the symbol table");
284 auto todo = [&op](StringRef clauseName) {
285 return op.
emitError() <<
"not yet implemented: Unhandled clause "
286 << clauseName <<
" in " << op.
getName()
290 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
291 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
292 result = todo(
"allocate");
294 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
296 result = todo(
"ompx_bare");
298 auto checkCancelDirective = [&todo](
auto op, LogicalResult &
result) {
299 omp::ClauseCancellationConstructType cancelledDirective =
300 op.getCancelDirective();
303 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
310 if (isa_and_nonnull<omp::TaskloopOp>(parent))
311 result = todo(
"cancel directive inside of taskloop");
314 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
315 if (!op.getDependVars().empty() || op.getDependKinds())
318 auto checkDevice = [&todo](
auto op, LogicalResult &
result) {
322 auto checkDistSchedule = [&todo](
auto op, LogicalResult &
result) {
323 if (op.getDistScheduleChunkSize())
324 result = todo(
"dist_schedule with chunk_size");
326 auto checkHint = [](
auto op, LogicalResult &) {
330 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
331 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
332 op.getInReductionSyms())
333 result = todo(
"in_reduction");
335 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &
result) {
336 if (!op.getIsDevicePtrVars().empty())
337 result = todo(
"is_device_ptr");
339 auto checkLinear = [&todo](
auto op, LogicalResult &
result) {
340 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
343 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
347 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
348 if (op.getOrder() || op.getOrderMod())
351 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &
result) {
352 if (op.getParLevelSimd())
353 result = todo(
"parallelization-level");
355 auto checkPriority = [&todo](
auto op, LogicalResult &
result) {
356 if (op.getPriority())
357 result = todo(
"priority");
359 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
360 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
361 result = todo(
"privatization");
363 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
364 if (isa<omp::TeamsOp>(op))
365 if (!op.getReductionVars().empty() || op.getReductionByref() ||
366 op.getReductionSyms())
367 result = todo(
"reduction");
368 if (op.getReductionMod() &&
369 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
370 result = todo(
"reduction with modifier");
372 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
373 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
374 op.getTaskReductionSyms())
375 result = todo(
"task_reduction");
377 auto checkUntied = [&todo](
auto op, LogicalResult &
result) {
384 .Case([&](omp::CancelOp op) { checkCancelDirective(op,
result); })
385 .Case([&](omp::CancellationPointOp op) {
386 checkCancelDirective(op,
result);
388 .Case([&](omp::DistributeOp op) {
389 checkAllocate(op,
result);
390 checkDistSchedule(op,
result);
393 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op,
result); })
394 .Case([&](omp::SectionsOp op) {
395 checkAllocate(op,
result);
397 checkReduction(op,
result);
399 .Case([&](omp::SingleOp op) {
400 checkAllocate(op,
result);
403 .Case([&](omp::TeamsOp op) {
404 checkAllocate(op,
result);
407 .Case([&](omp::TaskOp op) {
408 checkAllocate(op,
result);
409 checkInReduction(op,
result);
411 .Case([&](omp::TaskgroupOp op) {
412 checkAllocate(op,
result);
413 checkTaskReduction(op,
result);
415 .Case([&](omp::TaskwaitOp op) {
419 .Case([&](omp::TaskloopOp op) {
422 checkPriority(op,
result);
424 .Case([&](omp::WsloopOp op) {
425 checkAllocate(op,
result);
428 checkReduction(op,
result);
430 .Case([&](omp::ParallelOp op) {
431 checkAllocate(op,
result);
432 checkReduction(op,
result);
434 .Case([&](omp::SimdOp op) {
436 checkReduction(op,
result);
438 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
439 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
440 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
441 [&](
auto op) { checkDepend(op,
result); })
442 .Case([&](omp::TargetOp op) {
443 checkAllocate(op,
result);
446 checkInReduction(op,
result);
447 checkIsDevicePtr(op,
result);
459 llvm::handleAllErrors(
461 [&](
const PreviouslyReportedError &) {
result = failure(); },
462 [&](
const llvm::ErrorInfoBase &err) {
479static llvm::OpenMPIRBuilder::InsertPointTy
485 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
487 [&](OpenMPAllocaStackFrame &frame) {
488 allocaInsertPoint = frame.allocaInsertPoint;
496 allocaInsertPoint.getBlock()->getParent() ==
497 builder.GetInsertBlock()->getParent())
498 return allocaInsertPoint;
507 if (builder.GetInsertBlock() ==
508 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
509 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
510 "Assuming end of basic block");
511 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
512 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
513 builder.GetInsertBlock()->getNextNode());
514 builder.CreateBr(entryBB);
515 builder.SetInsertPoint(entryBB);
518 llvm::BasicBlock &funcEntryBlock =
519 builder.GetInsertBlock()->getParent()->getEntryBlock();
520 return llvm::OpenMPIRBuilder::InsertPointTy(
521 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
527static llvm::CanonicalLoopInfo *
529 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
530 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
531 [&](OpenMPLoopInfoStackFrame &frame) {
532 loopInfo = frame.loopInfo;
544 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
547 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
549 llvm::BasicBlock *continuationBlock =
550 splitBB(builder,
true,
"omp.region.cont");
551 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
553 llvm::LLVMContext &llvmContext = builder.getContext();
554 for (
Block &bb : region) {
555 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
556 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
557 builder.GetInsertBlock()->getNextNode());
558 moduleTranslation.
mapBlock(&bb, llvmBB);
561 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
568 unsigned numYields = 0;
570 if (!isLoopWrapper) {
571 bool operandsProcessed =
false;
573 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
574 if (!operandsProcessed) {
575 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
576 continuationBlockPHITypes.push_back(
577 moduleTranslation.
convertType(yield->getOperand(i).getType()));
579 operandsProcessed =
true;
581 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
582 "mismatching number of values yielded from the region");
583 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
584 llvm::Type *operandType =
585 moduleTranslation.
convertType(yield->getOperand(i).getType());
587 assert(continuationBlockPHITypes[i] == operandType &&
588 "values of mismatching types yielded from the region");
598 if (!continuationBlockPHITypes.empty())
600 continuationBlockPHIs &&
601 "expected continuation block PHIs if converted regions yield values");
602 if (continuationBlockPHIs) {
603 llvm::IRBuilderBase::InsertPointGuard guard(builder);
604 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
605 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
606 for (llvm::Type *ty : continuationBlockPHITypes)
607 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
613 for (
Block *bb : blocks) {
614 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
617 if (bb->isEntryBlock()) {
618 assert(sourceTerminator->getNumSuccessors() == 1 &&
619 "provided entry block has multiple successors");
620 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
621 "ContinuationBlock is not the successor of the entry block");
622 sourceTerminator->setSuccessor(0, llvmBB);
625 llvm::IRBuilderBase::InsertPointGuard guard(builder);
627 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
628 return llvm::make_error<PreviouslyReportedError>();
633 builder.CreateBr(continuationBlock);
644 Operation *terminator = bb->getTerminator();
645 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
646 builder.CreateBr(continuationBlock);
648 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
649 (*continuationBlockPHIs)[i]->addIncoming(
663 return continuationBlock;
669 case omp::ClauseProcBindKind::Close:
670 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
671 case omp::ClauseProcBindKind::Master:
672 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
673 case omp::ClauseProcBindKind::Primary:
674 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
675 case omp::ClauseProcBindKind::Spread:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
678 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
688 omp::BlockArgOpenMPOpInterface blockArgIface) {
690 blockArgIface.getBlockArgsPairs(blockArgsPairs);
691 for (
auto [var, arg] : blockArgsPairs)
699 auto maskedOp = cast<omp::MaskedOp>(opInst);
700 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
705 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
707 auto ®ion = maskedOp.getRegion();
708 builder.restoreIP(codeGenIP);
716 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
718 llvm::Value *filterVal =
nullptr;
719 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
720 filterVal = moduleTranslation.
lookupValue(filterVar);
722 llvm::LLVMContext &llvmContext = builder.getContext();
724 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
726 assert(filterVal !=
nullptr);
727 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
728 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
735 builder.restoreIP(*afterIP);
743 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
744 auto masterOp = cast<omp::MasterOp>(opInst);
749 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
751 auto ®ion = masterOp.getRegion();
752 builder.restoreIP(codeGenIP);
760 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
762 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
763 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
770 builder.restoreIP(*afterIP);
778 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
779 auto criticalOp = cast<omp::CriticalOp>(opInst);
784 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
786 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
787 builder.restoreIP(codeGenIP);
795 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
797 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
798 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
799 llvm::Constant *hint =
nullptr;
802 if (criticalOp.getNameAttr()) {
805 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
806 auto criticalDeclareOp =
810 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
811 static_cast<int>(criticalDeclareOp.getHint()));
813 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
815 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
820 builder.restoreIP(*afterIP);
827 template <
typename OP>
830 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
833 collectPrivatizationDecls<OP>(op);
848 void collectPrivatizationDecls(OP op) {
849 std::optional<ArrayAttr> attr = op.getPrivateSyms();
854 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
865 std::optional<ArrayAttr> attr = op.getReductionSyms();
869 reductions.reserve(reductions.size() + op.getNumReductionVars());
870 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
871 reductions.push_back(
883 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
892 llvm::Instruction *potentialTerminator =
893 builder.GetInsertBlock()->empty() ?
nullptr
894 : &builder.GetInsertBlock()->back();
896 if (potentialTerminator && potentialTerminator->isTerminator())
897 potentialTerminator->removeFromParent();
898 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
901 region.
front(),
true, builder)))
905 if (continuationBlockArgs)
907 *continuationBlockArgs,
914 if (potentialTerminator && potentialTerminator->isTerminator()) {
915 llvm::BasicBlock *block = builder.GetInsertBlock();
916 if (block->empty()) {
922 potentialTerminator->insertInto(block, block->begin());
924 potentialTerminator->insertAfter(&block->back());
938 if (continuationBlockArgs)
939 llvm::append_range(*continuationBlockArgs, phis);
940 builder.SetInsertPoint(*continuationBlock,
941 (*continuationBlock)->getFirstInsertionPt());
948using OwningReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
952using OwningAtomicReductionGen =
953 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
954 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
961static OwningReductionGen
967 OwningReductionGen gen =
968 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
969 llvm::Value *
lhs, llvm::Value *
rhs,
970 llvm::Value *&
result)
mutable
971 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
972 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
973 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
974 builder.restoreIP(insertPoint);
977 "omp.reduction.nonatomic.body", builder,
978 moduleTranslation, &phis)))
979 return llvm::createStringError(
980 "failed to inline `combiner` region of `omp.declare_reduction`");
981 result = llvm::getSingleElement(phis);
982 return builder.saveIP();
991static OwningAtomicReductionGen
993 llvm::IRBuilderBase &builder,
995 if (decl.getAtomicReductionRegion().empty())
996 return OwningAtomicReductionGen();
1002 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1003 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1004 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1005 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1006 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1007 builder.restoreIP(insertPoint);
1010 "omp.reduction.atomic.body", builder,
1011 moduleTranslation, &phis)))
1012 return llvm::createStringError(
1013 "failed to inline `atomic` region of `omp.declare_reduction`");
1014 assert(phis.empty());
1015 return builder.saveIP();
1024 auto orderedOp = cast<omp::OrderedOp>(opInst);
1029 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1030 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1031 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1033 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1035 size_t indexVecValues = 0;
1036 while (indexVecValues < vecValues.size()) {
1038 storeValues.reserve(numLoops);
1039 for (
unsigned i = 0; i < numLoops; i++) {
1040 storeValues.push_back(vecValues[indexVecValues]);
1043 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1045 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1046 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1047 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1057 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1058 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1063 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1065 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1066 builder.restoreIP(codeGenIP);
1074 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1076 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1077 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1079 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1084 builder.restoreIP(*afterIP);
1090struct DeferredStore {
1091 DeferredStore(llvm::Value *value, llvm::Value *address)
1092 : value(value), address(address) {}
1095 llvm::Value *address;
1102template <
typename T>
1105 llvm::IRBuilderBase &builder,
1107 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1113 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1114 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1117 deferredStores.reserve(loop.getNumReductionVars());
1119 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1120 Region &allocRegion = reductionDecls[i].getAllocRegion();
1122 if (allocRegion.
empty())
1127 builder, moduleTranslation, &phis)))
1128 return loop.emitError(
1129 "failed to inline `alloc` region of `omp.declare_reduction`");
1131 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1132 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1136 llvm::Value *var = builder.CreateAlloca(
1137 moduleTranslation.
convertType(reductionDecls[i].getType()));
1139 llvm::Type *ptrTy = builder.getPtrTy();
1140 llvm::Value *castVar =
1141 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1142 llvm::Value *castPhi =
1143 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1145 deferredStores.emplace_back(castPhi, castVar);
1147 privateReductionVariables[i] = castVar;
1148 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1149 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1151 assert(allocRegion.
empty() &&
1152 "allocaction is implicit for by-val reduction");
1153 llvm::Value *var = builder.CreateAlloca(
1154 moduleTranslation.
convertType(reductionDecls[i].getType()));
1156 llvm::Type *ptrTy = builder.getPtrTy();
1157 llvm::Value *castVar =
1158 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1160 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1161 privateReductionVariables[i] = castVar;
1162 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1170template <
typename T>
1177 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1178 Region &initializerRegion = reduction.getInitializerRegion();
1181 mlir::Value mlirSource = loop.getReductionVars()[i];
1182 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1183 assert(llvmSource &&
"lookup reduction var");
1184 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1187 llvm::Value *allocation =
1188 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1189 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1195 llvm::BasicBlock *block =
nullptr) {
1196 if (block ==
nullptr)
1197 block = builder.GetInsertBlock();
1199 if (block->empty() || block->getTerminator() ==
nullptr)
1200 builder.SetInsertPoint(block);
1202 builder.SetInsertPoint(block->getTerminator());
1210template <
typename OP>
1213 llvm::IRBuilderBase &builder,
1215 llvm::BasicBlock *latestAllocaBlock,
1221 if (op.getNumReductionVars() == 0)
1224 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1225 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1226 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1227 builder.restoreIP(allocaIP);
1230 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1232 if (!reductionDecls[i].getAllocRegion().empty())
1238 byRefVars[i] = builder.CreateAlloca(
1239 moduleTranslation.
convertType(reductionDecls[i].getType()));
1247 for (
auto [data, addr] : deferredStores)
1248 builder.CreateStore(data, addr);
1253 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1258 reductionVariableMap, i);
1266 "omp.reduction.neutral", builder,
1267 moduleTranslation, &phis)))
1270 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1271 "reduction neutral element declaration region");
1276 if (!reductionDecls[i].getAllocRegion().empty())
1285 builder.CreateStore(phis[0], byRefVars[i]);
1287 privateReductionVariables[i] = byRefVars[i];
1288 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1289 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1292 builder.CreateStore(phis[0], privateReductionVariables[i]);
1299 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1306template <
typename T>
1307static void collectReductionInfo(
1308 T loop, llvm::IRBuilderBase &builder,
1315 unsigned numReductions = loop.getNumReductionVars();
1317 for (
unsigned i = 0; i < numReductions; ++i) {
1320 owningAtomicReductionGens.push_back(
1325 reductionInfos.reserve(numReductions);
1326 for (
unsigned i = 0; i < numReductions; ++i) {
1327 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1328 if (owningAtomicReductionGens[i])
1329 atomicGen = owningAtomicReductionGens[i];
1330 llvm::Value *variable =
1331 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1332 reductionInfos.push_back(
1334 privateReductionVariables[i],
1335 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1346 llvm::IRBuilderBase &builder, StringRef regionName,
1347 bool shouldLoadCleanupRegionArg =
true) {
1348 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1349 if (cleanupRegion->empty())
1355 llvm::Instruction *potentialTerminator =
1356 builder.GetInsertBlock()->empty() ?
nullptr
1357 : &builder.GetInsertBlock()->back();
1358 if (potentialTerminator && potentialTerminator->isTerminator())
1359 builder.SetInsertPoint(potentialTerminator);
1360 llvm::Value *privateVarValue =
1361 shouldLoadCleanupRegionArg
1362 ? builder.CreateLoad(
1364 privateVariables[i])
1365 : privateVariables[i];
1370 moduleTranslation)))
1383 OP op, llvm::IRBuilderBase &builder,
1385 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1388 bool isNowait =
false,
bool isTeamsReduction =
false) {
1390 if (op.getNumReductionVars() == 0)
1401 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1403 privateReductionVariables, reductionInfos);
1408 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1409 builder.SetInsertPoint(tempTerminator);
1410 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1411 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1412 isByRef, isNowait, isTeamsReduction);
1417 if (!contInsertPoint->getBlock())
1418 return op->emitOpError() <<
"failed to convert reductions";
1420 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1421 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1426 tempTerminator->eraseFromParent();
1427 builder.restoreIP(*afterIP);
1431 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1432 [](omp::DeclareReductionOp reductionDecl) {
1433 return &reductionDecl.getCleanupRegion();
1436 moduleTranslation, builder,
1437 "omp.reduction.cleanup");
1448template <
typename OP>
1452 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1457 if (op.getNumReductionVars() == 0)
1463 allocaIP, reductionDecls,
1464 privateReductionVariables, reductionVariableMap,
1465 deferredStores, isByRef)))
1468 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1469 allocaIP.getBlock(), reductionDecls,
1470 privateReductionVariables, reductionVariableMap,
1471 isByRef, deferredStores);
1485 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1488 Value blockArg = (*mappedPrivateVars)[privateVar];
1491 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1492 "A block argument corresponding to a mapped var should have "
1495 if (privVarType == blockArgType)
1502 if (!isa<LLVM::LLVMPointerType>(privVarType))
1503 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1516 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1518 Region &initRegion = privDecl.getInitRegion();
1519 if (initRegion.
empty())
1520 return llvmPrivateVar;
1524 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1525 assert(nonPrivateVar);
1526 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1527 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1532 moduleTranslation, &phis)))
1533 return llvm::createStringError(
1534 "failed to inline `init` region of `omp.private`");
1536 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1555 return llvm::Error::success();
1557 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1560 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1563 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1565 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1566 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1569 return privVarOrErr.takeError();
1571 llvmPrivateVar = privVarOrErr.get();
1572 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1577 return llvm::Error::success();
1587 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1590 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1591 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1592 allocaTerminator->getIterator()),
1593 true, allocaTerminator->getStableDebugLoc(),
1594 "omp.region.after_alloca");
1596 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1598 allocaTerminator = allocaIP.getBlock()->getTerminator();
1599 builder.SetInsertPoint(allocaTerminator);
1601 assert(allocaTerminator->getNumSuccessors() == 1 &&
1602 "This is an unconditional branch created by splitBB");
1604 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1605 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1607 unsigned int allocaAS =
1608 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1611 .getProgramAddressSpace();
1613 for (
auto [privDecl, mlirPrivVar, blockArg] :
1616 llvm::Type *llvmAllocType =
1617 moduleTranslation.
convertType(privDecl.getType());
1618 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1619 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1620 llvmAllocType,
nullptr,
"omp.private.alloc");
1621 if (allocaAS != defaultAS)
1622 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1623 builder.getPtrTy(defaultAS));
1625 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1628 return afterAllocas;
1639 bool needsFirstprivate =
1640 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1641 return privOp.getDataSharingType() ==
1642 omp::DataSharingClauseType::FirstPrivate;
1645 if (!needsFirstprivate)
1648 llvm::BasicBlock *copyBlock =
1649 splitBB(builder,
true,
"omp.private.copy");
1652 for (
auto [decl, mlirVar, llvmVar] :
1653 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1654 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1658 Region ©Region = decl.getCopyRegion();
1662 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1663 assert(nonPrivateVar);
1664 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1667 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1671 moduleTranslation)))
1672 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1684 if (insertBarrier) {
1686 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1687 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1702 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1703 [](omp::PrivateClauseOp privatizer) {
1704 return &privatizer.getDeallocRegion();
1708 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1709 "omp.private.dealloc",
false)))
1710 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1711 "`omp.private` op in");
1723 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1733 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1734 using StorableBodyGenCallbackTy =
1735 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1737 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1743 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1747 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1751 sectionsOp.getNumReductionVars());
1755 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1758 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1759 reductionDecls, privateReductionVariables, reductionVariableMap,
1766 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1770 Region ®ion = sectionOp.getRegion();
1771 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1772 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1773 builder.restoreIP(codeGenIP);
1780 sectionsOp.getRegion().getNumArguments());
1781 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1782 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1783 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1785 moduleTranslation.
mapValue(sectionArg, llvmVal);
1792 sectionCBs.push_back(sectionCB);
1798 if (sectionCBs.empty())
1801 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1806 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1807 llvm::Value &vPtr, llvm::Value *&replacementValue)
1808 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1809 replacementValue = &vPtr;
1815 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1819 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1820 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1822 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1823 sectionsOp.getNowait());
1828 builder.restoreIP(*afterIP);
1832 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1833 privateReductionVariables, isByRef, sectionsOp.getNowait());
1840 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1841 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1846 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1847 builder.restoreIP(codegenIP);
1849 builder, moduleTranslation)
1852 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1856 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1859 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1860 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1862 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1863 llvmCPFuncs.push_back(
1867 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1869 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1875 builder.restoreIP(*afterIP);
1881 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1886 for (
auto ra : iface.getReductionBlockArgs())
1887 for (
auto &use : ra.getUses()) {
1888 auto *useOp = use.getOwner();
1890 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1891 debugUses.push_back(useOp);
1895 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1900 Operation *currentOp = currentDistOp.getOperation();
1901 if (distOp && (distOp != currentOp))
1910 for (
auto use : debugUses)
1919 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1924 unsigned numReductionVars = op.getNumReductionVars();
1928 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1934 if (doTeamsReduction) {
1935 isByRef =
getIsByRef(op.getReductionByref());
1937 assert(isByRef.size() == op.getNumReductionVars());
1940 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1945 op, reductionArgs, builder, moduleTranslation, allocaIP,
1946 reductionDecls, privateReductionVariables, reductionVariableMap,
1951 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1953 moduleTranslation, allocaIP);
1954 builder.restoreIP(codegenIP);
1960 llvm::Value *numTeamsLower =
nullptr;
1961 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1962 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1964 llvm::Value *numTeamsUpper =
nullptr;
1965 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1966 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1968 llvm::Value *threadLimit =
nullptr;
1969 if (
Value threadLimitVar = op.getThreadLimit())
1970 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1972 llvm::Value *ifExpr =
nullptr;
1973 if (
Value ifVar = op.getIfExpr())
1976 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1977 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1979 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1984 builder.restoreIP(*afterIP);
1985 if (doTeamsReduction) {
1988 op, builder, moduleTranslation, allocaIP, reductionDecls,
1989 privateReductionVariables, isByRef,
1999 if (dependVars.empty())
2001 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2002 llvm::omp::RTLDependenceKindTy type;
2004 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2005 case mlir::omp::ClauseTaskDepend::taskdependin:
2006 type = llvm::omp::RTLDependenceKindTy::DepIn;
2011 case mlir::omp::ClauseTaskDepend::taskdependout:
2012 case mlir::omp::ClauseTaskDepend::taskdependinout:
2013 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2015 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2016 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2018 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2019 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2022 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2023 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2024 dds.emplace_back(dd);
2036 llvm::IRBuilderBase &llvmBuilder,
2038 llvm::omp::Directive cancelDirective) {
2039 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2040 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2044 llvmBuilder.restoreIP(ip);
2050 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2051 return llvm::Error::success();
2056 ompBuilder.pushFinalizationCB(
2066 llvm::OpenMPIRBuilder &ompBuilder,
2067 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2068 ompBuilder.popFinalizationCB();
2069 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2070 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2071 assert(cancelBranch->getNumSuccessors() == 1 &&
2072 "cancel branch should have one target");
2073 cancelBranch->setSuccessor(0, constructFini);
2080class TaskContextStructManager {
2082 TaskContextStructManager(llvm::IRBuilderBase &builder,
2083 LLVM::ModuleTranslation &moduleTranslation,
2084 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2085 : builder{builder}, moduleTranslation{moduleTranslation},
2086 privateDecls{privateDecls} {}
2092 void generateTaskContextStruct();
2098 void createGEPsToPrivateVars();
2101 void freeStructPtr();
2103 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2104 return llvmPrivateVarGEPs;
2107 llvm::Value *getStructPtr() {
return structPtr; }
2110 llvm::IRBuilderBase &builder;
2111 LLVM::ModuleTranslation &moduleTranslation;
2112 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2115 SmallVector<llvm::Type *> privateVarTypes;
2119 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2122 llvm::Value *structPtr =
nullptr;
2124 llvm::Type *structTy =
nullptr;
2128void TaskContextStructManager::generateTaskContextStruct() {
2129 if (privateDecls.empty())
2131 privateVarTypes.reserve(privateDecls.size());
2133 for (omp::PrivateClauseOp &privOp : privateDecls) {
2136 if (!privOp.readsFromMold())
2138 Type mlirType = privOp.getType();
2139 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2142 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2145 llvm::DataLayout dataLayout =
2146 builder.GetInsertBlock()->getModule()->getDataLayout();
2147 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2148 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2151 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2153 "omp.task.context_ptr");
2156void TaskContextStructManager::createGEPsToPrivateVars() {
2158 assert(privateVarTypes.empty());
2163 llvmPrivateVarGEPs.clear();
2164 llvmPrivateVarGEPs.reserve(privateDecls.size());
2165 llvm::Value *zero = builder.getInt32(0);
2167 for (
auto privDecl : privateDecls) {
2168 if (!privDecl.readsFromMold()) {
2170 llvmPrivateVarGEPs.push_back(
nullptr);
2173 llvm::Value *iVal = builder.getInt32(i);
2174 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2175 llvmPrivateVarGEPs.push_back(gep);
2180void TaskContextStructManager::freeStructPtr() {
2184 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2186 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2187 builder.CreateFree(structPtr);
2194 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2199 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2211 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2216 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2217 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2218 builder.getContext(),
"omp.task.start",
2219 builder.GetInsertBlock()->getParent());
2220 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2221 builder.SetInsertPoint(branchToTaskStartBlock);
2224 llvm::BasicBlock *copyBlock =
2225 splitBB(builder,
true,
"omp.private.copy");
2226 llvm::BasicBlock *initBlock =
2227 splitBB(builder,
true,
"omp.private.init");
2243 moduleTranslation, allocaIP);
2246 builder.SetInsertPoint(initBlock->getTerminator());
2249 taskStructMgr.generateTaskContextStruct();
2256 taskStructMgr.createGEPsToPrivateVars();
2258 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2261 taskStructMgr.getLLVMPrivateVarGEPs())) {
2263 if (!privDecl.readsFromMold())
2265 assert(llvmPrivateVarAlloc &&
2266 "reads from mold so shouldn't have been skipped");
2269 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2270 blockArg, llvmPrivateVarAlloc, initBlock);
2271 if (!privateVarOrErr)
2272 return handleError(privateVarOrErr, *taskOp.getOperation());
2281 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2282 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2283 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2285 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2286 llvmPrivateVarAlloc);
2288 assert(llvmPrivateVarAlloc->getType() ==
2289 moduleTranslation.
convertType(blockArg.getType()));
2299 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2300 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2301 taskOp.getPrivateNeedsBarrier())))
2302 return llvm::failure();
2305 builder.SetInsertPoint(taskStartBlock);
2307 auto bodyCB = [&](InsertPointTy allocaIP,
2308 InsertPointTy codegenIP) -> llvm::Error {
2312 moduleTranslation, allocaIP);
2315 builder.restoreIP(codegenIP);
2317 llvm::BasicBlock *privInitBlock =
nullptr;
2319 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2322 auto [blockArg, privDecl, mlirPrivVar] = zip;
2324 if (privDecl.readsFromMold())
2327 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2328 llvm::Type *llvmAllocType =
2329 moduleTranslation.
convertType(privDecl.getType());
2330 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2331 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2332 llvmAllocType,
nullptr,
"omp.private.alloc");
2335 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2336 blockArg, llvmPrivateVar, privInitBlock);
2337 if (!privateVarOrError)
2338 return privateVarOrError.takeError();
2339 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2340 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2343 taskStructMgr.createGEPsToPrivateVars();
2344 for (
auto [i, llvmPrivVar] :
2345 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2347 assert(privateVarsInfo.
llvmVars[i] &&
2348 "This is added in the loop above");
2351 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2356 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2360 if (!privateDecl.readsFromMold())
2363 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2364 llvmPrivateVar = builder.CreateLoad(
2365 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2367 assert(llvmPrivateVar->getType() ==
2368 moduleTranslation.
convertType(blockArg.getType()));
2369 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2373 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2374 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2375 return llvm::make_error<PreviouslyReportedError>();
2377 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2382 return llvm::make_error<PreviouslyReportedError>();
2385 taskStructMgr.freeStructPtr();
2387 return llvm::Error::success();
2396 llvm::omp::Directive::OMPD_taskgroup);
2400 moduleTranslation, dds);
2402 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2403 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2405 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2407 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2408 taskOp.getMergeable(),
2409 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2410 moduleTranslation.
lookupValue(taskOp.getPriority()));
2418 builder.restoreIP(*afterIP);
2426 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2430 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2431 builder.restoreIP(codegenIP);
2433 builder, moduleTranslation)
2438 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2439 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2446 builder.restoreIP(*afterIP);
2465 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2469 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2471 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2475 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2478 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2479 llvm::Type *ivType = step->getType();
2480 llvm::Value *chunk =
nullptr;
2481 if (wsloopOp.getScheduleChunk()) {
2482 llvm::Value *chunkVar =
2483 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2484 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2491 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2495 wsloopOp.getNumReductionVars());
2498 builder, moduleTranslation, privateVarsInfo, allocaIP);
2505 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2510 moduleTranslation, allocaIP, reductionDecls,
2511 privateReductionVariables, reductionVariableMap,
2512 deferredStores, isByRef)))
2521 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2523 wsloopOp.getPrivateNeedsBarrier())))
2526 assert(afterAllocas.get()->getSinglePredecessor());
2527 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2529 afterAllocas.get()->getSinglePredecessor(),
2530 reductionDecls, privateReductionVariables,
2531 reductionVariableMap, isByRef, deferredStores)))
2535 bool isOrdered = wsloopOp.getOrdered().has_value();
2536 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2537 bool isSimd = wsloopOp.getScheduleSimd();
2538 bool loopNeedsBarrier = !wsloopOp.getNowait();
2543 llvm::omp::WorksharingLoopType workshareLoopType =
2544 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2545 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2546 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2550 llvm::omp::Directive::OMPD_for);
2552 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2555 LinearClauseProcessor linearClauseProcessor;
2556 if (!wsloopOp.getLinearVars().empty()) {
2557 for (
mlir::Value linearVar : wsloopOp.getLinearVars())
2558 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2560 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
2561 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2565 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2573 if (!wsloopOp.getLinearVars().empty()) {
2574 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2575 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2576 loopInfo->getPreheader());
2579 builder.restoreIP(*afterBarrierIP);
2580 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2581 loopInfo->getIndVar());
2582 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2583 loopInfo->getExit());
2586 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2589 bool noLoopMode =
false;
2590 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
2592 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
2596 if (loopOp == targetCapturedOp) {
2597 omp::TargetRegionFlags kernelFlags =
2598 targetOp.getKernelExecFlags(targetCapturedOp);
2599 if (omp::bitEnumContainsAll(kernelFlags,
2600 omp::TargetRegionFlags::spmd |
2601 omp::TargetRegionFlags::no_loop) &&
2602 !omp::bitEnumContainsAny(kernelFlags,
2603 omp::TargetRegionFlags::generic))
2608 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2609 ompBuilder->applyWorkshareLoop(
2610 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2611 convertToScheduleKind(schedule), chunk, isSimd,
2612 scheduleMod == omp::ScheduleModifier::monotonic,
2613 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2614 workshareLoopType, noLoopMode);
2620 if (!wsloopOp.getLinearVars().empty()) {
2621 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2622 assert(loopInfo->getLastIter() &&
2623 "`lastiter` in CanonicalLoopInfo is nullptr");
2624 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2625 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2626 loopInfo->getLastIter());
2629 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
2630 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
2632 builder.restoreIP(oldIP);
2640 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2641 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2654 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2656 assert(isByRef.size() == opInst.getNumReductionVars());
2668 opInst.getNumReductionVars());
2671 auto bodyGenCB = [&](InsertPointTy allocaIP,
2672 InsertPointTy codeGenIP) -> llvm::Error {
2674 builder, moduleTranslation, privateVarsInfo, allocaIP);
2676 return llvm::make_error<PreviouslyReportedError>();
2682 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2685 InsertPointTy(allocaIP.getBlock(),
2686 allocaIP.getBlock()->getTerminator()->getIterator());
2689 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2690 reductionDecls, privateReductionVariables, reductionVariableMap,
2691 deferredStores, isByRef)))
2692 return llvm::make_error<PreviouslyReportedError>();
2694 assert(afterAllocas.get()->getSinglePredecessor());
2695 builder.restoreIP(codeGenIP);
2701 return llvm::make_error<PreviouslyReportedError>();
2704 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2706 opInst.getPrivateNeedsBarrier())))
2707 return llvm::make_error<PreviouslyReportedError>();
2710 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
2711 afterAllocas.get()->getSinglePredecessor(),
2712 reductionDecls, privateReductionVariables,
2713 reductionVariableMap, isByRef, deferredStores)))
2714 return llvm::make_error<PreviouslyReportedError>();
2719 moduleTranslation, allocaIP);
2723 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2725 return regionBlock.takeError();
2728 if (opInst.getNumReductionVars() > 0) {
2733 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
2735 privateReductionVariables, reductionInfos);
2738 builder.SetInsertPoint((*regionBlock)->getTerminator());
2741 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2742 builder.SetInsertPoint(tempTerminator);
2744 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2745 ompBuilder->createReductions(
2746 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2748 if (!contInsertPoint)
2749 return contInsertPoint.takeError();
2751 if (!contInsertPoint->getBlock())
2752 return llvm::make_error<PreviouslyReportedError>();
2754 tempTerminator->eraseFromParent();
2755 builder.restoreIP(*contInsertPoint);
2758 return llvm::Error::success();
2761 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2762 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2771 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2772 InsertPointTy oldIP = builder.saveIP();
2773 builder.restoreIP(codeGenIP);
2778 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2779 [](omp::DeclareReductionOp reductionDecl) {
2780 return &reductionDecl.getCleanupRegion();
2783 reductionCleanupRegions, privateReductionVariables,
2784 moduleTranslation, builder,
"omp.reduction.cleanup")))
2785 return llvm::createStringError(
2786 "failed to inline `cleanup` region of `omp.declare_reduction`");
2791 return llvm::make_error<PreviouslyReportedError>();
2793 builder.restoreIP(oldIP);
2794 return llvm::Error::success();
2797 llvm::Value *ifCond =
nullptr;
2798 if (
auto ifVar = opInst.getIfExpr())
2800 llvm::Value *numThreads =
nullptr;
2801 if (
auto numThreadsVar = opInst.getNumThreads())
2802 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2803 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2804 if (
auto bind = opInst.getProcBindKind())
2808 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2810 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2812 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2813 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2814 ifCond, numThreads, pbKind, isCancellable);
2819 builder.restoreIP(*afterIP);
2824static llvm::omp::OrderKind
2827 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2829 case omp::ClauseOrderKind::Concurrent:
2830 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2832 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2840 auto simdOp = cast<omp::SimdOp>(opInst);
2848 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2851 simdOp.getNumReductionVars());
2856 assert(isByRef.size() == simdOp.getNumReductionVars());
2858 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2862 builder, moduleTranslation, privateVarsInfo, allocaIP);
2867 moduleTranslation, allocaIP, reductionDecls,
2868 privateReductionVariables, reductionVariableMap,
2869 deferredStores, isByRef)))
2880 assert(afterAllocas.get()->getSinglePredecessor());
2881 if (failed(initReductionVars(simdOp, reductionArgs, builder,
2883 afterAllocas.get()->getSinglePredecessor(),
2884 reductionDecls, privateReductionVariables,
2885 reductionVariableMap, isByRef, deferredStores)))
2888 llvm::ConstantInt *simdlen =
nullptr;
2889 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2890 simdlen = builder.getInt64(simdlenVar.value());
2892 llvm::ConstantInt *safelen =
nullptr;
2893 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2894 safelen = builder.getInt64(safelenVar.value());
2896 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2899 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2900 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2902 for (
size_t i = 0; i < operands.size(); ++i) {
2903 llvm::Value *alignment =
nullptr;
2904 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
2905 llvm::Type *ty = llvmVal->getType();
2907 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2908 alignment = builder.getInt64(intAttr.getInt());
2909 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2910 assert(alignment &&
"Invalid alignment value");
2914 if (!intAttr.getValue().isPowerOf2())
2917 auto curInsert = builder.saveIP();
2918 builder.SetInsertPoint(sourceBlock);
2919 llvmVal = builder.CreateLoad(ty, llvmVal);
2920 builder.restoreIP(curInsert);
2921 alignedVars[llvmVal] = alignment;
2925 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2930 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2932 ompBuilder->applySimd(loopInfo, alignedVars,
2934 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2936 order, simdlen, safelen);
2942 for (
auto [i, tuple] : llvm::enumerate(
2943 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
2944 privateReductionVariables))) {
2945 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
2947 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
2948 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
2949 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
2953 llvm::Value *redValue = originalVariable;
2956 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
2957 llvm::Value *privateRedValue = builder.CreateLoad(
2958 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
2959 llvm::Value *reduced;
2961 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
2964 builder.restoreIP(res.get());
2968 builder.CreateStore(reduced, originalVariable);
2973 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
2974 [](omp::DeclareReductionOp reductionDecl) {
2975 return &reductionDecl.getCleanupRegion();
2978 moduleTranslation, builder,
2979 "omp.reduction.cleanup")))
2992 auto loopOp = cast<omp::LoopNestOp>(opInst);
2995 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3000 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3001 llvm::Value *iv) -> llvm::Error {
3004 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3009 bodyInsertPoints.push_back(ip);
3011 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3012 return llvm::Error::success();
3015 builder.restoreIP(ip);
3017 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3019 return regionBlock.takeError();
3021 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3022 return llvm::Error::success();
3030 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3031 llvm::Value *lowerBound =
3032 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3033 llvm::Value *upperBound =
3034 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3035 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3040 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3041 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3043 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3045 computeIP = loopInfos.front()->getPreheaderIP();
3049 ompBuilder->createCanonicalLoop(
3050 loc, bodyGen, lowerBound, upperBound, step,
3051 true, loopOp.getLoopInclusive(), computeIP);
3056 loopInfos.push_back(*loopResult);
3059 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3060 loopInfos.front()->getAfterIP();
3063 if (
const auto &tiles = loopOp.getTileSizes()) {
3064 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3067 for (
auto tile : tiles.value()) {
3068 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3069 tileSizes.push_back(tileVal);
3072 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3073 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3077 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3078 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3079 afterIP = {afterAfterBB, afterAfterBB->begin()};
3083 for (
const auto &newLoop : newLoops)
3084 loopInfos.push_back(newLoop);
3088 const auto &numCollapse = loopOp.getCollapseNumLoops();
3090 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3092 auto newTopLoopInfo =
3093 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3095 assert(newTopLoopInfo &&
"New top loop information is missing");
3096 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3097 [&](OpenMPLoopInfoStackFrame &frame) {
3098 frame.loopInfo = newTopLoopInfo;
3106 builder.restoreIP(afterIP);
3116 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3117 Value loopIV = op.getInductionVar();
3118 Value loopTC = op.getTripCount();
3120 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3123 ompBuilder->createCanonicalLoop(
3125 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3128 moduleTranslation.
mapValue(loopIV, llvmIV);
3130 builder.restoreIP(ip);
3135 return bodyGenStatus.takeError();
3137 llvmTC,
"omp.loop");
3139 return op.emitError(llvm::toString(llvmOrError.takeError()));
3141 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3142 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3143 builder.restoreIP(afterIP);
3146 if (
Value cli = op.getCli())
3159 Value applyee = op.getApplyee();
3160 assert(applyee &&
"Loop to apply unrolling on required");
3162 llvm::CanonicalLoopInfo *consBuilderCLI =
3164 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3165 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3173static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3176 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3181 for (
Value size : op.getSizes()) {
3182 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3183 assert(translatedSize &&
3184 "sizes clause arguments must already be translated");
3185 translatedSizes.push_back(translatedSize);
3188 for (
Value applyee : op.getApplyees()) {
3189 llvm::CanonicalLoopInfo *consBuilderCLI =
3191 assert(applyee &&
"Canonical loop must already been translated");
3192 translatedLoops.push_back(consBuilderCLI);
3195 auto generatedLoops =
3196 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3197 if (!op.getGeneratees().empty()) {
3198 for (
auto [mlirLoop,
genLoop] :
3199 zip_equal(op.getGeneratees(), generatedLoops))
3204 for (
Value applyee : op.getApplyees())
3211static llvm::AtomicOrdering
3214 return llvm::AtomicOrdering::Monotonic;
3217 case omp::ClauseMemoryOrderKind::Seq_cst:
3218 return llvm::AtomicOrdering::SequentiallyConsistent;
3219 case omp::ClauseMemoryOrderKind::Acq_rel:
3220 return llvm::AtomicOrdering::AcquireRelease;
3221 case omp::ClauseMemoryOrderKind::Acquire:
3222 return llvm::AtomicOrdering::Acquire;
3223 case omp::ClauseMemoryOrderKind::Release:
3224 return llvm::AtomicOrdering::Release;
3225 case omp::ClauseMemoryOrderKind::Relaxed:
3226 return llvm::AtomicOrdering::Monotonic;
3228 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3235 auto readOp = cast<omp::AtomicReadOp>(opInst);
3240 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3243 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3246 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3247 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3249 llvm::Type *elementType =
3250 moduleTranslation.
convertType(readOp.getElementType());
3252 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3253 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3254 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3262 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3267 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3270 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3272 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3273 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3274 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3275 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3278 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3286 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3287 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3288 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3289 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3290 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3291 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3292 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3293 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3294 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3295 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3299 bool &isIgnoreDenormalMode,
3300 bool &isFineGrainedMemory,
3301 bool &isRemoteMemory) {
3302 isIgnoreDenormalMode =
false;
3303 isFineGrainedMemory =
false;
3304 isRemoteMemory =
false;
3305 if (atomicUpdateOp &&
3306 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3307 mlir::omp::AtomicControlAttr atomicControlAttr =
3308 atomicUpdateOp.getAtomicControlAttr();
3309 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3310 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3311 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3318 llvm::IRBuilderBase &builder,
3325 auto &innerOpList = opInst.getRegion().front().getOperations();
3326 bool isXBinopExpr{
false};
3327 llvm::AtomicRMWInst::BinOp binop;
3329 llvm::Value *llvmExpr =
nullptr;
3330 llvm::Value *llvmX =
nullptr;
3331 llvm::Type *llvmXElementType =
nullptr;
3332 if (innerOpList.size() == 2) {
3338 opInst.getRegion().getArgument(0))) {
3339 return opInst.emitError(
"no atomic update operation with region argument"
3340 " as operand found inside atomic.update region");
3343 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3345 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3349 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3351 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3353 opInst.getRegion().getArgument(0).getType());
3354 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3358 llvm::AtomicOrdering atomicOrdering =
3363 [&opInst, &moduleTranslation](
3364 llvm::Value *atomicx,
3367 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3368 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3369 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3370 return llvm::make_error<PreviouslyReportedError>();
3372 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3373 assert(yieldop && yieldop.getResults().size() == 1 &&
3374 "terminator must be omp.yield op and it must have exactly one "
3376 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3379 bool isIgnoreDenormalMode;
3380 bool isFineGrainedMemory;
3381 bool isRemoteMemory;
3386 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3387 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3388 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3389 atomicOrdering, binop, updateFn,
3390 isXBinopExpr, isIgnoreDenormalMode,
3391 isFineGrainedMemory, isRemoteMemory);
3396 builder.restoreIP(*afterIP);
3402 llvm::IRBuilderBase &builder,
3409 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3410 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3412 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3413 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3415 assert((atomicUpdateOp || atomicWriteOp) &&
3416 "internal op must be an atomic.update or atomic.write op");
3418 if (atomicWriteOp) {
3419 isPostfixUpdate =
true;
3420 mlirExpr = atomicWriteOp.getExpr();
3422 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3423 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3424 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3427 if (innerOpList.size() == 2) {
3430 atomicUpdateOp.getRegion().getArgument(0))) {
3431 return atomicUpdateOp.emitError(
3432 "no atomic update operation with region argument"
3433 " as operand found inside atomic.update region");
3437 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3440 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3444 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3445 llvm::Value *llvmX =
3446 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3447 llvm::Value *llvmV =
3448 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3449 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3450 atomicCaptureOp.getAtomicReadOp().getElementType());
3451 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3454 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3458 llvm::AtomicOrdering atomicOrdering =
3462 [&](llvm::Value *atomicx,
3465 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3466 Block &bb = *atomicUpdateOp.getRegion().
begin();
3467 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3469 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3470 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3471 return llvm::make_error<PreviouslyReportedError>();
3473 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3474 assert(yieldop && yieldop.getResults().size() == 1 &&
3475 "terminator must be omp.yield op and it must have exactly one "
3477 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3480 bool isIgnoreDenormalMode;
3481 bool isFineGrainedMemory;
3482 bool isRemoteMemory;
3484 isFineGrainedMemory, isRemoteMemory);
3487 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3488 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3489 ompBuilder->createAtomicCapture(
3490 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3491 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
3492 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
3494 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3497 builder.restoreIP(*afterIP);
3502 omp::ClauseCancellationConstructType directive) {
3503 switch (directive) {
3504 case omp::ClauseCancellationConstructType::Loop:
3505 return llvm::omp::Directive::OMPD_for;
3506 case omp::ClauseCancellationConstructType::Parallel:
3507 return llvm::omp::Directive::OMPD_parallel;
3508 case omp::ClauseCancellationConstructType::Sections:
3509 return llvm::omp::Directive::OMPD_sections;
3510 case omp::ClauseCancellationConstructType::Taskgroup:
3511 return llvm::omp::Directive::OMPD_taskgroup;
3513 llvm_unreachable(
"Unhandled cancellation construct type");
3522 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3525 llvm::Value *ifCond =
nullptr;
3526 if (
Value ifVar = op.getIfExpr())
3529 llvm::omp::Directive cancelledDirective =
3532 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3533 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3535 if (failed(
handleError(afterIP, *op.getOperation())))
3538 builder.restoreIP(afterIP.get());
3545 llvm::IRBuilderBase &builder,
3550 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3553 llvm::omp::Directive cancelledDirective =
3556 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3557 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3559 if (failed(
handleError(afterIP, *op.getOperation())))
3562 builder.restoreIP(afterIP.get());
3572 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3574 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3579 Value symAddr = threadprivateOp.getSymAddr();
3582 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3585 if (!isa<LLVM::AddressOfOp>(symOp))
3586 return opInst.
emitError(
"Addressing symbol not found");
3587 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3589 LLVM::GlobalOp global =
3590 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3591 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3593 if (!ompBuilder->Config.isTargetDevice()) {
3594 llvm::Type *type = globalValue->getValueType();
3595 llvm::TypeSize typeSize =
3596 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3598 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3599 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3600 ompLoc, globalValue, size, global.getSymName() +
".cache");
3609static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3611 switch (deviceClause) {
3612 case mlir::omp::DeclareTargetDeviceType::host:
3613 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3615 case mlir::omp::DeclareTargetDeviceType::nohost:
3616 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3618 case mlir::omp::DeclareTargetDeviceType::any:
3619 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3622 llvm_unreachable(
"unhandled device clause");
3625static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3627 mlir::omp::DeclareTargetCaptureClause captureClause) {
3628 switch (captureClause) {
3629 case mlir::omp::DeclareTargetCaptureClause::to:
3630 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3631 case mlir::omp::DeclareTargetCaptureClause::link:
3632 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3633 case mlir::omp::DeclareTargetCaptureClause::enter:
3634 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3636 llvm_unreachable(
"unhandled capture clause");
3639static llvm::SmallString<64>
3641 llvm::OpenMPIRBuilder &ompBuilder) {
3643 llvm::raw_svector_ostream os(suffix);
3646 auto fileInfoCallBack = [&loc]() {
3647 return std::pair<std::string, uint64_t>(
3648 llvm::StringRef(loc.getFilename()), loc.getLine());
3651 auto vfs = llvm::vfs::getRealFileSystem();
3654 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
3656 os <<
"_decl_tgt_ref_ptr";
3662 if (
auto addressOfOp = value.
getDefiningOp<LLVM::AddressOfOp>()) {
3663 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3664 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3665 if (
auto declareTargetGlobal =
3666 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3667 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3668 mlir::omp::DeclareTargetCaptureClause::link)
3682 if (
auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
3687 if (
auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
3688 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3689 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3690 addressOfOp.getGlobalName()))) {
3692 if (
auto declareTargetGlobal =
3693 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3694 gOp.getOperation())) {
3698 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3699 mlir::omp::DeclareTargetCaptureClause::link) ||
3700 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3701 mlir::omp::DeclareTargetCaptureClause::to &&
3702 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3706 if (gOp.getSymName().contains(suffix))
3711 (gOp.getSymName().str() + suffix.str()).str());
3722struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3723 SmallVector<Operation *, 4> Mappers;
3726 void append(MapInfosTy &curInfo) {
3727 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3728 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3737struct MapInfoData : MapInfosTy {
3738 llvm::SmallVector<bool, 4> IsDeclareTarget;
3739 llvm::SmallVector<bool, 4> IsAMember;
3741 llvm::SmallVector<bool, 4> IsAMapping;
3742 llvm::SmallVector<mlir::Operation *, 4> MapClause;
3743 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
3746 llvm::SmallVector<llvm::Type *, 4> BaseType;
3749 void append(MapInfoData &CurInfo) {
3750 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3751 CurInfo.IsDeclareTarget.end());
3752 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3753 OriginalValue.append(CurInfo.OriginalValue.begin(),
3754 CurInfo.OriginalValue.end());
3755 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3756 MapInfosTy::append(CurInfo);
3763 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3764 arrTy.getElementType()))
3781 llvm::Value *basePointer,
3782 llvm::Type *baseType,
3783 llvm::IRBuilderBase &builder,
3785 if (
auto memberClause =
3786 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3791 if (!memberClause.getBounds().empty()) {
3792 llvm::Value *elementCount = builder.getInt64(1);
3793 for (
auto bounds : memberClause.getBounds()) {
3794 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3795 bounds.getDefiningOp())) {
3800 elementCount = builder.CreateMul(
3804 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3805 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3806 builder.getInt64(1)));
3813 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3821 return builder.CreateMul(elementCount,
3822 builder.getInt64(underlyingTypeSzInBits / 8));
3833static llvm::omp::OpenMPOffloadMappingFlags
3835 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
3836 return (mlirFlags & flag) == flag;
3839 llvm::omp::OpenMPOffloadMappingFlags mapType =
3840 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
3843 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
3846 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
3849 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3852 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
3855 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3858 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
3861 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3864 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
3867 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
3870 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
3873 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
3876 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
3885 ArrayRef<Value> useDevAddrOperands = {},
3886 ArrayRef<Value> hasDevAddrOperands = {}) {
3887 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3895 for (Value mapValue : mapVars) {
3896 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3897 for (
auto member : map.getMembers())
3898 if (member == mapOp)
3905 for (Value mapValue : mapVars) {
3906 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3908 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3909 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
3910 mapData.Pointers.push_back(mapData.OriginalValue.back());
3912 if (llvm::Value *refPtr =
3914 moduleTranslation)) {
3915 mapData.IsDeclareTarget.push_back(
true);
3916 mapData.BasePointers.push_back(refPtr);
3918 mapData.IsDeclareTarget.push_back(
false);
3919 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3922 mapData.BaseType.push_back(
3923 moduleTranslation.
convertType(mapOp.getVarType()));
3924 mapData.Sizes.push_back(
3925 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3926 mapData.BaseType.back(), builder, moduleTranslation));
3927 mapData.MapClause.push_back(mapOp.getOperation());
3931 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3932 if (mapOp.getMapperId())
3933 mapData.Mappers.push_back(
3935 mapOp, mapOp.getMapperIdAttr()));
3937 mapData.Mappers.push_back(
nullptr);
3938 mapData.IsAMapping.push_back(
true);
3939 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3942 auto findMapInfo = [&mapData](llvm::Value *val,
3943 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3946 for (llvm::Value *basePtr : mapData.OriginalValue) {
3947 if (basePtr == val && mapData.IsAMapping[index]) {
3949 mapData.Types[index] |=
3950 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3951 mapData.DevicePointers[index] = devInfoTy;
3959 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
3960 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3961 for (Value mapValue : useDevOperands) {
3962 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3964 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3965 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3968 if (!findMapInfo(origValue, devInfoTy)) {
3969 mapData.OriginalValue.push_back(origValue);
3970 mapData.Pointers.push_back(mapData.OriginalValue.back());
3971 mapData.IsDeclareTarget.push_back(
false);
3972 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3973 mapData.BaseType.push_back(
3974 moduleTranslation.
convertType(mapOp.getVarType()));
3975 mapData.Sizes.push_back(builder.getInt64(0));
3976 mapData.MapClause.push_back(mapOp.getOperation());
3977 mapData.Types.push_back(
3978 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3981 mapData.DevicePointers.push_back(devInfoTy);
3982 mapData.Mappers.push_back(
nullptr);
3983 mapData.IsAMapping.push_back(
false);
3984 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3989 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3990 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3992 for (Value mapValue : hasDevAddrOperands) {
3993 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3995 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3996 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3998 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4000 mapData.OriginalValue.push_back(origValue);
4001 mapData.BasePointers.push_back(origValue);
4002 mapData.Pointers.push_back(origValue);
4003 mapData.IsDeclareTarget.push_back(
false);
4004 mapData.BaseType.push_back(
4005 moduleTranslation.
convertType(mapOp.getVarType()));
4006 mapData.Sizes.push_back(
4007 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4008 mapData.MapClause.push_back(mapOp.getOperation());
4009 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4013 mapData.Types.push_back(mapType);
4017 if (mapOp.getMapperId()) {
4018 mapData.Mappers.push_back(
4020 mapOp, mapOp.getMapperIdAttr()));
4022 mapData.Mappers.push_back(
nullptr);
4025 mapData.Types.push_back(
4026 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4027 mapData.Mappers.push_back(
nullptr);
4031 mapData.DevicePointers.push_back(
4032 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4033 mapData.IsAMapping.push_back(
false);
4034 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4039 auto *res = llvm::find(mapData.MapClause, memberOp);
4040 assert(res != mapData.MapClause.end() &&
4041 "MapInfoOp for member not found in MapData, cannot return index");
4042 return std::distance(mapData.MapClause.begin(), res);
4047 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4049 if (indexAttr.size() == 1)
4050 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4055 llvm::sort(
indices, [&](
const size_t a,
const size_t b) {
4056 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4057 auto memberIndicesB = cast<ArrayAttr>(indexAttr[
b]);
4058 for (
const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4059 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
4060 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
4062 if (aIndex == bIndex)
4065 if (aIndex < bIndex)
4068 if (aIndex > bIndex)
4075 return memberIndicesA.size() < memberIndicesB.size();
4078 return llvm::cast<omp::MapInfoOp>(
4079 mapInfo.getMembers()[
indices.front()].getDefiningOp());
4102static std::vector<llvm::Value *>
4104 llvm::IRBuilderBase &builder,
bool isArrayTy,
4106 std::vector<llvm::Value *> idx;
4117 idx.push_back(builder.getInt64(0));
4118 for (
int i = bounds.size() - 1; i >= 0; --i) {
4119 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4120 bounds[i].getDefiningOp())) {
4121 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4139 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4140 for (
int i = bounds.size() - 1; i >= 0; --i) {
4141 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4142 bounds[i].getDefiningOp())) {
4143 if (i == ((
int)bounds.size() - 1))
4145 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4147 idx.back() = builder.CreateAdd(
4148 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4149 boundOp.getExtent())),
4150 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4175 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4176 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
4177 assert(!ompBuilder.Config.isTargetDevice() &&
4178 "function only supported for host device codegen");
4184 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4186 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4187 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4190 bool hasUserMapper = mapData.Mappers[mapDataIndex] !=
nullptr;
4191 if (hasUserMapper) {
4192 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4196 mapFlags parentFlags = mapData.Types[mapDataIndex];
4197 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4198 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4199 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4200 baseFlag |= (parentFlags & preserve);
4203 combinedInfo.Types.emplace_back(baseFlag);
4204 combinedInfo.DevicePointers.emplace_back(
4205 mapData.DevicePointers[mapDataIndex]);
4206 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4208 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4209 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4219 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4221 llvm::Value *lowAddr, *highAddr;
4222 if (!parentClause.getPartialMap()) {
4223 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4224 builder.getPtrTy());
4225 highAddr = builder.CreatePointerCast(
4226 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4227 mapData.Pointers[mapDataIndex], 1),
4228 builder.getPtrTy());
4229 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4231 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4234 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4235 builder.getPtrTy());
4238 highAddr = builder.CreatePointerCast(
4239 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4240 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4241 builder.getPtrTy());
4242 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4245 llvm::Value *size = builder.CreateIntCast(
4246 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4247 builder.getInt64Ty(),
4249 combinedInfo.Sizes.push_back(size);
4251 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4252 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4260 if (!parentClause.getPartialMap()) {
4265 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4266 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4267 combinedInfo.Types.emplace_back(mapFlag);
4268 combinedInfo.DevicePointers.emplace_back(
4269 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4270 combinedInfo.Mappers.emplace_back(
nullptr);
4272 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4273 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4274 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4275 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4277 return memberOfFlag;
4289 if (mapOp.getVarPtrPtr())
4304 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4305 MapInfoData &mapData, uint64_t mapDataIndex,
4306 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4307 assert(!ompBuilder.Config.isTargetDevice() &&
4308 "function only supported for host device codegen");
4311 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4313 for (
auto mappedMembers : parentClause.getMembers()) {
4315 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4318 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
4329 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4330 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4331 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4332 combinedInfo.Types.emplace_back(mapFlag);
4333 combinedInfo.DevicePointers.emplace_back(
4334 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4335 combinedInfo.Mappers.emplace_back(
nullptr);
4336 combinedInfo.Names.emplace_back(
4338 combinedInfo.BasePointers.emplace_back(
4339 mapData.BasePointers[mapDataIndex]);
4340 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4341 combinedInfo.Sizes.emplace_back(builder.getInt64(
4342 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
4348 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4349 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4350 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4352 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4354 combinedInfo.Types.emplace_back(mapFlag);
4355 combinedInfo.DevicePointers.emplace_back(
4356 mapData.DevicePointers[memberDataIdx]);
4357 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4358 combinedInfo.Names.emplace_back(
4360 uint64_t basePointerIndex =
4362 combinedInfo.BasePointers.emplace_back(
4363 mapData.BasePointers[basePointerIndex]);
4364 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4366 llvm::Value *size = mapData.Sizes[memberDataIdx];
4368 size = builder.CreateSelect(
4369 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4370 builder.getInt64(0), size);
4373 combinedInfo.Sizes.emplace_back(size);
4378 MapInfosTy &combinedInfo,
bool isTargetParams,
4379 int mapDataParentIdx = -1) {
4383 auto mapFlag = mapData.Types[mapDataIdx];
4384 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4388 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4390 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4391 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4393 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4395 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4400 if (mapDataParentIdx >= 0)
4401 combinedInfo.BasePointers.emplace_back(
4402 mapData.BasePointers[mapDataParentIdx]);
4404 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4406 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4407 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4408 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4409 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4410 combinedInfo.Types.emplace_back(mapFlag);
4411 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4415 llvm::IRBuilderBase &builder,
4416 llvm::OpenMPIRBuilder &ompBuilder,
4418 MapInfoData &mapData, uint64_t mapDataIndex,
4419 bool isTargetParams) {
4420 assert(!ompBuilder.Config.isTargetDevice() &&
4421 "function only supported for host device codegen");
4424 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4429 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4430 auto memberClause = llvm::cast<omp::MapInfoOp>(
4431 parentClause.getMembers()[0].getDefiningOp());
4448 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4450 combinedInfo, mapData, mapDataIndex, isTargetParams);
4452 combinedInfo, mapData, mapDataIndex,
4453 memberOfParentFlag);
4463 llvm::IRBuilderBase &builder) {
4465 "function only supported for host device codegen");
4466 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4468 if (!mapData.IsDeclareTarget[i]) {
4469 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4470 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4480 switch (captureKind) {
4481 case omp::VariableCaptureKind::ByRef: {
4482 llvm::Value *newV = mapData.Pointers[i];
4484 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4487 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4489 if (!offsetIdx.empty())
4490 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4492 mapData.Pointers[i] = newV;
4494 case omp::VariableCaptureKind::ByCopy: {
4495 llvm::Type *type = mapData.BaseType[i];
4497 if (mapData.Pointers[i]->getType()->isPointerTy())
4498 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4500 newV = mapData.Pointers[i];
4503 auto curInsert = builder.saveIP();
4504 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
4506 auto *memTempAlloc =
4507 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4508 builder.SetCurrentDebugLocation(DbgLoc);
4509 builder.restoreIP(curInsert);
4511 builder.CreateStore(newV, memTempAlloc);
4512 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4515 mapData.Pointers[i] = newV;
4516 mapData.BasePointers[i] = newV;
4518 case omp::VariableCaptureKind::This:
4519 case omp::VariableCaptureKind::VLAType:
4520 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4531 MapInfoData &mapData,
bool isTargetParams =
false) {
4533 "function only supported for host device codegen");
4555 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4558 if (mapData.IsAMember[i])
4561 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4562 if (!mapInfoOp.getMembers().empty()) {
4564 combinedInfo, mapData, i, isTargetParams);
4572static llvm::Expected<llvm::Function *>
4574 LLVM::ModuleTranslation &moduleTranslation,
4575 llvm::StringRef mapperFuncName);
4577static llvm::Expected<llvm::Function *>
4581 "function only supported for host device codegen");
4582 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4583 std::string mapperFuncName =
4585 {
"omp_mapper", declMapperOp.getSymName()});
4587 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4594static llvm::Expected<llvm::Function *>
4597 llvm::StringRef mapperFuncName) {
4599 "function only supported for host device codegen");
4600 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4601 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4604 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4607 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4610 MapInfosTy combinedInfo;
4612 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4613 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4614 builder.restoreIP(codeGenIP);
4615 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4616 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4617 builder.GetInsertBlock());
4618 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4621 return llvm::make_error<PreviouslyReportedError>();
4622 MapInfoData mapData;
4625 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4630 return combinedInfo;
4634 if (!combinedInfo.Mappers[i])
4641 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4643 return newFn.takeError();
4644 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4651 llvm::Value *ifCond =
nullptr;
4652 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4656 llvm::omp::RuntimeFunction RTLFn;
4660 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4662 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4663 bool isOffloadEntry =
4664 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4668 .Case([&](omp::TargetDataOp dataOp) {
4672 if (
auto ifVar = dataOp.getIfExpr())
4675 if (
auto devId = dataOp.getDevice())
4676 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4677 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4678 deviceID = intAttr.getInt();
4680 mapVars = dataOp.getMapVars();
4681 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4682 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4685 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4689 if (
auto ifVar = enterDataOp.getIfExpr())
4692 if (
auto devId = enterDataOp.getDevice())
4693 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4694 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4695 deviceID = intAttr.getInt();
4697 enterDataOp.getNowait()
4698 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4699 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4700 mapVars = enterDataOp.getMapVars();
4701 info.HasNoWait = enterDataOp.getNowait();
4704 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4708 if (
auto ifVar = exitDataOp.getIfExpr())
4711 if (
auto devId = exitDataOp.getDevice())
4712 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4713 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4714 deviceID = intAttr.getInt();
4716 RTLFn = exitDataOp.getNowait()
4717 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4718 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4719 mapVars = exitDataOp.getMapVars();
4720 info.HasNoWait = exitDataOp.getNowait();
4723 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4727 if (
auto ifVar = updateDataOp.getIfExpr())
4730 if (
auto devId = updateDataOp.getDevice())
4731 if (
auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
4732 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4733 deviceID = intAttr.getInt();
4736 updateDataOp.getNowait()
4737 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4738 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4739 mapVars = updateDataOp.getMapVars();
4740 info.HasNoWait = updateDataOp.getNowait();
4743 .DefaultUnreachable(
"unexpected operation");
4748 if (!isOffloadEntry)
4749 ifCond = builder.getFalse();
4751 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4752 MapInfoData mapData;
4754 builder, useDevicePtrVars, useDeviceAddrVars);
4757 MapInfosTy combinedInfo;
4758 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4759 builder.restoreIP(codeGenIP);
4760 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4761 return combinedInfo;
4767 [&moduleTranslation](
4768 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4772 for (
auto [arg, useDevVar] :
4773 llvm::zip_equal(blockArgs, useDeviceVars)) {
4775 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4776 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4777 : mapInfoOp.getVarPtr();
4780 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4781 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4782 mapInfoData.MapClause, mapInfoData.DevicePointers,
4783 mapInfoData.BasePointers)) {
4784 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4785 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4786 devicePointer != type)
4789 if (llvm::Value *devPtrInfoMap =
4790 mapper ? mapper(basePointer) : basePointer) {
4791 moduleTranslation.
mapValue(arg, devPtrInfoMap);
4798 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4799 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4800 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4803 builder.restoreIP(codeGenIP);
4804 assert(isa<omp::TargetDataOp>(op) &&
4805 "BodyGen requested for non TargetDataOp");
4806 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4807 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4808 switch (bodyGenType) {
4809 case BodyGenTy::Priv:
4811 if (!info.DevicePtrInfoMap.empty()) {
4812 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4813 blockArgIface.getUseDeviceAddrBlockArgs(),
4814 useDeviceAddrVars, mapData,
4815 [&](llvm::Value *basePointer) -> llvm::Value * {
4816 if (!info.DevicePtrInfoMap[basePointer].second)
4818 return builder.CreateLoad(
4820 info.DevicePtrInfoMap[basePointer].second);
4822 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4823 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4824 mapData, [&](llvm::Value *basePointer) {
4825 return info.DevicePtrInfoMap[basePointer].second;
4829 moduleTranslation)))
4830 return llvm::make_error<PreviouslyReportedError>();
4833 case BodyGenTy::DupNoPriv:
4834 if (info.DevicePtrInfoMap.empty()) {
4837 if (!ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4838 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4839 blockArgIface.getUseDeviceAddrBlockArgs(),
4840 useDeviceAddrVars, mapData);
4841 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4842 blockArgIface.getUseDevicePtrBlockArgs(),
4843 useDevicePtrVars, mapData);
4847 case BodyGenTy::NoPriv:
4849 if (info.DevicePtrInfoMap.empty()) {
4852 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4853 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4854 blockArgIface.getUseDeviceAddrBlockArgs(),
4855 useDeviceAddrVars, mapData);
4856 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4857 blockArgIface.getUseDevicePtrBlockArgs(),
4858 useDevicePtrVars, mapData);
4862 moduleTranslation)))
4863 return llvm::make_error<PreviouslyReportedError>();
4867 return builder.saveIP();
4870 auto customMapperCB =
4872 if (!combinedInfo.Mappers[i])
4874 info.HasMapper =
true;
4879 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4880 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4882 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4883 if (isa<omp::TargetDataOp>(op))
4884 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4885 builder.getInt64(deviceID), ifCond,
4886 info, genMapInfoCB, customMapperCB,
4889 return ompBuilder->createTargetData(
4890 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4891 info, genMapInfoCB, customMapperCB, &RTLFn);
4897 builder.restoreIP(*afterIP);
4905 auto distributeOp = cast<omp::DistributeOp>(opInst);
4912 bool doDistributeReduction =
4916 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4921 if (doDistributeReduction) {
4922 isByRef =
getIsByRef(teamsOp.getReductionByref());
4923 assert(isByRef.size() == teamsOp.getNumReductionVars());
4926 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4930 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4931 .getReductionBlockArgs();
4934 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4935 reductionDecls, privateReductionVariables, reductionVariableMap,
4940 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4941 auto bodyGenCB = [&](InsertPointTy allocaIP,
4942 InsertPointTy codeGenIP) -> llvm::Error {
4946 moduleTranslation, allocaIP);
4949 builder.restoreIP(codeGenIP);
4955 return llvm::make_error<PreviouslyReportedError>();
4960 return llvm::make_error<PreviouslyReportedError>();
4963 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
4965 distributeOp.getPrivateNeedsBarrier())))
4966 return llvm::make_error<PreviouslyReportedError>();
4969 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4972 builder, moduleTranslation);
4974 return regionBlock.takeError();
4975 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4980 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4983 auto schedule = omp::ClauseScheduleKind::Static;
4984 bool isOrdered =
false;
4985 std::optional<omp::ScheduleModifier> scheduleMod;
4986 bool isSimd =
false;
4987 llvm::omp::WorksharingLoopType workshareLoopType =
4988 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4989 bool loopNeedsBarrier =
false;
4990 llvm::Value *chunk =
nullptr;
4992 llvm::CanonicalLoopInfo *loopInfo =
4994 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4995 ompBuilder->applyWorkshareLoop(
4996 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4997 convertToScheduleKind(schedule), chunk, isSimd,
4998 scheduleMod == omp::ScheduleModifier::monotonic,
4999 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5003 return wsloopIP.takeError();
5007 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5009 return llvm::make_error<PreviouslyReportedError>();
5011 return llvm::Error::success();
5014 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5016 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5017 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5018 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5023 builder.restoreIP(*afterIP);
5025 if (doDistributeReduction) {
5028 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5029 privateReductionVariables, isByRef,
5041 if (!cast<mlir::ModuleOp>(op))
5046 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5047 attribute.getOpenmpDeviceVersion());
5049 if (attribute.getNoGpuLib())
5052 ompBuilder->createGlobalFlag(
5053 attribute.getDebugKind() ,
5054 "__omp_rtl_debug_kind");
5055 ompBuilder->createGlobalFlag(
5057 .getAssumeTeamsOversubscription()
5059 "__omp_rtl_assume_teams_oversubscription");
5060 ompBuilder->createGlobalFlag(
5062 .getAssumeThreadsOversubscription()
5064 "__omp_rtl_assume_threads_oversubscription");
5065 ompBuilder->createGlobalFlag(
5066 attribute.getAssumeNoThreadState() ,
5067 "__omp_rtl_assume_no_thread_state");
5068 ompBuilder->createGlobalFlag(
5070 .getAssumeNoNestedParallelism()
5072 "__omp_rtl_assume_no_nested_parallelism");
5077 omp::TargetOp targetOp,
5078 llvm::StringRef parentName =
"") {
5079 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5081 assert(fileLoc &&
"No file found from location");
5082 StringRef fileName = fileLoc.getFilename().getValue();
5084 llvm::sys::fs::UniqueID id;
5085 uint64_t line = fileLoc.getLine();
5086 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5087 size_t fileHash = llvm::hash_value(fileName.str());
5088 size_t deviceId = 0xdeadf17e;
5090 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5092 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5093 id.getFile(), line);
5100 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5102 "function only supported for target device codegen");
5103 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5104 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5117 if (mapData.IsDeclareTarget[i]) {
5124 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5125 convertUsersOfConstantsToInstructions(constant,
func,
false);
5132 for (llvm::User *user : mapData.OriginalValue[i]->users())
5133 userVec.push_back(user);
5135 for (llvm::User *user : userVec) {
5136 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5137 if (insn->getFunction() ==
func) {
5138 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5139 auto *
load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
5140 mapData.BasePointers[i]);
5141 load->moveBefore(insn->getIterator());
5142 user->replaceUsesOfWith(mapData.OriginalValue[i],
load);
5189static llvm::IRBuilderBase::InsertPoint
5191 llvm::Value *input, llvm::Value *&retVal,
5192 llvm::IRBuilderBase &builder,
5193 llvm::OpenMPIRBuilder &ompBuilder,
5195 llvm::IRBuilderBase::InsertPoint allocaIP,
5196 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5197 assert(ompBuilder.Config.isTargetDevice() &&
5198 "function only supported for target device codegen");
5199 builder.restoreIP(allocaIP);
5201 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5203 ompBuilder.M.getContext());
5204 unsigned alignmentValue = 0;
5206 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
5207 if (mapData.OriginalValue[i] == input) {
5208 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5209 capture = mapOp.getMapCaptureType();
5212 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5216 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5217 unsigned int defaultAS =
5218 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5221 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5223 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5224 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5226 builder.CreateStore(&arg, v);
5228 builder.restoreIP(codeGenIP);
5231 case omp::VariableCaptureKind::ByCopy: {
5235 case omp::VariableCaptureKind::ByRef: {
5236 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5238 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5253 if (v->getType()->isPointerTy() && alignmentValue) {
5254 llvm::MDBuilder MDB(builder.getContext());
5255 loadInst->setMetadata(
5256 llvm::LLVMContext::MD_align,
5257 llvm::MDNode::get(builder.getContext(),
5258 MDB.createConstant(llvm::ConstantInt::get(
5259 llvm::Type::getInt64Ty(builder.getContext()),
5266 case omp::VariableCaptureKind::This:
5267 case omp::VariableCaptureKind::VLAType:
5270 assert(
false &&
"Currently unsupported capture kind");
5274 return builder.saveIP();
5291 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
5292 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5293 blockArgIface.getHostEvalBlockArgs())) {
5294 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5298 .Case([&](omp::TeamsOp teamsOp) {
5299 if (teamsOp.getNumTeamsLower() == blockArg)
5300 numTeamsLower = hostEvalVar;
5301 else if (teamsOp.getNumTeamsUpper() == blockArg)
5302 numTeamsUpper = hostEvalVar;
5303 else if (teamsOp.getThreadLimit() == blockArg)
5304 threadLimit = hostEvalVar;
5306 llvm_unreachable(
"unsupported host_eval use");
5308 .Case([&](omp::ParallelOp parallelOp) {
5309 if (parallelOp.getNumThreads() == blockArg)
5310 numThreads = hostEvalVar;
5312 llvm_unreachable(
"unsupported host_eval use");
5314 .Case([&](omp::LoopNestOp loopOp) {
5315 auto processBounds =
5319 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
5320 if (lb == blockArg) {
5323 (*outBounds)[i] = hostEvalVar;
5329 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5330 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5332 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5334 assert(found &&
"unsupported host_eval use");
5336 .DefaultUnreachable(
"unsupported host_eval use");
5348template <
typename OpTy>
5353 if (OpTy casted = dyn_cast<OpTy>(op))
5356 if (immediateParent)
5357 return dyn_cast_if_present<OpTy>(op->
getParentOp());
5366 return std::nullopt;
5369 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5370 return constAttr.getInt();
5372 return std::nullopt;
5377 uint64_t sizeInBytes = sizeInBits / 8;
5381template <
typename OpTy>
5383 if (op.getNumReductionVars() > 0) {
5388 members.reserve(reductions.size());
5389 for (omp::DeclareReductionOp &red : reductions)
5390 members.push_back(red.getType());
5392 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5408 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5409 bool isTargetDevice,
bool isGPU) {
5412 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5413 if (!isTargetDevice) {
5421 numTeamsLower = teamsOp.getNumTeamsLower();
5422 numTeamsUpper = teamsOp.getNumTeamsUpper();
5423 threadLimit = teamsOp.getThreadLimit();
5427 numThreads = parallelOp.getNumThreads();
5432 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5436 if (numTeamsUpper) {
5438 minTeamsVal = maxTeamsVal = *val;
5440 minTeamsVal = maxTeamsVal = 0;
5446 minTeamsVal = maxTeamsVal = 1;
5448 minTeamsVal = maxTeamsVal = -1;
5453 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
5467 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5468 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5469 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5472 int32_t maxThreadsVal = -1;
5474 setMaxValueFromClause(numThreads, maxThreadsVal);
5482 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5483 if (combinedMaxThreadsVal < 0 ||
5484 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5485 combinedMaxThreadsVal = teamsThreadLimitVal;
5487 if (combinedMaxThreadsVal < 0 ||
5488 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5489 combinedMaxThreadsVal = maxThreadsVal;
5491 int32_t reductionDataSize = 0;
5492 if (isGPU && capturedOp) {
5498 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5500 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5501 omp::TargetRegionFlags::spmd) &&
5502 "invalid kernel flags");
5504 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5505 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5506 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5507 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5508 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5509 if (omp::bitEnumContainsAll(kernelFlags,
5510 omp::TargetRegionFlags::spmd |
5511 omp::TargetRegionFlags::no_loop) &&
5512 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
5513 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
5515 attrs.MinTeams = minTeamsVal;
5516 attrs.MaxTeams.front() = maxTeamsVal;
5517 attrs.MinThreads = 1;
5518 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5519 attrs.ReductionDataSize = reductionDataSize;
5522 if (attrs.ReductionDataSize != 0)
5523 attrs.ReductionBufferLength = 1024;
5535 omp::TargetOp targetOp,
Operation *capturedOp,
5536 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5538 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5540 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5544 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5547 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5548 attrs.TargetThreadLimit.front() =
5552 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5555 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5557 if (teamsThreadLimit)
5558 attrs.TeamsThreadLimit.front() =
5562 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5564 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5565 omp::TargetRegionFlags::trip_count)) {
5567 attrs.LoopTripCount =
nullptr;
5572 for (
auto [loopLower, loopUpper, loopStep] :
5573 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5574 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5575 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5576 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5578 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5579 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5580 loc, lowerBound, upperBound, step,
true,
5581 loopOp.getLoopInclusive());
5583 if (!attrs.LoopTripCount) {
5584 attrs.LoopTripCount = tripCount;
5589 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5598 auto targetOp = cast<omp::TargetOp>(opInst);
5602 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
5611 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
5612 assert(parentBB &&
"No insert block is set for the builder");
5613 llvm::Function *parentLLVMFn = parentBB->getParent();
5614 assert(parentLLVMFn &&
"Parent Function must be valid");
5615 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
5616 builder.SetCurrentDebugLocation(llvm::DILocation::get(
5617 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
5618 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
5621 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5622 bool isGPU = ompBuilder->Config.isGPU();
5625 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5626 auto &targetRegion = targetOp.getRegion();
5643 llvm::Function *llvmOutlinedFn =
nullptr;
5647 bool isOffloadEntry =
5648 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5655 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5657 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5658 std::optional<DenseI64ArrayAttr> privateMapIndices =
5659 targetOp.getPrivateMapsAttr();
5661 for (
auto [privVarIdx, privVarSymPair] :
5662 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
5663 auto privVar = std::get<0>(privVarSymPair);
5664 auto privSym = std::get<1>(privVarSymPair);
5666 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5667 omp::PrivateClauseOp privatizer =
5670 if (!privatizer.needsMap())
5674 targetOp.getMappedValueForPrivateVar(privVarIdx);
5675 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5676 "variable that needs mapping");
5681 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5682 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5686 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5688 varType == privVar.getType() &&
5689 "Type of private var doesn't match the type of the mapped value");
5693 mappedPrivateVars.insert(
5695 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5696 (*privateMapIndices)[privVarIdx])});
5700 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5701 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5702 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5703 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5704 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5707 llvm::Function *llvmParentFn =
5709 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5710 assert(llvmParentFn && llvmOutlinedFn &&
5711 "Both parent and outlined functions must exist at this point");
5713 if (outlinedFnLoc && llvmParentFn->getSubprogram())
5714 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
5716 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5717 attr.isStringAttribute())
5718 llvmOutlinedFn->addFnAttr(attr);
5720 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5721 attr.isStringAttribute())
5722 llvmOutlinedFn->addFnAttr(attr);
5724 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5725 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5726 llvm::Value *mapOpValue =
5727 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5728 moduleTranslation.
mapValue(arg, mapOpValue);
5730 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5731 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5732 llvm::Value *mapOpValue =
5733 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5734 moduleTranslation.
mapValue(arg, mapOpValue);
5743 allocaIP, &mappedPrivateVars);
5746 return llvm::make_error<PreviouslyReportedError>();
5748 builder.restoreIP(codeGenIP);
5750 &mappedPrivateVars),
5753 return llvm::make_error<PreviouslyReportedError>();
5756 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
5758 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5759 return llvm::make_error<PreviouslyReportedError>();
5763 std::back_inserter(privateCleanupRegions),
5764 [](omp::PrivateClauseOp privatizer) {
5765 return &privatizer.getDeallocRegion();
5769 targetRegion,
"omp.target", builder, moduleTranslation);
5772 return exitBlock.takeError();
5774 builder.SetInsertPoint(*exitBlock);
5775 if (!privateCleanupRegions.empty()) {
5777 privateCleanupRegions, privateVarsInfo.
llvmVars,
5778 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5780 return llvm::createStringError(
5781 "failed to inline `dealloc` region of `omp.private` "
5782 "op in the target region");
5784 return builder.saveIP();
5787 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5790 StringRef parentName = parentFn.getName();
5792 llvm::TargetRegionEntryInfo entryInfo;
5796 MapInfoData mapData;
5801 MapInfosTy combinedInfos;
5803 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5804 builder.restoreIP(codeGenIP);
5805 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5806 return combinedInfos;
5809 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
5810 llvm::Value *&retVal, InsertPointTy allocaIP,
5811 InsertPointTy codeGenIP)
5812 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5813 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5814 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5820 if (!isTargetDevice) {
5821 retVal = cast<llvm::Value>(&arg);
5826 *ompBuilder, moduleTranslation,
5827 allocaIP, codeGenIP);
5830 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5831 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5832 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5834 isTargetDevice, isGPU);
5838 if (!isTargetDevice)
5840 targetCapturedOp, runtimeAttrs);
5848 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5849 llvm::Value *value = moduleTranslation.
lookupValue(var);
5850 moduleTranslation.
mapValue(arg, value);
5852 if (!llvm::isa<llvm::Constant>(value))
5853 kernelInput.push_back(value);
5856 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5863 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5864 kernelInput.push_back(mapData.OriginalValue[i]);
5869 moduleTranslation, dds);
5871 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5873 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5875 llvm::OpenMPIRBuilder::TargetDataInfo info(
5879 auto customMapperCB =
5881 if (!combinedInfos.Mappers[i])
5883 info.HasMapper =
true;
5888 llvm::Value *ifCond =
nullptr;
5889 if (
Value targetIfCond = targetOp.getIfExpr())
5890 ifCond = moduleTranslation.
lookupValue(targetIfCond);
5892 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5894 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5895 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5896 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5901 builder.restoreIP(*afterIP);
5922 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5923 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5925 if (!offloadMod.getIsTargetDevice())
5928 omp::DeclareTargetDeviceType declareType =
5929 attribute.getDeviceType().getValue();
5931 if (declareType == omp::DeclareTargetDeviceType::host) {
5932 llvm::Function *llvmFunc =
5934 llvmFunc->dropAllReferences();
5935 llvmFunc->eraseFromParent();
5941 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5942 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
5943 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5945 bool isDeclaration = gOp.isDeclaration();
5946 bool isExternallyVisible =
5949 llvm::StringRef mangledName = gOp.getSymName();
5950 auto captureClause =
5956 std::vector<llvm::GlobalVariable *> generatedRefs;
5958 std::vector<llvm::Triple> targetTriple;
5959 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5961 LLVM::LLVMDialect::getTargetTripleAttrName()));
5962 if (targetTripleAttr)
5963 targetTriple.emplace_back(targetTripleAttr.data());
5965 auto fileInfoCallBack = [&loc]() {
5966 std::string filename =
"";
5967 std::uint64_t lineNo = 0;
5970 filename = loc.getFilename().str();
5971 lineNo = loc.getLine();
5974 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5978 auto vfs = llvm::vfs::getRealFileSystem();
5980 ompBuilder->registerTargetGlobalVariable(
5981 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5982 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
5983 mangledName, generatedRefs,
false, targetTriple,
5985 gVal->getType(), gVal);
5987 if (ompBuilder->Config.isTargetDevice() &&
5988 (attribute.getCaptureClause().getValue() !=
5989 mlir::omp::DeclareTargetCaptureClause::to ||
5990 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5991 ompBuilder->getAddrOfDeclareTargetVar(
5992 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5993 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
5994 mangledName, generatedRefs,
false, targetTriple,
5995 gVal->getType(),
nullptr,
6016 if (mlir::isa<omp::ThreadprivateOp>(op))
6019 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
6020 mlir::isa<omp::TargetFreeMemOp>(op))
6024 if (
auto declareTargetIface =
6025 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6026 parentFn.getOperation()))
6027 if (declareTargetIface.isDeclareTarget() &&
6028 declareTargetIface.getDeclareTargetDeviceType() !=
6029 mlir::omp::DeclareTargetDeviceType::host)
6036 llvm::Module *llvmModule) {
6037 llvm::Type *i64Ty = builder.getInt64Ty();
6038 llvm::Type *i32Ty = builder.getInt32Ty();
6039 llvm::Type *returnType = builder.getPtrTy(0);
6040 llvm::FunctionType *fnType =
6041 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
6042 llvm::Function *
func = cast<llvm::Function>(
6043 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
6050 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6055 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6059 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6061 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6062 mlir::Type heapTy = allocMemOp.getAllocatedType();
6063 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
6064 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6065 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6066 for (
auto typeParam : allocMemOp.getTypeparams())
6068 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
6070 llvm::CallInst *call =
6071 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6072 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6075 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
6080 llvm::Module *llvmModule) {
6081 llvm::Type *ptrTy = builder.getPtrTy(0);
6082 llvm::Type *i32Ty = builder.getInt32Ty();
6083 llvm::Type *voidTy = builder.getVoidTy();
6084 llvm::FunctionType *fnType =
6085 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
6086 llvm::Function *
func = dyn_cast<llvm::Function>(
6087 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
6094 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6099 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6103 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6106 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
6108 llvm::Value *intToPtr =
6109 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
6110 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
6127 bool isOutermostLoopWrapper =
6128 isa_and_present<omp::LoopWrapperInterface>(op) &&
6129 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
6131 if (isOutermostLoopWrapper)
6132 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
6136 .Case([&](omp::BarrierOp op) -> LogicalResult {
6140 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6141 ompBuilder->createBarrier(builder.saveIP(),
6142 llvm::omp::OMPD_barrier);
6144 if (res.succeeded()) {
6147 builder.restoreIP(*afterIP);
6151 .Case([&](omp::TaskyieldOp op) {
6155 ompBuilder->createTaskyield(builder.saveIP());
6158 .Case([&](omp::FlushOp op) {
6170 ompBuilder->createFlush(builder.saveIP());
6173 .Case([&](omp::ParallelOp op) {
6176 .Case([&](omp::MaskedOp) {
6179 .Case([&](omp::MasterOp) {
6182 .Case([&](omp::CriticalOp) {
6185 .Case([&](omp::OrderedRegionOp) {
6188 .Case([&](omp::OrderedOp) {
6191 .Case([&](omp::WsloopOp) {
6194 .Case([&](omp::SimdOp) {
6197 .Case([&](omp::AtomicReadOp) {
6200 .Case([&](omp::AtomicWriteOp) {
6203 .Case([&](omp::AtomicUpdateOp op) {
6206 .Case([&](omp::AtomicCaptureOp op) {
6209 .Case([&](omp::CancelOp op) {
6212 .Case([&](omp::CancellationPointOp op) {
6215 .Case([&](omp::SectionsOp) {
6218 .Case([&](omp::SingleOp op) {
6221 .Case([&](omp::TeamsOp op) {
6224 .Case([&](omp::TaskOp op) {
6227 .Case([&](omp::TaskgroupOp op) {
6230 .Case([&](omp::TaskwaitOp op) {
6233 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6234 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6235 omp::CriticalDeclareOp>([](
auto op) {
6248 .Case([&](omp::ThreadprivateOp) {
6251 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
6252 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
6255 .Case([&](omp::TargetOp) {
6258 .Case([&](omp::DistributeOp) {
6261 .Case([&](omp::LoopNestOp) {
6264 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
6271 .Case([&](omp::NewCliOp op) {
6276 .Case([&](omp::CanonicalLoopOp op) {
6279 .Case([&](omp::UnrollHeuristicOp op) {
6288 .Case([&](omp::TileOp op) {
6289 return applyTile(op, builder, moduleTranslation);
6291 .Case([&](omp::TargetAllocMemOp) {
6294 .Case([&](omp::TargetFreeMemOp) {
6299 <<
"not yet implemented: " << inst->
getName();
6302 if (isOutermostLoopWrapper)
6317 if (isa<omp::TargetOp>(op))
6319 if (isa<omp::TargetDataOp>(op))
6323 if (isa<omp::TargetOp>(oper)) {
6328 if (isa<omp::TargetDataOp>(oper)) {
6338 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
6339 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
6340 !oper->getRegions().empty()) {
6341 if (
auto blockArgsIface =
6342 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
6348 if (isa<mlir::omp::AtomicUpdateOp>(oper))
6349 for (
auto [operand, arg] :
6350 llvm::zip_equal(oper->getOperands(),
6351 oper->getRegion(0).getArguments())) {
6353 arg, builder.CreateLoad(
6359 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
6360 assert(builder.GetInsertBlock() &&
6361 "No insert block is set for the builder");
6362 for (
auto iv : loopNest.getIVs()) {
6365 iv, llvm::PoisonValue::get(
6370 for (
Region ®ion : oper->getRegions()) {
6377 region, oper->getName().getStringRef().str() +
".fake.region",
6378 builder, moduleTranslation, &phis);
6382 builder.SetInsertPoint(
result.get(),
result.get()->end());
6389 }).wasInterrupted();
6390 return failure(interrupted);
6397class OpenMPDialectLLVMIRTranslationInterface
6398 :
public LLVMTranslationDialectInterface {
6405 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6406 LLVM::ModuleTranslation &moduleTranslation)
const final;
6411 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6412 NamedAttribute attribute,
6413 LLVM::ModuleTranslation &moduleTranslation)
const final;
6418LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6419 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6420 NamedAttribute attribute,
6421 LLVM::ModuleTranslation &moduleTranslation)
const {
6422 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6424 .Case(
"omp.is_target_device",
6425 [&](Attribute attr) {
6426 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6427 llvm::OpenMPIRBuilderConfig &
config =
6429 config.setIsTargetDevice(deviceAttr.getValue());
6435 [&](Attribute attr) {
6436 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6437 llvm::OpenMPIRBuilderConfig &
config =
6439 config.setIsGPU(gpuAttr.getValue());
6444 .Case(
"omp.host_ir_filepath",
6445 [&](Attribute attr) {
6446 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6447 llvm::OpenMPIRBuilder *ompBuilder =
6449 auto VFS = llvm::vfs::getRealFileSystem();
6450 ompBuilder->loadOffloadInfoMetadata(*VFS,
6451 filepathAttr.getValue());
6457 [&](Attribute attr) {
6458 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6462 .Case(
"omp.version",
6463 [&](Attribute attr) {
6464 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6465 llvm::OpenMPIRBuilder *ompBuilder =
6467 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6468 versionAttr.getVersion());
6473 .Case(
"omp.declare_target",
6474 [&](Attribute attr) {
6475 if (
auto declareTargetAttr =
6476 dyn_cast<omp::DeclareTargetAttr>(attr))
6481 .Case(
"omp.requires",
6482 [&](Attribute attr) {
6483 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6484 using Requires = omp::ClauseRequires;
6485 Requires flags = requiresAttr.getValue();
6486 llvm::OpenMPIRBuilderConfig &
config =
6488 config.setHasRequiresReverseOffload(
6489 bitEnumContainsAll(flags, Requires::reverse_offload));
6490 config.setHasRequiresUnifiedAddress(
6491 bitEnumContainsAll(flags, Requires::unified_address));
6492 config.setHasRequiresUnifiedSharedMemory(
6493 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6494 config.setHasRequiresDynamicAllocators(
6495 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6500 .Case(
"omp.target_triples",
6501 [&](Attribute attr) {
6502 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6503 llvm::OpenMPIRBuilderConfig &
config =
6505 config.TargetTriples.clear();
6506 config.TargetTriples.reserve(triplesAttr.size());
6507 for (Attribute tripleAttr : triplesAttr) {
6508 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6509 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6517 .Default([](Attribute) {
6527LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6528 Operation *op, llvm::IRBuilderBase &builder,
6529 LLVM::ModuleTranslation &moduleTranslation)
const {
6532 if (ompBuilder->Config.isTargetDevice()) {
6542 registry.
insert<omp::OpenMPDialect>();
6544 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars