26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
71 llvm_unreachable(
"unhandled schedule clause argument");
76class OpenMPAllocStackFrame
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
92class OpenMPLoopInfoStackFrame
96 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
115class PreviouslyReportedError
116 :
public llvm::ErrorInfo<PreviouslyReportedError> {
118 void log(raw_ostream &)
const override {
122 std::error_code convertToErrorCode()
const override {
124 "PreviouslyReportedError doesn't support ECError conversion");
131char PreviouslyReportedError::ID = 0;
142class LinearClauseProcessor {
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.
convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar,
int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
205 if (linearVarType->isIntegerTy()) {
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 }
else if (linearVarType->isFloatingPointTy()) {
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
227 "Linear variable must be of integer or floating-point type");
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(),
"omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
272 builder.SetInsertPoint(linearExitBB->getTerminator());
274 builder.saveIP(), llvm::omp::OMPD_barrier);
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
289 void rewriteInPlace(llvm::IRBuilderBase &builder, llvm::BasicBlock *startBB,
290 llvm::BasicBlock *endBB, llvm::StringRef prefix,
292 llvm::SmallVector<llvm::BasicBlock *, 32> worklist;
293 llvm::SmallPtrSet<llvm::BasicBlock *, 32> visited;
294 llvm::SmallPtrSet<llvm::BasicBlock *, 32> matchingBBs;
296 assert(startBB && endBB &&
"Invalid startBB/endBB");
300 worklist.push_back(startBB);
301 visited.insert(startBB);
303 while (!worklist.empty()) {
304 llvm::BasicBlock *bb = worklist.pop_back_val();
306 if (bb->hasName() && bb->getName().starts_with(prefix))
307 matchingBBs.insert(bb);
312 for (llvm::BasicBlock *succ : llvm::successors(bb)) {
313 if (visited.insert(succ).second)
314 worklist.push_back(succ);
319 llvm::SmallVector<llvm::User *> users(linearOrigVal[varIndex]->users());
320 for (
auto *user : users) {
321 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
322 if (matchingBBs.contains(userInst->getParent()))
323 user->replaceUsesOfWith(linearOrigVal[varIndex],
324 linearLoopBodyTemps[varIndex]);
335 SymbolRefAttr symbolName) {
336 omp::PrivateClauseOp privatizer =
339 assert(privatizer &&
"privatizer not found in the symbol table");
350 auto todo = [&op](StringRef clauseName) {
351 return op.
emitError() <<
"not yet implemented: Unhandled clause "
352 << clauseName <<
" in " << op.
getName()
356 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
357 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
358 result = todo(
"allocate");
360 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
362 result = todo(
"ompx_bare");
364 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
365 if (!op.getDependVars().empty() || op.getDependKinds())
368 auto checkHint = [](
auto op, LogicalResult &) {
372 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
373 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
374 op.getInReductionSyms())
375 result = todo(
"in_reduction");
377 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
381 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
382 if (op.getOrder() || op.getOrderMod())
385 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
386 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
387 result = todo(
"privatization");
389 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
390 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
391 if (!op.getReductionVars().empty() || op.getReductionByref() ||
392 op.getReductionSyms())
393 result = todo(
"reduction");
394 if (op.getReductionMod() &&
395 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
396 result = todo(
"reduction with modifier");
398 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
399 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
400 op.getTaskReductionSyms())
401 result = todo(
"task_reduction");
403 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
404 if (op.hasNumTeamsMultiDim())
405 result = todo(
"num_teams with multi-dimensional values");
407 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
408 if (op.hasNumThreadsMultiDim())
409 result = todo(
"num_threads with multi-dimensional values");
412 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
413 if (op.hasThreadLimitMultiDim())
414 result = todo(
"thread_limit with multi-dimensional values");
416 auto checkMap = [&todo](
auto op, LogicalResult &
result) {
417 if (!op.getMapIterated().empty())
418 result = todo(
"map/motion clause with iterator modifier");
421 auto checkDynGroupprivate = [&todo](
auto op, LogicalResult &
result) {
422 if (op.getDynGroupprivateSize())
423 result = todo(
"dyn_groupprivate");
428 .Case([&](omp::DistributeOp op) {
429 checkAllocate(op,
result);
432 .Case([&](omp::SectionsOp op) {
433 checkAllocate(op,
result);
435 checkReduction(op,
result);
437 .Case([&](omp::ScopeOp op) {
438 checkAllocate(op,
result);
439 checkReduction(op,
result);
441 .Case([&](omp::SingleOp op) {
442 checkAllocate(op,
result);
445 .Case([&](omp::TeamsOp op) {
446 checkAllocate(op,
result);
448 checkNumTeams(op,
result);
449 checkThreadLimit(op,
result);
450 checkDynGroupprivate(op,
result);
452 .Case([&](omp::TaskOp op) {
453 checkAllocate(op,
result);
454 checkInReduction(op,
result);
456 .Case([&](omp::TaskgroupOp op) {
457 checkAllocate(op,
result);
458 checkTaskReduction(op,
result);
460 .Case([&](omp::TaskwaitOp op) {
464 .Case([&](omp::TaskloopContextOp op) {
465 checkAllocate(op,
result);
466 checkInReduction(op,
result);
467 checkReduction(op,
result);
469 .Case([&](omp::WsloopOp op) {
470 checkAllocate(op,
result);
472 checkReduction(op,
result);
474 .Case([&](omp::ParallelOp op) {
475 checkAllocate(op,
result);
476 checkReduction(op,
result);
477 checkNumThreads(op,
result);
479 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
480 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
481 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
482 .Case([&](omp::AtomicCompareOp op) {
488 auto structTy = dyn_cast<LLVM::LLVMStructType>(argType);
494 result = todo(
"compare for complex types wider than 128 bits");
496 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>([&](
auto op) {
500 .Case([&](omp::TargetUpdateOp op) {
504 .Case([&](omp::TargetOp op) {
505 checkAllocate(op,
result);
507 checkInReduction(op,
result);
509 checkThreadLimit(op,
result);
511 .Case([&](omp::TargetDataOp op) { checkMap(op,
result); })
512 .Case([&](omp::DeclareMapperInfoOp op) { checkMap(op,
result); })
523 llvm::handleAllErrors(
525 [&](
const PreviouslyReportedError &) {
result = failure(); },
526 [&](
const llvm::ErrorInfoBase &err) {
549 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
552 [&](OpenMPAllocStackFrame &frame) {
553 allocInsertPoint = frame.allocInsertPoint;
554 deallocInsertPoints = frame.deallocBlocks;
562 allocInsertPoint.getBlock()->getParent() ==
563 builder.GetInsertBlock()->getParent()) {
565 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
566 deallocInsertPoints.end());
567 return allocInsertPoint;
577 if (builder.GetInsertBlock() ==
578 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
579 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
580 "Assuming end of basic block");
581 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
582 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
583 builder.GetInsertBlock()->getNextNode());
584 builder.CreateBr(entryBB);
585 builder.SetInsertPoint(entryBB);
591 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
595 llvm::Instruction *terminator = block.getTerminatorOrNull();
596 if (isa_and_present<llvm::ReturnInst>(terminator))
597 deallocBlocks->emplace_back(&block);
601 llvm::BasicBlock &funcEntryBlock =
602 builder.GetInsertBlock()->getParent()->getEntryBlock();
603 return llvm::OpenMPIRBuilder::InsertPointTy(
604 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
610static llvm::CanonicalLoopInfo *
612 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
613 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
614 [&](OpenMPLoopInfoStackFrame &frame) {
615 loopInfo = frame.loopInfo;
627 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
630 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
632 llvm::BasicBlock *continuationBlock =
633 splitBB(builder,
true,
"omp.region.cont");
634 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
636 llvm::LLVMContext &llvmContext = builder.getContext();
637 for (
Block &bb : region) {
638 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
639 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
640 builder.GetInsertBlock()->getNextNode());
641 moduleTranslation.
mapBlock(&bb, llvmBB);
644 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
651 unsigned numYields = 0;
653 if (!isLoopWrapper) {
654 bool operandsProcessed =
false;
656 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
657 if (!operandsProcessed) {
658 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
659 continuationBlockPHITypes.push_back(
660 moduleTranslation.
convertType(yield->getOperand(i).getType()));
662 operandsProcessed =
true;
664 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
665 "mismatching number of values yielded from the region");
666 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
667 llvm::Type *operandType =
668 moduleTranslation.
convertType(yield->getOperand(i).getType());
670 assert(continuationBlockPHITypes[i] == operandType &&
671 "values of mismatching types yielded from the region");
681 if (!continuationBlockPHITypes.empty())
683 continuationBlockPHIs &&
684 "expected continuation block PHIs if converted regions yield values");
685 if (continuationBlockPHIs) {
686 llvm::IRBuilderBase::InsertPointGuard guard(builder);
687 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
688 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
689 for (llvm::Type *ty : continuationBlockPHITypes)
690 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
696 for (
Block *bb : blocks) {
697 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
700 if (bb->isEntryBlock()) {
701 assert(sourceTerminator->getNumSuccessors() == 1 &&
702 "provided entry block has multiple successors");
703 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
704 "ContinuationBlock is not the successor of the entry block");
705 sourceTerminator->setSuccessor(0, llvmBB);
708 llvm::IRBuilderBase::InsertPointGuard guard(builder);
710 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
711 return llvm::make_error<PreviouslyReportedError>();
716 builder.CreateBr(continuationBlock);
727 Operation *terminator = bb->getTerminator();
728 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
729 builder.CreateBr(continuationBlock);
731 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
732 (*continuationBlockPHIs)[i]->addIncoming(
746 return continuationBlock;
752 case omp::ClauseProcBindKind::Close:
753 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
754 case omp::ClauseProcBindKind::Master:
755 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
756 case omp::ClauseProcBindKind::Primary:
757 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
758 case omp::ClauseProcBindKind::Spread:
759 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
761 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
768 auto maskedOp = cast<omp::MaskedOp>(opInst);
769 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
774 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
777 auto ®ion = maskedOp.getRegion();
778 builder.restoreIP(codeGenIP);
786 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
788 llvm::Value *filterVal =
nullptr;
789 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
790 filterVal = moduleTranslation.
lookupValue(filterVar);
792 llvm::LLVMContext &llvmContext = builder.getContext();
794 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
796 assert(filterVal !=
nullptr);
797 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
798 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
805 builder.restoreIP(*afterIP);
813 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
814 auto masterOp = cast<omp::MasterOp>(opInst);
819 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
822 auto ®ion = masterOp.getRegion();
823 builder.restoreIP(codeGenIP);
831 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
833 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
834 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
841 builder.restoreIP(*afterIP);
849 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
850 auto criticalOp = cast<omp::CriticalOp>(opInst);
855 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
858 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
859 builder.restoreIP(codeGenIP);
867 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
869 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
870 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
871 llvm::Constant *hint =
nullptr;
874 if (criticalOp.getNameAttr()) {
877 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
878 auto criticalDeclareOp =
882 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
883 static_cast<int>(criticalDeclareOp.getHint()));
885 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
887 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
892 builder.restoreIP(*afterIP);
899 template <
typename OP>
902 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
905 collectPrivatizationDecls<OP>(op);
920 void collectPrivatizationDecls(OP op) {
921 std::optional<ArrayAttr> attr = op.getPrivateSyms();
926 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
937 std::optional<ArrayAttr> attr = op.getReductionSyms();
941 reductions.reserve(reductions.size() + op.getNumReductionVars());
942 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
943 reductions.push_back(
955 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
964 llvm::Instruction *potentialTerminator =
965 builder.GetInsertBlock()->empty() ?
nullptr
966 : &builder.GetInsertBlock()->back();
968 if (potentialTerminator && potentialTerminator->isTerminator())
969 potentialTerminator->removeFromParent();
970 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
973 region.
front(),
true, builder)))
977 if (continuationBlockArgs)
979 *continuationBlockArgs,
986 if (potentialTerminator && potentialTerminator->isTerminator()) {
987 llvm::BasicBlock *block = builder.GetInsertBlock();
988 if (block->empty()) {
994 potentialTerminator->insertInto(block, block->begin());
996 potentialTerminator->insertAfter(&block->back());
1010 if (continuationBlockArgs)
1011 llvm::append_range(*continuationBlockArgs, phis);
1012 builder.SetInsertPoint(*continuationBlock,
1013 (*continuationBlock)->getFirstInsertionPt());
1020using OwningReductionGen =
1021 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1022 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
1024using OwningAtomicReductionGen =
1025 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1026 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
1028using OwningDataPtrPtrReductionGen =
1029 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1030 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
1036static OwningReductionGen
1042 OwningReductionGen gen =
1043 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1044 llvm::Value *
lhs, llvm::Value *
rhs,
1045 llvm::Value *&
result)
mutable
1046 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1047 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
1048 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
1049 builder.restoreIP(insertPoint);
1052 "omp.reduction.nonatomic.body", builder,
1053 moduleTranslation, &phis)))
1054 return llvm::createStringError(
1055 "failed to inline `combiner` region of `omp.declare_reduction`");
1056 result = llvm::getSingleElement(phis);
1057 return builder.saveIP();
1066static OwningAtomicReductionGen
1068 llvm::IRBuilderBase &builder,
1070 if (decl.getAtomicReductionRegion().empty())
1071 return OwningAtomicReductionGen();
1077 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1078 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1079 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1080 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1081 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1082 builder.restoreIP(insertPoint);
1085 "omp.reduction.atomic.body", builder,
1086 moduleTranslation, &phis)))
1087 return llvm::createStringError(
1088 "failed to inline `atomic` region of `omp.declare_reduction`");
1089 assert(phis.empty());
1090 return builder.saveIP();
1099static OwningDataPtrPtrReductionGen
1102 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1103 return OwningDataPtrPtrReductionGen();
1105 OwningDataPtrPtrReductionGen refDataPtrGen =
1106 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1107 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1108 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1109 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1110 builder.restoreIP(insertPoint);
1113 "omp.data_ptr_ptr.body", builder,
1114 moduleTranslation, &phis)))
1115 return llvm::createStringError(
1116 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1117 result = llvm::getSingleElement(phis);
1118 return builder.saveIP();
1121 return refDataPtrGen;
1128 auto orderedOp = cast<omp::OrderedOp>(opInst);
1133 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1134 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1135 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1137 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1139 size_t indexVecValues = 0;
1140 while (indexVecValues < vecValues.size()) {
1142 storeValues.reserve(numLoops);
1143 for (
unsigned i = 0; i < numLoops; i++) {
1144 storeValues.push_back(vecValues[indexVecValues]);
1147 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1149 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1150 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1151 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1161 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1162 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1167 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1170 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1171 builder.restoreIP(codeGenIP);
1179 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1181 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1184 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1189 builder.restoreIP(*afterIP);
1195struct DeferredStore {
1196 DeferredStore(llvm::Value *value, llvm::Value *address)
1197 : value(value), address(address) {}
1200 llvm::Value *address;
1207template <
typename T>
1210 llvm::IRBuilderBase &builder,
1212 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1218 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1219 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1225 deferredStores.reserve(op.getNumReductionVars());
1227 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1228 Region &allocRegion = reductionDecls[i].getAllocRegion();
1230 if (allocRegion.
empty())
1235 builder, moduleTranslation, &phis)))
1236 return op.emitError(
1237 "failed to inline `alloc` region of `omp.declare_reduction`");
1239 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1240 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1244 llvm::Type *ptrTy = builder.getPtrTy();
1248 if (useDeviceSharedMem) {
1249 var = ompBuilder->createOMPAllocShared(builder, varTy);
1251 var = builder.CreateAlloca(varTy);
1252 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1255 llvm::Value *castPhi =
1256 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1258 deferredStores.emplace_back(castPhi, var);
1260 privateReductionVariables[i] = var;
1261 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1262 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1264 assert(allocRegion.
empty() &&
1265 "allocaction is implicit for by-val reduction");
1267 llvm::Type *ptrTy = builder.getPtrTy();
1271 if (useDeviceSharedMem) {
1272 var = ompBuilder->createOMPAllocShared(builder, varTy);
1274 var = builder.CreateAlloca(varTy);
1275 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1278 moduleTranslation.
mapValue(reductionArgs[i], var);
1279 privateReductionVariables[i] = var;
1280 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1288template <
typename T>
1291 llvm::IRBuilderBase &builder,
1296 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1297 Region &initializerRegion = reduction.getInitializerRegion();
1300 mlir::Value mlirSource = loop.getReductionVars()[i];
1301 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1302 llvm::Value *origVal = llvmSource;
1304 if (!isa<LLVM::LLVMPointerType>(
1305 reduction.getInitializerMoldArg().getType()) &&
1306 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1309 reduction.getInitializerMoldArg().getType()),
1310 llvmSource,
"omp_orig");
1312 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1315 llvm::Value *allocation =
1316 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1317 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1323 llvm::BasicBlock *block =
nullptr) {
1324 if (block ==
nullptr)
1325 block = builder.GetInsertBlock();
1327 if (!block->hasTerminator())
1328 builder.SetInsertPoint(block);
1330 builder.SetInsertPoint(block->getTerminator());
1338template <
typename OP>
1341 llvm::IRBuilderBase &builder,
1343 llvm::BasicBlock *latestAllocaBlock,
1349 if (op.getNumReductionVars() == 0)
1355 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1356 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1357 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1358 builder.restoreIP(allocaIP);
1361 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1363 if (!reductionDecls[i].getAllocRegion().empty())
1371 if (useDeviceSharedMem)
1372 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1374 byRefVars[i] = builder.CreateAlloca(varTy);
1382 for (
auto [data, addr] : deferredStores)
1383 builder.CreateStore(data, addr);
1388 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1393 reductionVariableMap, i);
1401 "omp.reduction.neutral", builder,
1402 moduleTranslation, &phis)))
1405 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1406 "reduction neutral element declaration region");
1411 if (!reductionDecls[i].getAllocRegion().empty())
1420 builder.CreateStore(phis[0], byRefVars[i]);
1422 privateReductionVariables[i] = byRefVars[i];
1423 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1424 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1427 builder.CreateStore(phis[0], privateReductionVariables[i]);
1434 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1441template <
typename T>
1442static void collectReductionInfo(
1443 T loop, llvm::IRBuilderBase &builder,
1452 unsigned numReductions = loop.getNumReductionVars();
1454 for (
unsigned i = 0; i < numReductions; ++i) {
1457 owningAtomicReductionGens.push_back(
1460 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1464 reductionInfos.reserve(numReductions);
1465 for (
unsigned i = 0; i < numReductions; ++i) {
1466 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1467 if (owningAtomicReductionGens[i])
1468 atomicGen = owningAtomicReductionGens[i];
1469 llvm::Value *variable =
1470 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1473 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1474 allocatedType = alloca.getElemType();
1481 reductionInfos.push_back(
1483 privateReductionVariables[i],
1484 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1488 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1489 reductionDecls[i].getByrefElementType()
1491 *reductionDecls[i].getByrefElementType())
1501 llvm::IRBuilderBase &builder, StringRef regionName,
1502 bool shouldLoadCleanupRegionArg =
true) {
1503 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1504 if (cleanupRegion->empty())
1510 llvm::Instruction *potentialTerminator =
1511 builder.GetInsertBlock()->empty() ?
nullptr
1512 : &builder.GetInsertBlock()->back();
1513 if (potentialTerminator && potentialTerminator->isTerminator())
1514 builder.SetInsertPoint(potentialTerminator);
1515 llvm::Value *privateVarValue =
1516 shouldLoadCleanupRegionArg
1517 ? builder.CreateLoad(
1519 privateVariables[i])
1520 : privateVariables[i];
1525 moduleTranslation)))
1538 OP op, llvm::IRBuilderBase &builder,
1540 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1543 bool isNowait =
false,
bool isTeamsReduction =
false) {
1545 if (op.getNumReductionVars() == 0)
1557 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1559 owningReductionGenRefDataPtrGens,
1560 privateReductionVariables, reductionInfos, isByRef);
1565 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1566 builder.SetInsertPoint(tempTerminator);
1567 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1568 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1569 isByRef, isNowait, isTeamsReduction);
1574 if (!contInsertPoint->getBlock())
1575 return op->emitOpError() <<
"failed to convert reductions";
1577 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1578 if (!isTeamsReduction) {
1579 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1580 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1584 afterIP = *barrierIP;
1587 tempTerminator->eraseFromParent();
1588 builder.restoreIP(afterIP);
1592 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1593 [](omp::DeclareReductionOp reductionDecl) {
1594 return &reductionDecl.getCleanupRegion();
1597 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1598 "omp.reduction.cleanup");
1601 if (useDeviceSharedMem) {
1602 for (
auto [var, reductionDecl] :
1603 llvm::zip_equal(privateReductionVariables, reductionDecls))
1604 ompBuilder->createOMPFreeShared(
1605 builder, var, moduleTranslation.
convertType(reductionDecl.getType()));
1618template <
typename OP>
1622 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1627 if (op.getNumReductionVars() == 0)
1633 allocaIP, reductionDecls,
1634 privateReductionVariables, reductionVariableMap,
1635 deferredStores, isByRef)))
1638 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1639 allocaIP.getBlock(), reductionDecls,
1640 privateReductionVariables, reductionVariableMap,
1641 isByRef, deferredStores);
1655 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1658 Value blockArg = (*mappedPrivateVars)[privateVar];
1661 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1662 "A block argument corresponding to a mapped var should have "
1665 if (privVarType == blockArgType)
1672 if (!isa<LLVM::LLVMPointerType>(privVarType))
1673 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1686 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1688 llvm::BasicBlock *privInitBlock,
1690 Region &initRegion = privDecl.getInitRegion();
1691 if (initRegion.
empty())
1692 return llvmPrivateVar;
1694 assert(nonPrivateVar);
1695 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1696 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1701 moduleTranslation, &phis)))
1702 return llvm::createStringError(
1703 "failed to inline `init` region of `omp.private`");
1705 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1722 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1725 builder, moduleTranslation, privDecl,
1728 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1737 return llvm::Error::success();
1739 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1742 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1745 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1747 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1748 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1751 return privVarOrErr.takeError();
1753 llvmPrivateVar = privVarOrErr.get();
1754 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1759 return llvm::Error::success();
1765template <
typename T>
1770 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1773 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1774 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1775 allocaTerminator->getIterator()),
1776 true, allocaTerminator->getStableDebugLoc(),
1777 "omp.region.after_alloca");
1779 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1781 allocaTerminator = allocaIP.getBlock()->getTerminator();
1782 builder.SetInsertPoint(allocaTerminator);
1784 assert(allocaTerminator->getNumSuccessors() == 1 &&
1785 "This is an unconditional branch created by splitBB");
1787 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1788 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1792 unsigned int allocaAS =
1793 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1796 .getProgramAddressSpace();
1798 for (
auto [privDecl, mlirPrivVar, blockArg] :
1801 llvm::Type *llvmAllocType =
1802 moduleTranslation.
convertType(privDecl.getType());
1803 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1804 llvm::Value *llvmPrivateVar =
nullptr;
1806 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1808 llvmPrivateVar = builder.CreateAlloca(
1809 llvmAllocType,
nullptr,
"omp.private.alloc");
1810 if (allocaAS != defaultAS)
1811 llvmPrivateVar = builder.CreateAddrSpaceCast(
1812 llvmPrivateVar, builder.getPtrTy(defaultAS));
1815 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1818 return afterAllocas;
1826 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1835 if (mlir::isa<omp::ParallelOp>(parent))
1849 bool needsFirstprivate =
1850 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1851 return privOp.getDataSharingType() ==
1852 omp::DataSharingClauseType::FirstPrivate;
1855 if (!needsFirstprivate)
1858 llvm::BasicBlock *copyBlock =
1859 splitBB(builder,
true,
"omp.private.copy");
1862 for (
auto [decl, moldVar, llvmVar] :
1863 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1864 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1868 Region ©Region = decl.getCopyRegion();
1870 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1873 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1877 moduleTranslation)))
1878 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1892 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1893 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1909 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1911 llvm::Value *moldVar = findAssociatedValue(
1912 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1917 llvmPrivateVars, privateDecls, insertBarrier,
1921template <
typename T>
1929 std::back_inserter(privateCleanupRegions),
1930 [](omp::PrivateClauseOp privatizer) {
1931 return &privatizer.getDeallocRegion();
1935 privateVarsInfo.
llvmVars, moduleTranslation,
1936 builder,
"omp.private.dealloc",
1938 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1939 "`omp.private` op in");
1943 for (
auto [privDecl, llvmPrivVar, blockArg] :
1947 ompBuilder->createOMPFreeShared(
1948 builder, llvmPrivVar,
1949 moduleTranslation.
convertType(privDecl.getType()));
1963 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1973 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1974 using StorableBodyGenCallbackTy =
1975 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1977 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1983 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1987 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1991 sectionsOp.getNumReductionVars());
1995 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1998 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1999 reductionDecls, privateReductionVariables, reductionVariableMap,
2006 auto sectionOp = dyn_cast<omp::SectionOp>(op);
2010 Region ®ion = sectionOp.getRegion();
2011 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
2012 InsertPointTy allocaIP, InsertPointTy codeGenIP,
2014 builder.restoreIP(codeGenIP);
2021 sectionsOp.getRegion().getNumArguments());
2022 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
2023 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
2024 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
2026 moduleTranslation.
mapValue(sectionArg, llvmVal);
2033 sectionCBs.push_back(sectionCB);
2039 if (sectionCBs.empty())
2042 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
2047 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
2048 llvm::Value &vPtr, llvm::Value *&replacementValue)
2049 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
2050 replacementValue = &vPtr;
2056 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2060 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2061 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2063 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2064 sectionsOp.getNowait());
2069 builder.restoreIP(*afterIP);
2073 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2074 privateReductionVariables, isByRef, sectionsOp.getNowait());
2081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2088 assert(isByRef.size() == scopeOp.getNumReductionVars());
2097 scopeOp.getNumReductionVars());
2101 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2105 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2110 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2111 reductionDecls, privateReductionVariables, reductionVariableMap,
2116 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2118 builder.restoreIP(codeGenIP);
2124 return llvm::make_error<PreviouslyReportedError>();
2127 scopeOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2129 scopeOp.getPrivateNeedsBarrier())))
2130 return llvm::make_error<PreviouslyReportedError>();
2137 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2138 InsertPointTy oldIP = builder.saveIP();
2139 builder.restoreIP(codeGenIP);
2141 scopeOp.getLoc(), privateVarsInfo)))
2142 return llvm::make_error<PreviouslyReportedError>();
2143 builder.restoreIP(oldIP);
2144 return llvm::Error::success();
2147 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2148 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2149 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2154 builder.restoreIP(*afterIP);
2158 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2159 privateReductionVariables, isByRef, scopeOp.getNowait(),
2167 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2168 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2173 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2175 builder.restoreIP(codegenIP);
2177 builder, moduleTranslation)
2180 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2184 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2187 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
2188 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
2190 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2191 llvmCPFuncs.push_back(
2195 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2197 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2203 builder.restoreIP(*afterIP);
2207static omp::DistributeOp
2211 omp::DistributeOp distOp;
2212 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2218 if (walk.wasInterrupted() || !distOp)
2222 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2226 for (
auto ra : iface.getReductionBlockArgs())
2227 for (
auto &use : ra.getUses()) {
2228 auto *useOp = use.getOwner();
2230 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2231 debugUses.push_back(useOp);
2234 if (!distOp->isProperAncestor(useOp))
2241 for (
auto *use : debugUses)
2250 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2255 unsigned numReductionVars = op.getNumReductionVars();
2259 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2265 if (doTeamsReduction) {
2266 isByRef =
getIsByRef(op.getReductionByref());
2268 assert(isByRef.size() == op.getNumReductionVars());
2271 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2276 op, reductionArgs, builder, moduleTranslation, allocaIP,
2277 reductionDecls, privateReductionVariables, reductionVariableMap,
2282 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2285 moduleTranslation, allocaIP, deallocBlocks);
2286 builder.restoreIP(codegenIP);
2292 llvm::Value *numTeamsLower =
nullptr;
2293 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2294 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2296 llvm::Value *numTeamsUpper =
nullptr;
2297 if (!op.getNumTeamsUpperVars().empty())
2298 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2300 llvm::Value *threadLimit =
nullptr;
2301 if (!op.getThreadLimitVars().empty())
2302 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2304 llvm::Value *ifExpr =
nullptr;
2305 if (
Value ifVar = op.getIfExpr())
2308 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2309 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2311 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2316 builder.restoreIP(*afterIP);
2317 if (doTeamsReduction) {
2320 op, builder, moduleTranslation, allocaIP, reductionDecls,
2321 privateReductionVariables, isByRef,
2327static llvm::omp::RTLDependenceKindTy
2330 case mlir::omp::ClauseTaskDepend::taskdependin:
2331 return llvm::omp::RTLDependenceKindTy::DepIn;
2335 case mlir::omp::ClauseTaskDepend::taskdependout:
2336 case mlir::omp::ClauseTaskDepend::taskdependinout:
2337 return llvm::omp::RTLDependenceKindTy::DepInOut;
2338 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2339 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2340 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2341 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2343 llvm_unreachable(
"unhandled depend kind");
2347 std::optional<ArrayAttr> dependKinds,
OperandRange dependVars,
2350 if (dependVars.empty())
2352 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2354 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2356 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2357 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2358 dds.emplace_back(dd);
2370 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2372 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2373 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2377 llvmBuilder.restoreIP(ip);
2383 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2384 return llvm::Error::success();
2389 ompBuilder.pushFinalizationCB(
2399 llvm::OpenMPIRBuilder &ompBuilder,
2400 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2401 ompBuilder.popFinalizationCB();
2402 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2403 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2404 cancelBranch->setSuccessor(constructFini);
2410class TaskContextStructManager {
2412 TaskContextStructManager(llvm::IRBuilderBase &builder,
2413 LLVM::ModuleTranslation &moduleTranslation,
2414 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2415 : builder{builder}, moduleTranslation{moduleTranslation},
2416 privateDecls{privateDecls} {}
2422 void generateTaskContextStruct();
2428 void createGEPsToPrivateVars();
2434 SmallVector<llvm::Value *>
2435 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2438 void freeStructPtr();
2440 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2441 return llvmPrivateVarGEPs;
2444 llvm::Value *getStructPtr() {
return structPtr; }
2447 llvm::IRBuilderBase &builder;
2448 LLVM::ModuleTranslation &moduleTranslation;
2449 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2452 SmallVector<llvm::Type *> privateVarTypes;
2456 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2459 llvm::Value *structPtr =
nullptr;
2461 llvm::Type *structTy =
nullptr;
2472 llvm::SmallVector<llvm::Value *> lowerBounds;
2473 llvm::SmallVector<llvm::Value *> upperBounds;
2474 llvm::SmallVector<llvm::Value *> steps;
2475 llvm::SmallVector<llvm::Value *> trips;
2477 llvm::Value *totalTrips;
2479 llvm::Value *lookUpAsI64(mlir::Value val,
const LLVM::ModuleTranslation &mt,
2480 llvm::IRBuilderBase &builder) {
2484 if (v->getType()->isIntegerTy(64))
2486 if (v->getType()->isIntegerTy())
2487 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2492 IteratorInfo(mlir::omp::IteratorOp itersOp,
2493 mlir::LLVM::ModuleTranslation &moduleTranslation,
2494 llvm::IRBuilderBase &builder) {
2495 dims = itersOp.getLoopLowerBounds().size();
2496 lowerBounds.resize(dims);
2497 upperBounds.resize(dims);
2501 for (
unsigned d = 0; d < dims; ++d) {
2502 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2503 moduleTranslation, builder);
2504 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2505 moduleTranslation, builder);
2507 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2508 assert(lb && ub && st &&
2509 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2510 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2511 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2512 "Expect non-zero step in IteratorOp");
2514 lowerBounds[d] = lb;
2515 upperBounds[d] = ub;
2519 llvm::Value *diff = builder.CreateSub(ub, lb);
2520 llvm::Value *
div = builder.CreateSDiv(diff, st);
2521 trips[d] = builder.CreateAdd(
2522 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2525 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2526 for (
unsigned d = 0; d < dims; ++d)
2527 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2530 unsigned getDims()
const {
return dims; }
2531 llvm::ArrayRef<llvm::Value *> getLowerBounds()
const {
return lowerBounds; }
2532 llvm::ArrayRef<llvm::Value *> getUpperBounds()
const {
return upperBounds; }
2533 llvm::ArrayRef<llvm::Value *> getSteps()
const {
return steps; }
2534 llvm::ArrayRef<llvm::Value *> getTrips()
const {
return trips; }
2535 llvm::Value *getTotalTrips()
const {
return totalTrips; }
2540void TaskContextStructManager::generateTaskContextStruct() {
2541 if (privateDecls.empty())
2543 privateVarTypes.reserve(privateDecls.size());
2545 for (omp::PrivateClauseOp &privOp : privateDecls) {
2548 if (!privOp.readsFromMold())
2550 Type mlirType = privOp.getType();
2551 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2554 if (privateVarTypes.empty())
2557 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2560 llvm::DataLayout dataLayout =
2561 builder.GetInsertBlock()->getModule()->getDataLayout();
2562 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2563 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2566 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2568 "omp.task.context_ptr");
2571SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2572 llvm::Value *altStructPtr)
const {
2573 SmallVector<llvm::Value *> ret;
2576 ret.reserve(privateDecls.size());
2577 llvm::Value *zero = builder.getInt32(0);
2579 for (
auto privDecl : privateDecls) {
2580 if (!privDecl.readsFromMold()) {
2582 ret.push_back(
nullptr);
2585 llvm::Value *iVal = builder.getInt32(i);
2586 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2593void TaskContextStructManager::createGEPsToPrivateVars() {
2595 assert(privateVarTypes.empty());
2599 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2602void TaskContextStructManager::freeStructPtr() {
2606 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2608 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2609 builder.CreateFree(structPtr);
2613 llvm::OpenMPIRBuilder &ompBuilder,
2614 llvm::Value *affinityList, llvm::Value *
index,
2615 llvm::Value *addr, llvm::Value *len) {
2616 llvm::StructType *kmpTaskAffinityInfoTy =
2617 ompBuilder.getKmpTaskAffinityInfoTy();
2618 llvm::Value *entry = builder.CreateInBoundsGEP(
2619 kmpTaskAffinityInfoTy, affinityList,
index,
"omp.affinity.entry");
2621 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2622 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2624 llvm::Value *flags = builder.getInt32(0);
2626 builder.CreateStore(addr,
2627 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2628 builder.CreateStore(len,
2629 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2630 builder.CreateStore(flags,
2631 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2635 llvm::IRBuilderBase &builder,
2637 llvm::Value *affinityList) {
2638 for (
auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2639 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2640 assert(entryOp &&
"affinity item must be omp.affinity_entry");
2642 llvm::Value *addr = moduleTranslation.
lookupValue(entryOp.getAddr());
2643 llvm::Value *len = moduleTranslation.
lookupValue(entryOp.getLen());
2644 assert(addr && len &&
"expect affinity addr and len to be non-null");
2646 affinityList, builder.getInt64(i), addr, len);
2650static mlir::LogicalResult
2653 llvm::IRBuilderBase &builder,
2655 llvm::Value *tmp = linearIV;
2656 for (
int d = (
int)iterInfo.getDims() - 1; d >= 0; --d) {
2657 llvm::Value *trip = iterInfo.getTrips()[d];
2659 llvm::Value *idx = builder.CreateURem(tmp, trip);
2661 tmp = builder.CreateUDiv(tmp, trip);
2664 llvm::Value *physIV = builder.CreateAdd(
2665 iterInfo.getLowerBounds()[d],
2666 builder.CreateMul(idx, iterInfo.getSteps()[d]),
"omp.it.phys_iv");
2672 moduleTranslation.
mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2673 if (mlir::failed(moduleTranslation.
convertBlock(iteratorRegionBlock,
2676 return mlir::failure();
2678 return mlir::success();
2684static mlir::LogicalResult
2687 IteratorInfo &iterInfo, llvm::StringRef loopName,
2692 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2694 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2695 llvm::Value *linearIV) -> llvm::Error {
2696 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2697 builder.restoreIP(bodyIP);
2700 builder, moduleTranslation))) {
2701 return llvm::make_error<llvm::StringError>(
2702 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2706 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.
getTerminator());
2707 assert(yield && yield.getResults().size() == 1 &&
2708 "expect omp.yield in iterator region to have one result");
2710 genStoreEntry(linearIV, yield);
2716 return llvm::Error::success();
2719 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2721 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2725 builder.restoreIP(*afterIP);
2727 return mlir::success();
2730static mlir::LogicalResult
2733 llvm::OpenMPIRBuilder::AffinityData &ad) {
2735 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2738 return mlir::success();
2742 llvm::StructType *kmpTaskAffinityInfoTy =
2745 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2746 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2747 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2749 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2750 "omp.affinity_list");
2753 auto createAffinity =
2754 [&](llvm::Value *count,
2755 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2756 llvm::OpenMPIRBuilder::AffinityData ad{};
2757 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2759 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2763 if (!taskOp.getAffinityVars().empty()) {
2764 llvm::Value *count = llvm::ConstantInt::get(
2765 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2766 llvm::Value *list = allocateAffinityList(count);
2769 ads.emplace_back(createAffinity(count, list));
2772 if (!taskOp.getIterated().empty()) {
2773 for (
auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2774 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2775 assert(itersOp &&
"iterated value must be defined by omp.iterator");
2776 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2777 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2779 itersOp, builder, moduleTranslation, iterInfo,
"iterator",
2780 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2781 auto entryOp = yield.getResults()[0]
2782 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2783 assert(entryOp &&
"expect yield produce an affinity entry");
2790 affList, linearIV, addr, len);
2792 return llvm::failure();
2793 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2797 llvm::Value *totalAffinityCount = builder.getInt32(0);
2798 for (
const auto &affinity : ads)
2799 totalAffinityCount = builder.CreateAdd(
2801 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2804 llvm::Value *affinityInfo = ads.front().Info;
2805 if (ads.size() > 1) {
2806 llvm::StructType *kmpTaskAffinityInfoTy =
2808 llvm::Value *affinityInfoElemSize = builder.getInt64(
2809 moduleTranslation.
getLLVMModule()->getDataLayout().getTypeAllocSize(
2810 kmpTaskAffinityInfoTy));
2812 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2813 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2814 for (
const auto &affinity : ads) {
2815 llvm::Value *affinityCount = builder.CreateIntCast(
2816 affinity.Count, builder.getInt32Ty(),
false);
2817 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2818 affinityCount, builder.getInt64Ty(),
false);
2819 llvm::Value *affinityInfoSize =
2820 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2822 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2823 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2825 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2826 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2828 builder.CreateMemCpy(
2829 packedAffinityInfoIndex, llvm::Align(1),
2830 builder.CreatePointerBitCastOrAddrSpaceCast(
2831 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2832 ->getPointerAddressSpace())),
2833 llvm::Align(1), affinityInfoSize);
2835 packedAffinityInfoOffset =
2836 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2839 affinityInfo = packedAffinityInfo;
2842 ad.Count = totalAffinityCount;
2843 ad.Info = affinityInfo;
2845 return mlir::success();
2851static mlir::LogicalResult
2854 std::optional<ArrayAttr> dependIteratedKinds,
2855 llvm::IRBuilderBase &builder,
2857 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2858 if (dependIterated.empty()) {
2861 return mlir::success();
2865 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2866 unsigned numLocator = dependVars.size();
2869 llvm::Value *totalCount =
2870 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2873 for (
auto iter : dependIterated) {
2874 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2875 assert(itersOp &&
"depend_iterated value must be defined by omp.iterator");
2876 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2878 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2883 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2884 llvm::Value *depArray =
2885 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2886 totalCount,
nullptr,
".dep.arr.addr");
2889 if (numLocator > 0) {
2892 for (
auto [i, dd] : llvm::enumerate(dds)) {
2893 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2894 llvm::Value *entry =
2895 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2896 ompBuilder.emitTaskDependency(builder, entry, dd);
2901 llvm::Value *offset =
2902 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2903 for (
auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2904 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2905 dependIteratedKinds->getValue()[i]);
2906 llvm::omp::RTLDependenceKindTy rtlKind =
2909 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2911 itersOp, builder, moduleTranslation, iterInfo,
"dep_iterator",
2912 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2914 moduleTranslation.
lookupValue(yield.getResults()[0]);
2915 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2916 llvm::Value *entry =
2917 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2918 ompBuilder.emitTaskDependency(
2920 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2923 return mlir::failure();
2926 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2929 taskDeps.DepArray = depArray;
2930 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2931 return mlir::success();
2938 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2943 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2955 InsertPointTy allocaIP =
2960 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2961 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2962 builder.getContext(),
"omp.task.start",
2963 builder.GetInsertBlock()->getParent());
2964 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2965 builder.SetInsertPoint(branchToTaskStartBlock);
2968 llvm::BasicBlock *copyBlock =
2969 splitBB(builder,
true,
"omp.private.copy");
2970 llvm::BasicBlock *initBlock =
2971 splitBB(builder,
true,
"omp.private.init");
2987 moduleTranslation, allocaIP, deallocBlocks);
2990 builder.SetInsertPoint(initBlock->getTerminator());
2993 taskStructMgr.generateTaskContextStruct();
3000 taskStructMgr.createGEPsToPrivateVars();
3002 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
3005 taskStructMgr.getLLVMPrivateVarGEPs())) {
3007 if (!privDecl.readsFromMold())
3009 assert(llvmPrivateVarAlloc &&
3010 "reads from mold so shouldn't have been skipped");
3013 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3014 blockArg, llvmPrivateVarAlloc, initBlock);
3015 if (!privateVarOrErr)
3016 return handleError(privateVarOrErr, *taskOp.getOperation());
3025 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3026 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3027 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3029 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3030 llvmPrivateVarAlloc);
3032 assert(llvmPrivateVarAlloc->getType() ==
3033 moduleTranslation.
convertType(blockArg.getType()));
3043 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3044 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3045 taskOp.getPrivateNeedsBarrier())))
3046 return llvm::failure();
3048 llvm::OpenMPIRBuilder::AffinityData ad;
3050 return llvm::failure();
3053 builder.SetInsertPoint(taskStartBlock);
3056 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3061 moduleTranslation, allocaIP, deallocBlocks);
3064 builder.restoreIP(codegenIP);
3066 llvm::BasicBlock *privInitBlock =
nullptr;
3068 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3071 auto [blockArg, privDecl, mlirPrivVar] = zip;
3073 if (privDecl.readsFromMold())
3076 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3077 llvm::Type *llvmAllocType =
3078 moduleTranslation.
convertType(privDecl.getType());
3079 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3080 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3081 llvmAllocType,
nullptr,
"omp.private.alloc");
3084 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3085 blockArg, llvmPrivateVar, privInitBlock);
3086 if (!privateVarOrError)
3087 return privateVarOrError.takeError();
3088 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3089 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3092 taskStructMgr.createGEPsToPrivateVars();
3093 for (
auto [i, llvmPrivVar] :
3094 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3096 assert(privateVarsInfo.
llvmVars[i] &&
3097 "This is added in the loop above");
3100 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3105 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3109 if (!privateDecl.readsFromMold())
3112 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3113 llvmPrivateVar = builder.CreateLoad(
3114 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3116 assert(llvmPrivateVar->getType() ==
3117 moduleTranslation.
convertType(blockArg.getType()));
3118 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3122 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
3123 if (failed(
handleError(continuationBlockOrError, *taskOp)))
3124 return llvm::make_error<PreviouslyReportedError>();
3126 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3129 taskOp.getLoc(), privateVarsInfo)))
3130 return llvm::make_error<PreviouslyReportedError>();
3133 taskStructMgr.freeStructPtr();
3135 return llvm::Error::success();
3144 llvm::omp::Directive::OMPD_taskgroup);
3146 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3147 if (failed(
buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3148 taskOp.getDependIterated(),
3149 taskOp.getDependIteratedKinds(), builder,
3150 moduleTranslation, dependencies)))
3153 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3154 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3156 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3158 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dependencies, ad,
3159 taskOp.getMergeable(),
3160 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
3161 moduleTranslation.
lookupValue(taskOp.getPriority()));
3169 builder.restoreIP(*afterIP);
3171 if (dependencies.DepArray)
3172 builder.CreateFree(dependencies.DepArray);
3181 llvm::IRBuilderBase &builder,
3189 loopWrapperOp.getRegion(),
"omp.taskloop.wrapper.region", builder,
3192 if (failed(
handleError(continuationBlockOrError, opInst)))
3195 builder.SetInsertPoint(continuationBlockOrError.get());
3203static llvm::Expected<llvm::Value *>
3206 llvm::IRBuilderBase &builder) {
3207 if (llvm::Value *mapped = moduleTranslation.
lookupValue(value))
3212 return llvm::make_error<llvm::StringError>(
3213 "value is a block argument and is not mapped",
3214 llvm::inconvertibleErrorCode());
3216 return llvm::make_error<llvm::StringError>(
3217 "unsupported op defining taskloop loop bound",
3218 llvm::inconvertibleErrorCode());
3228 if (!operandOrError)
3229 return operandOrError.takeError();
3230 moduleTranslation.
mapValue(operand, *operandOrError);
3231 mappingsToRemove.push_back(operand);
3235 return llvm::make_error<llvm::StringError>(
3236 "failed to convert op defining taskloop loop bound",
3237 llvm::inconvertibleErrorCode());
3240 assert(
result &&
"expected conversion of loop bound op to produce a value");
3244 mappingsToRemove.push_back(resultValue);
3246 for (
Value mappedValue : mappingsToRemove)
3255 llvm::Value *&lbVal, llvm::Value *&ubVal,
3256 llvm::Value *&stepVal) {
3264 return firstLbOrErr.takeError();
3266 llvm::Type *boundType = (*firstLbOrErr)->getType();
3267 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3268 if (loopOp.getCollapseNumLoops() > 1) {
3286 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3288 i == 0 ? std::move(firstLbOrErr)
3292 return lbOrErr.takeError();
3294 upperBounds[i], moduleTranslation, builder);
3296 return ubOrErr.takeError();
3300 return stepOrErr.takeError();
3302 llvm::Value *loopLb = *lbOrErr;
3303 llvm::Value *loopUb = *ubOrErr;
3304 llvm::Value *loopStep = *stepOrErr;
3310 llvm::Value *loopLbMinusOne = builder.CreateSub(
3311 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3312 llvm::Value *loopUbMinusOne = builder.CreateSub(
3313 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3314 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3315 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3316 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3317 llvm::Value *loopTripCount =
3318 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3319 loopTripCount = builder.CreateBinaryIntrinsic(
3320 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3324 llvm::Value *loopTripCountDivStep =
3325 builder.CreateSDiv(loopTripCount, loopStep);
3326 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3327 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3328 llvm::Value *loopTripCountRem =
3329 builder.CreateSRem(loopTripCount, loopStep);
3330 loopTripCountRem = builder.CreateBinaryIntrinsic(
3331 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3332 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3334 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3337 builder.CreateAdd(loopTripCountDivStep,
3338 builder.CreateZExtOrTrunc(
3339 needsRoundUp, loopTripCountDivStep->getType()));
3340 ubVal = builder.CreateMul(ubVal, loopTripCount);
3342 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3343 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3348 return ubOrErr.takeError();
3352 return stepOrErr.takeError();
3353 lbVal = *firstLbOrErr;
3355 stepVal = *stepOrErr;
3358 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
3359 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
3360 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
3361 return llvm::Error::success();
3367 llvm::IRBuilderBase &builder,
3369 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3371 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3379 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3383 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3386 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3387 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3388 builder.getContext(),
"omp.taskloop.wrapper.start",
3389 builder.GetInsertBlock()->getParent());
3390 llvm::Instruction *branchToTaskloopStartBlock =
3391 builder.CreateBr(taskloopStartBlock);
3392 builder.SetInsertPoint(branchToTaskloopStartBlock);
3394 llvm::BasicBlock *copyBlock =
3395 splitBB(builder,
true,
"omp.private.copy");
3396 llvm::BasicBlock *initBlock =
3397 splitBB(builder,
true,
"omp.private.init");
3400 moduleTranslation, allocaIP, deallocBlocks);
3403 builder.SetInsertPoint(initBlock->getTerminator());
3406 taskStructMgr.generateTaskContextStruct();
3407 taskStructMgr.createGEPsToPrivateVars();
3409 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
3411 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3413 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3414 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3416 if (!privDecl.readsFromMold())
3418 assert(llvmPrivateVarAlloc &&
3419 "reads from mold so shouldn't have been skipped");
3422 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3423 blockArg, llvmPrivateVarAlloc, initBlock);
3424 if (!privateVarOrErr)
3425 return handleError(privateVarOrErr, *contextOp.getOperation());
3427 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3429 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3430 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3432 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3433 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3434 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3436 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3437 llvmPrivateVarAlloc);
3439 assert(llvmPrivateVarAlloc->getType() ==
3440 moduleTranslation.
convertType(blockArg.getType()));
3446 contextOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3447 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3448 contextOp.getPrivateNeedsBarrier())))
3449 return llvm::failure();
3452 builder.SetInsertPoint(taskloopStartBlock);
3454 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3455 llvm::Value *lbVal =
nullptr;
3456 llvm::Value *ubVal =
nullptr;
3457 llvm::Value *stepVal =
nullptr;
3459 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3463 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3468 moduleTranslation, allocaIP, deallocBlocks);
3471 builder.restoreIP(codegenIP);
3473 llvm::BasicBlock *privInitBlock =
nullptr;
3475 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3478 auto [blockArg, privDecl, mlirPrivVar] = zip;
3480 if (privDecl.readsFromMold())
3483 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3484 llvm::Type *llvmAllocType =
3485 moduleTranslation.
convertType(privDecl.getType());
3486 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3487 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3488 llvmAllocType,
nullptr,
"omp.private.alloc");
3491 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3492 blockArg, llvmPrivateVar, privInitBlock);
3493 if (!privateVarOrError)
3494 return privateVarOrError.takeError();
3495 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3496 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3499 taskStructMgr.createGEPsToPrivateVars();
3500 for (
auto [i, llvmPrivVar] :
3501 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3503 assert(privateVarsInfo.
llvmVars[i] &&
3504 "This is added in the loop above");
3507 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3512 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3516 if (!privateDecl.readsFromMold())
3519 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3520 llvmPrivateVar = builder.CreateLoad(
3521 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3523 assert(llvmPrivateVar->getType() ==
3524 moduleTranslation.
convertType(blockArg.getType()));
3525 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3531 contextOp.getRegion(),
"omp.taskloop.context.region", builder,
3534 if (failed(
handleError(continuationBlockOrError, opInst)))
3535 return llvm::make_error<PreviouslyReportedError>();
3537 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3545 contextOp.getLoc(), privateVarsInfo)))
3546 return llvm::make_error<PreviouslyReportedError>();
3549 taskStructMgr.freeStructPtr();
3551 return llvm::Error::success();
3557 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3558 llvm::Value *destPtr, llvm::Value *srcPtr)
3560 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3561 builder.restoreIP(codegenIP);
3564 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3566 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
3568 TaskContextStructManager &srcStructMgr = taskStructMgr;
3569 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3571 destStructMgr.generateTaskContextStruct();
3572 llvm::Value *dest = destStructMgr.getStructPtr();
3573 dest->setName(
"omp.taskloop.context.dest");
3574 builder.CreateStore(dest, destPtr);
3577 srcStructMgr.createGEPsToPrivateVars(src);
3579 destStructMgr.createGEPsToPrivateVars(dest);
3582 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3583 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
3586 if (!privDecl.readsFromMold())
3588 assert(llvmPrivateVarAlloc &&
3589 "reads from mold so shouldn't have been skipped");
3592 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3593 llvmPrivateVarAlloc, builder.GetInsertBlock());
3594 if (!privateVarOrErr)
3595 return privateVarOrErr.takeError();
3604 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3605 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3606 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3608 llvmPrivateVarAlloc = builder.CreateLoad(
3609 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3611 assert(llvmPrivateVarAlloc->getType() ==
3612 moduleTranslation.
convertType(blockArg.getType()));
3620 moduleTranslation, srcGEPs, destGEPs,
3622 contextOp.getPrivateNeedsBarrier())))
3623 return llvm::make_error<PreviouslyReportedError>();
3625 return builder.saveIP();
3633 llvm::Value *ifCond =
nullptr;
3634 llvm::Value *grainsize =
nullptr;
3636 mlir::Value grainsizeVal = contextOp.getGrainsize();
3637 mlir::Value numTasksVal = contextOp.getNumTasks();
3638 if (
Value ifVar = contextOp.getIfExpr())
3641 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
3643 }
else if (numTasksVal) {
3644 grainsize = moduleTranslation.
lookupValue(numTasksVal);
3648 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
3649 if (taskStructMgr.getStructPtr())
3650 taskDupOrNull = taskDupCB;
3660 llvm::omp::Directive::OMPD_taskgroup);
3662 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3663 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3665 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3666 stepVal, contextOp.getUntied(), ifCond, grainsize,
3667 contextOp.getNogroup(), sched,
3668 moduleTranslation.
lookupValue(contextOp.getFinal()),
3669 contextOp.getMergeable(),
3670 moduleTranslation.
lookupValue(contextOp.getPriority()),
3671 loopOp.getCollapseNumLoops(), taskDupOrNull,
3672 taskStructMgr.getStructPtr());
3679 builder.restoreIP(*afterIP);
3687 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3691 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3693 builder.restoreIP(codegenIP);
3695 builder, moduleTranslation)
3700 InsertPointTy allocaIP =
3702 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3703 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3705 ompLoc, allocaIP, deallocBlocks, bodyCB);
3710 builder.restoreIP(*afterIP);
3729 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3733 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3735 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3739 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3742 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
3743 llvm::Type *ivType = step->getType();
3744 llvm::Value *chunk =
nullptr;
3745 if (wsloopOp.getScheduleChunk()) {
3746 llvm::Value *chunkVar =
3747 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
3748 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3751 omp::DistributeOp distributeOp =
nullptr;
3752 llvm::Value *distScheduleChunk =
nullptr;
3753 bool hasDistSchedule =
false;
3754 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
3755 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
3756 hasDistSchedule = distributeOp.getDistScheduleStatic();
3757 if (distributeOp.getDistScheduleChunkSize()) {
3758 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3759 distributeOp.getDistScheduleChunkSize());
3760 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3769 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3773 wsloopOp.getNumReductionVars());
3776 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3783 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3788 moduleTranslation, allocaIP, reductionDecls,
3789 privateReductionVariables, reductionVariableMap,
3790 deferredStores, isByRef)))
3799 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3801 wsloopOp.getPrivateNeedsBarrier())))
3804 assert(afterAllocas.get()->getSinglePredecessor());
3805 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3807 afterAllocas.get()->getSinglePredecessor(),
3808 reductionDecls, privateReductionVariables,
3809 reductionVariableMap, isByRef, deferredStores)))
3813 bool isOrdered = wsloopOp.getOrdered().has_value();
3814 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3815 bool isSimd = wsloopOp.getScheduleSimd();
3816 bool loopNeedsBarrier = !wsloopOp.getNowait();
3821 llvm::omp::WorksharingLoopType workshareLoopType =
3822 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3823 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3824 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3828 llvm::omp::Directive::OMPD_for);
3830 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3833 LinearClauseProcessor linearClauseProcessor;
3835 if (!wsloopOp.getLinearVars().empty()) {
3836 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3838 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3840 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3841 linearClauseProcessor.createLinearVar(
3842 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3844 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3845 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3848 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3850 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3858 if (!wsloopOp.getLinearVars().empty()) {
3859 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3860 loopInfo->getPreheader());
3861 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3863 builder.saveIP(), llvm::omp::OMPD_barrier);
3866 builder.restoreIP(*afterBarrierIP);
3867 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3868 loopInfo->getIndVar());
3869 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3872 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3875 bool noLoopMode =
false;
3876 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3878 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3882 if (loopOp == targetCapturedOp) {
3883 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3884 omp::TargetExecMode::no_loop)
3889 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3890 ompBuilder->applyWorkshareLoop(
3891 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3892 convertToScheduleKind(schedule), chunk, isSimd,
3893 scheduleMod == omp::ScheduleModifier::monotonic,
3894 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3895 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3901 if (!wsloopOp.getLinearVars().empty()) {
3902 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3903 assert(loopInfo->getLastIter() &&
3904 "`lastiter` in CanonicalLoopInfo is nullptr");
3905 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3906 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3907 loopInfo->getLastIter());
3910 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3911 linearClauseProcessor.rewriteInPlace(
3912 builder, sourceBlock->getSingleSuccessor(), *regionBlock,
3913 "omp.loop_nest.region",
index);
3915 builder.restoreIP(oldIP);
3923 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3924 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3929 wsloopOp.getLoc(), privateVarsInfo);
3936 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3938 assert(isByRef.size() == opInst.getNumReductionVars());
3951 opInst.getNumReductionVars());
3955 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3958 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3960 return llvm::make_error<PreviouslyReportedError>();
3966 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3969 InsertPointTy(allocaIP.getBlock(),
3970 allocaIP.getBlock()->getTerminator()->getIterator());
3973 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3974 reductionDecls, privateReductionVariables, reductionVariableMap,
3975 deferredStores, isByRef)))
3976 return llvm::make_error<PreviouslyReportedError>();
3978 assert(afterAllocas.get()->getSinglePredecessor());
3979 builder.restoreIP(codeGenIP);
3985 return llvm::make_error<PreviouslyReportedError>();
3988 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3990 opInst.getPrivateNeedsBarrier())))
3991 return llvm::make_error<PreviouslyReportedError>();
3994 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3995 afterAllocas.get()->getSinglePredecessor(),
3996 reductionDecls, privateReductionVariables,
3997 reductionVariableMap, isByRef, deferredStores)))
3998 return llvm::make_error<PreviouslyReportedError>();
4003 moduleTranslation, allocaIP, deallocBlocks);
4007 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
4009 return regionBlock.takeError();
4012 if (opInst.getNumReductionVars() > 0) {
4017 owningReductionGenRefDataPtrGens;
4019 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
4021 owningReductionGenRefDataPtrGens,
4022 privateReductionVariables, reductionInfos, isByRef);
4025 builder.SetInsertPoint((*regionBlock)->getTerminator());
4028 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
4029 builder.SetInsertPoint(tempTerminator);
4031 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
4032 ompBuilder->createReductions(
4033 builder.saveIP(), allocaIP, reductionInfos, isByRef,
4035 if (!contInsertPoint)
4036 return contInsertPoint.takeError();
4038 if (!contInsertPoint->getBlock())
4039 return llvm::make_error<PreviouslyReportedError>();
4041 tempTerminator->eraseFromParent();
4042 builder.restoreIP(*contInsertPoint);
4045 return llvm::Error::success();
4048 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
4049 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
4058 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4059 InsertPointTy oldIP = builder.saveIP();
4060 builder.restoreIP(codeGenIP);
4065 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4066 [](omp::DeclareReductionOp reductionDecl) {
4067 return &reductionDecl.getCleanupRegion();
4070 reductionCleanupRegions, privateReductionVariables,
4071 moduleTranslation, builder,
"omp.reduction.cleanup")))
4072 return llvm::createStringError(
4073 "failed to inline `cleanup` region of `omp.declare_reduction`");
4076 opInst.getLoc(), privateVarsInfo)))
4077 return llvm::make_error<PreviouslyReportedError>();
4081 if (isCancellable) {
4082 auto IPOrErr = ompBuilder->createBarrier(
4083 llvm::OpenMPIRBuilder::LocationDescription(builder),
4084 llvm::omp::Directive::OMPD_unknown,
4088 return IPOrErr.takeError();
4091 builder.restoreIP(oldIP);
4092 return llvm::Error::success();
4095 llvm::Value *ifCond =
nullptr;
4096 if (
auto ifVar = opInst.getIfExpr())
4098 llvm::Value *numThreads =
nullptr;
4099 if (!opInst.getNumThreadsVars().empty())
4100 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
4101 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4102 if (
auto bind = opInst.getProcBindKind())
4106 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4108 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4110 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4111 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4112 privCB, finiCB, ifCond, numThreads, pbKind,
4118 builder.restoreIP(*afterIP);
4123static llvm::omp::OrderKind
4126 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4128 case omp::ClauseOrderKind::Concurrent:
4129 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4131 llvm_unreachable(
"Unknown ClauseOrderKind kind");
4139 auto simdOp = cast<omp::SimdOp>(opInst);
4147 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4150 simdOp.getNumReductionVars());
4155 assert(isByRef.size() == simdOp.getNumReductionVars());
4157 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4161 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4166 LinearClauseProcessor linearClauseProcessor;
4168 if (!simdOp.getLinearVars().empty()) {
4169 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4171 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4172 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4173 bool isImplicit =
false;
4174 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4178 if (linearVar == mlirPrivVar) {
4180 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4181 llvmPrivateVar, idx);
4187 linearClauseProcessor.createLinearVar(
4188 builder, moduleTranslation,
4191 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
4192 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4196 moduleTranslation, allocaIP, reductionDecls,
4197 privateReductionVariables, reductionVariableMap,
4198 deferredStores, isByRef)))
4209 assert(afterAllocas.get()->getSinglePredecessor());
4210 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4212 afterAllocas.get()->getSinglePredecessor(),
4213 reductionDecls, privateReductionVariables,
4214 reductionVariableMap, isByRef, deferredStores)))
4217 llvm::ConstantInt *simdlen =
nullptr;
4218 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4219 simdlen = builder.getInt64(simdlenVar.value());
4221 llvm::ConstantInt *safelen =
nullptr;
4222 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4223 safelen = builder.getInt64(safelenVar.value());
4225 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4228 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4229 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4231 for (
size_t i = 0; i < operands.size(); ++i) {
4232 llvm::Value *alignment =
nullptr;
4233 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
4234 llvm::Type *ty = llvmVal->getType();
4236 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4237 alignment = builder.getInt64(intAttr.getInt());
4238 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
4239 assert(alignment &&
"Invalid alignment value");
4243 if (!intAttr.getValue().isPowerOf2())
4246 auto curInsert = builder.saveIP();
4247 builder.SetInsertPoint(sourceBlock);
4248 llvmVal = builder.CreateLoad(ty, llvmVal);
4249 builder.restoreIP(curInsert);
4250 alignedVars[llvmVal] = alignment;
4254 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
4261 if (simdOp.getLinearVars().size()) {
4262 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4263 loopInfo->getPreheader());
4265 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4266 loopInfo->getIndVar());
4268 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4270 ompBuilder->applySimd(loopInfo, alignedVars,
4272 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
4274 order, simdlen, safelen);
4276 linearClauseProcessor.emitStoresForLinearVar(builder);
4279 bool hasOrderedRegions =
false;
4280 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4281 hasOrderedRegions =
true;
4285 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
4286 llvm::BasicBlock *startBB = sourceBlock->getSingleSuccessor();
4287 llvm::BasicBlock *endBB = *regionBlock;
4288 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4289 "omp.loop_nest.region",
index);
4291 if (hasOrderedRegions) {
4293 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4294 "omp.ordered.region",
index);
4296 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4297 "omp_region.finalize",
index);
4305 for (
auto [i, tuple] : llvm::enumerate(
4306 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4307 privateReductionVariables))) {
4308 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4310 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
4311 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
4312 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
4316 llvm::Value *redValue = originalVariable;
4319 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
4320 llvm::Value *privateRedValue = builder.CreateLoad(
4321 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
4322 llvm::Value *reduced;
4324 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4327 builder.restoreIP(res.get());
4331 builder.CreateStore(reduced, originalVariable);
4336 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4337 [](omp::DeclareReductionOp reductionDecl) {
4338 return &reductionDecl.getCleanupRegion();
4341 moduleTranslation, builder,
4342 "omp.reduction.cleanup")))
4354 auto loopOp = cast<omp::LoopNestOp>(opInst);
4360 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4365 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4366 llvm::Value *iv) -> llvm::Error {
4369 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4374 bodyInsertPoints.push_back(ip);
4376 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4377 return llvm::Error::success();
4380 builder.restoreIP(ip);
4382 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
4384 return regionBlock.takeError();
4386 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4387 return llvm::Error::success();
4395 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4396 llvm::Value *lowerBound =
4397 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
4398 llvm::Value *upperBound =
4399 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
4400 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
4405 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4406 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4408 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4410 computeIP = loopInfos.front()->getPreheaderIP();
4414 ompBuilder->createCanonicalLoop(
4415 loc, bodyGen, lowerBound, upperBound, step,
4416 true, loopOp.getLoopInclusive(), computeIP);
4421 loopInfos.push_back(*loopResult);
4424 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4425 loopInfos.front()->getAfterIP();
4428 if (
const auto &tiles = loopOp.getTileSizes()) {
4429 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4432 for (
auto tile : tiles.value()) {
4433 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
4434 tileSizes.push_back(tileVal);
4437 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4438 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4442 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4443 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4444 afterIP = {afterAfterBB, afterAfterBB->begin()};
4448 for (
const auto &newLoop : newLoops)
4449 loopInfos.push_back(newLoop);
4453 const auto &numCollapse = loopOp.getCollapseNumLoops();
4455 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4457 auto newTopLoopInfo =
4458 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4460 assert(newTopLoopInfo &&
"New top loop information is missing");
4461 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
4462 [&](OpenMPLoopInfoStackFrame &frame) {
4463 frame.loopInfo = newTopLoopInfo;
4471 builder.restoreIP(afterIP);
4481 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4482 Value loopIV = op.getInductionVar();
4483 Value loopTC = op.getTripCount();
4485 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
4488 ompBuilder->createCanonicalLoop(
4490 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4493 moduleTranslation.
mapValue(loopIV, llvmIV);
4495 builder.restoreIP(ip);
4500 return bodyGenStatus.takeError();
4502 llvmTC,
"omp.loop");
4504 return op.emitError(llvm::toString(llvmOrError.takeError()));
4506 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4507 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4508 builder.restoreIP(afterIP);
4511 if (
Value cli = op.getCli())
4524 Value applyee = op.getApplyee();
4525 assert(applyee &&
"Loop to apply unrolling on required");
4527 llvm::CanonicalLoopInfo *consBuilderCLI =
4529 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4530 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4538static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4541 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4546 for (
Value size : op.getSizes()) {
4547 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
4548 assert(translatedSize &&
4549 "sizes clause arguments must already be translated");
4550 translatedSizes.push_back(translatedSize);
4553 for (
Value applyee : op.getApplyees()) {
4554 llvm::CanonicalLoopInfo *consBuilderCLI =
4556 assert(applyee &&
"Canonical loop must already been translated");
4557 translatedLoops.push_back(consBuilderCLI);
4560 auto generatedLoops =
4561 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4562 if (!op.getGeneratees().empty()) {
4563 for (
auto [mlirLoop,
genLoop] :
4564 zip_equal(op.getGeneratees(), generatedLoops))
4569 for (
Value applyee : op.getApplyees())
4577static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4580 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4584 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
4585 Value applyee = op.getApplyees()[i];
4586 llvm::CanonicalLoopInfo *consBuilderCLI =
4588 assert(applyee &&
"Canonical loop must already been translated");
4589 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4590 beforeFuse.push_back(consBuilderCLI);
4591 else if (op.getCount().has_value() &&
4592 i >= op.getFirst().value() + op.getCount().value() - 1)
4593 afterFuse.push_back(consBuilderCLI);
4595 toFuse.push_back(consBuilderCLI);
4598 (op.getGeneratees().empty() ||
4599 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4600 "Wrong number of generatees");
4603 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4604 if (!op.getGeneratees().empty()) {
4606 for (; i < beforeFuse.size(); i++)
4607 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4608 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4609 for (; i < afterFuse.size(); i++)
4610 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4614 for (
Value applyee : op.getApplyees())
4621static llvm::AtomicOrdering
4624 return llvm::AtomicOrdering::Monotonic;
4627 case omp::ClauseMemoryOrderKind::Seq_cst:
4628 return llvm::AtomicOrdering::SequentiallyConsistent;
4629 case omp::ClauseMemoryOrderKind::Acq_rel:
4630 return llvm::AtomicOrdering::AcquireRelease;
4631 case omp::ClauseMemoryOrderKind::Acquire:
4632 return llvm::AtomicOrdering::Acquire;
4633 case omp::ClauseMemoryOrderKind::Release:
4634 return llvm::AtomicOrdering::Release;
4635 case omp::ClauseMemoryOrderKind::Relaxed:
4636 return llvm::AtomicOrdering::Monotonic;
4638 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
4645 auto readOp = cast<omp::AtomicReadOp>(opInst);
4650 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4653 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4656 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
4657 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
4659 llvm::Type *elementType =
4660 moduleTranslation.
convertType(readOp.getElementType());
4662 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
4663 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
4664 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4672 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4677 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4680 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4682 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
4683 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
4684 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
4685 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
4688 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4696 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
4697 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
4698 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
4699 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
4700 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
4701 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
4702 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
4703 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
4704 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
4705 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4709 bool &isIgnoreDenormalMode,
4710 bool &isFineGrainedMemory,
4711 bool &isRemoteMemory) {
4712 isIgnoreDenormalMode =
false;
4713 isFineGrainedMemory =
false;
4714 isRemoteMemory =
false;
4715 if (atomicUpdateOp &&
4716 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4717 mlir::omp::AtomicControlAttr atomicControlAttr =
4718 atomicUpdateOp.getAtomicControlAttr();
4719 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4720 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4721 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4728 llvm::IRBuilderBase &builder,
4735 auto &innerOpList = opInst.getRegion().front().getOperations();
4736 bool isXBinopExpr{
false};
4737 llvm::AtomicRMWInst::BinOp binop;
4739 llvm::Value *llvmExpr =
nullptr;
4740 llvm::Value *llvmX =
nullptr;
4741 llvm::Type *llvmXElementType =
nullptr;
4742 if (innerOpList.size() == 2) {
4748 opInst.getRegion().getArgument(0))) {
4749 return opInst.emitError(
"no atomic update operation with region argument"
4750 " as operand found inside atomic.update region");
4753 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
4755 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4759 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4761 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4763 opInst.getRegion().getArgument(0).getType());
4764 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4768 llvm::AtomicOrdering atomicOrdering =
4773 [&opInst, &moduleTranslation](
4774 llvm::Value *atomicx,
4777 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4778 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4779 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4780 return llvm::make_error<PreviouslyReportedError>();
4782 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4783 assert(yieldop && yieldop.getResults().size() == 1 &&
4784 "terminator must be omp.yield op and it must have exactly one "
4786 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4789 bool isIgnoreDenormalMode;
4790 bool isFineGrainedMemory;
4791 bool isRemoteMemory;
4796 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4797 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4798 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4799 atomicOrdering, binop, updateFn,
4800 isXBinopExpr, isIgnoreDenormalMode,
4801 isFineGrainedMemory, isRemoteMemory);
4806 builder.restoreIP(*afterIP);
4812 llvm::IRBuilderBase &builder,
4819 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4820 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4822 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4823 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4825 assert((atomicUpdateOp || atomicWriteOp) &&
4826 "internal op must be an atomic.update or atomic.write op");
4828 if (atomicWriteOp) {
4829 isPostfixUpdate =
true;
4830 mlirExpr = atomicWriteOp.getExpr();
4832 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4833 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4834 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4837 if (innerOpList.size() == 2) {
4840 atomicUpdateOp.getRegion().getArgument(0))) {
4841 return atomicUpdateOp.emitError(
4842 "no atomic update operation with region argument"
4843 " as operand found inside atomic.update region");
4847 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4850 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4854 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4855 llvm::Value *llvmX =
4856 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4857 llvm::Value *llvmV =
4858 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4859 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4860 atomicCaptureOp.getAtomicReadOp().getElementType());
4861 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4864 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4868 llvm::AtomicOrdering atomicOrdering =
4872 [&](llvm::Value *atomicx,
4875 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4876 Block &bb = *atomicUpdateOp.getRegion().
begin();
4877 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4879 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4880 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4881 return llvm::make_error<PreviouslyReportedError>();
4883 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4884 assert(yieldop && yieldop.getResults().size() == 1 &&
4885 "terminator must be omp.yield op and it must have exactly one "
4887 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4890 bool isIgnoreDenormalMode;
4891 bool isFineGrainedMemory;
4892 bool isRemoteMemory;
4894 isFineGrainedMemory, isRemoteMemory);
4897 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4898 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4899 ompBuilder->createAtomicCapture(
4900 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4901 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4902 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4904 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4907 builder.restoreIP(*afterIP);
4913static std::optional<llvm::omp::OMPAtomicCompareOp>
4915 switch (predicate) {
4916 case LLVM::ICmpPredicate::eq:
4917 return llvm::omp::OMPAtomicCompareOp::EQ;
4918 case LLVM::ICmpPredicate::slt:
4919 case LLVM::ICmpPredicate::ult:
4920 return llvm::omp::OMPAtomicCompareOp::MIN;
4921 case LLVM::ICmpPredicate::sgt:
4922 case LLVM::ICmpPredicate::ugt:
4923 return llvm::omp::OMPAtomicCompareOp::MAX;
4925 return std::nullopt;
4931static std::optional<llvm::omp::OMPAtomicCompareOp>
4933 switch (predicate) {
4934 case LLVM::FCmpPredicate::oeq:
4935 case LLVM::FCmpPredicate::ueq:
4936 return llvm::omp::OMPAtomicCompareOp::EQ;
4937 case LLVM::FCmpPredicate::olt:
4938 case LLVM::FCmpPredicate::ult:
4939 return llvm::omp::OMPAtomicCompareOp::MIN;
4940 case LLVM::FCmpPredicate::ogt:
4941 case LLVM::FCmpPredicate::ugt:
4942 return llvm::omp::OMPAtomicCompareOp::MAX;
4944 return std::nullopt;
4966 llvm::IRBuilderBase &builder,
4972 Region ®ion = atomicCompareOp.getRegion();
4976 llvm::Type *llvmXElementType =
4978 if (!llvmXElementType)
4979 return atomicCompareOp.emitError(
4980 "unable to determine element type for atomic compare");
4982 llvm::Value *llvmX = moduleTranslation.
lookupValue(atomicCompareOp.getX());
4987 bool isSigned =
false;
4988 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4992 llvm::AtomicOrdering atomicOrdering =
4995 auto isAtomicComparePatternOp = [](
Operation &op) {
4996 return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
5017 if (isAtomicComparePatternOp(op))
5022 return moduleTranslation.lookupValue(v) != nullptr;
5024 if (!allOperandsMapped)
5028 return atomicCompareOp.emitError(
5029 "failed to translate operation inside atomic compare region");
5034 auto materializeValue = [&](
mlir::Value val) -> llvm::Value * {
5036 if (llvm::Value *existing = moduleTranslation.
lookupValue(val))
5041 if (loadOp->getParentRegion() == ®ion) {
5042 llvm::Value *loadAddr = moduleTranslation.
lookupValue(loadOp.getAddr());
5045 llvm::Type *loadType =
5046 moduleTranslation.
convertType(loadOp.getResult().getType());
5047 return builder.CreateLoad(loadType, loadAddr);
5055 llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5056 llvm::Value *eVal =
nullptr;
5057 llvm::Value *dVal =
nullptr;
5058 bool isXBinopExpr =
false;
5061 if (
auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
5062 return extractOp.getContainer();
5076 bool isComplexPattern =
false;
5078 if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
5084 if (!lhsFcmp || !rhsFcmp)
5089 mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
5090 mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
5091 bool lhsXIsOp0 = (lhsAgg0 == block.
getArgument(0));
5092 bool lhsXIsOp1 = (lhsAgg1 == block.
getArgument(0));
5093 if (!lhsXIsOp0 && !lhsXIsOp1)
5095 mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
5099 if (isa<LLVM::AndOp>(op))
5100 compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5103 return atomicCompareOp.emitError(
5104 "unsupported comparison predicate (NE) for complex atomic compare");
5106 isXBinopExpr = lhsXIsOp0;
5107 eVal = materializeValue(eAggregate);
5108 isComplexPattern =
true;
5112 if (isComplexPattern) {
5115 if (
auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5116 dVal = materializeValue(selectOp.getTrueValue());
5122 if (yieldOp.getResults().empty())
5123 return atomicCompareOp.emitError(
5124 "failed to extract desired value (d) from atomic compare region");
5125 dVal = materializeValue(yieldOp.getResults()[0]);
5128 const llvm::DataLayout &DL =
5129 builder.GetInsertBlock()->getModule()->getDataLayout();
5130 unsigned totalBits =
5131 DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
5133 llvm::IntegerType *intTy =
5134 llvm::IntegerType::get(builder.getContext(), totalBits);
5136 llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
5137 llvm::Align intAlign = DL.getABITypeAlign(intTy);
5138 llvm::Align maxAlign = std::max(complexAlign, intAlign);
5140 llvm::AllocaInst *eAlloca =
5141 builder.CreateAlloca(llvmXElementType,
nullptr,
"cmplx.e");
5142 eAlloca->setAlignment(maxAlign);
5143 llvm::AllocaInst *dAlloca =
5144 builder.CreateAlloca(llvmXElementType,
nullptr,
"cmplx.d");
5145 dAlloca->setAlignment(maxAlign);
5147 builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
5149 builder.CreateAlignedLoad(intTy, eAlloca, maxAlign,
"cmplx.e.int");
5150 builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
5152 builder.CreateAlignedLoad(intTy, dAlloca, maxAlign,
"cmplx.d.int");
5154 llvm::AtomicOrdering failOrdering =
5155 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
5156 auto *cmpXchg = builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign,
5157 atomicOrdering, failOrdering);
5158 cmpXchg->setWeak(atomicCompareOp.getWeak());
5162 if (atomicOrdering == llvm::AtomicOrdering::Release ||
5163 atomicOrdering == llvm::AtomicOrdering::AcquireRelease ||
5164 atomicOrdering == llvm::AtomicOrdering::SequentiallyConsistent) {
5165 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5166 ompBuilder->createFlush(ompLoc);
5172 if (
auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
5176 return atomicCompareOp.emitError(
5177 "unsupported comparison predicate in atomic compare");
5178 compareOp = *maybeOp;
5180 LLVM::ICmpPredicate pred = icmpOp.getPredicate();
5181 isSigned = (pred == LLVM::ICmpPredicate::slt ||
5182 pred == LLVM::ICmpPredicate::sgt ||
5183 pred == LLVM::ICmpPredicate::sle ||
5184 pred == LLVM::ICmpPredicate::sge);
5187 isXBinopExpr = (icmpOp.getOperand(0) == block.
getArgument(0));
5189 isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
5190 eVal = materializeValue(eOperand);
5191 }
else if (
auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
5195 return atomicCompareOp.emitError(
5196 "unsupported comparison predicate in atomic compare");
5197 compareOp = *maybeOp;
5199 isXBinopExpr = (fcmpOp.getOperand(0) == block.
getArgument(0));
5201 isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
5202 eVal = materializeValue(eOperand);
5203 }
else if (
auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5205 dVal = materializeValue(selectOp.getTrueValue());
5213 if (
auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5214 dVal = materializeValue(selectOp.getTrueValue());
5221 return atomicCompareOp.emitError(
5222 "failed to extract expected value (e) from atomic compare region");
5226 if (yieldOp.getResults().empty())
5227 return atomicCompareOp.emitError(
5228 "failed to extract desired value (d) from atomic compare region");
5229 dVal = materializeValue(yieldOp.getResults()[0]);
5232 llvmAtomicX.IsSigned = isSigned;
5234 llvm::OpenMPIRBuilder::AtomicOpValue vOpVal = {
nullptr,
nullptr,
false,
5236 llvm::OpenMPIRBuilder::AtomicOpValue rOpVal = {
nullptr,
nullptr,
false,
5238 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5240 bool isWeak = atomicCompareOp.getWeak();
5242 bool savedHandleFPNegZero = ompBuilder->setHandleFPNegZero(
true);
5243 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5244 ompBuilder->createAtomicCompare(ompLoc, llvmAtomicX, vOpVal, rOpVal, eVal,
5245 dVal, atomicOrdering, compareOp,
5246 isXBinopExpr,
false,
false, isWeak);
5247 ompBuilder->setHandleFPNegZero(savedHandleFPNegZero);
5249 if (failed(
handleError(afterIP, *atomicCompareOp)))
5252 builder.restoreIP(*afterIP);
5257 omp::ClauseCancellationConstructType directive) {
5258 switch (directive) {
5259 case omp::ClauseCancellationConstructType::Loop:
5260 return llvm::omp::Directive::OMPD_for;
5261 case omp::ClauseCancellationConstructType::Parallel:
5262 return llvm::omp::Directive::OMPD_parallel;
5263 case omp::ClauseCancellationConstructType::Sections:
5264 return llvm::omp::Directive::OMPD_sections;
5265 case omp::ClauseCancellationConstructType::Taskgroup:
5266 return llvm::omp::Directive::OMPD_taskgroup;
5268 llvm_unreachable(
"Unhandled cancellation construct type");
5277 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5280 llvm::Value *ifCond =
nullptr;
5281 if (
Value ifVar = op.getIfExpr())
5284 llvm::omp::Directive cancelledDirective =
5287 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5288 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
5290 if (failed(
handleError(afterIP, *op.getOperation())))
5293 builder.restoreIP(afterIP.get());
5300 llvm::IRBuilderBase &builder,
5305 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5308 llvm::omp::Directive cancelledDirective =
5311 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5312 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
5314 if (failed(
handleError(afterIP, *op.getOperation())))
5317 builder.restoreIP(afterIP.get());
5327 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5329 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
5334 Value symAddr = threadprivateOp.getSymAddr();
5337 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
5340 if (!isa<LLVM::AddressOfOp>(symOp))
5341 return opInst.
emitError(
"Addressing symbol not found");
5342 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
5344 LLVM::GlobalOp global =
5345 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
5346 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
5347 llvm::Type *type = globalValue->getValueType();
5348 llvm::TypeSize typeSize =
5349 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
5351 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
5352 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
5353 ompLoc, globalValue, size, global.getSymName() +
".cache");
5359static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
5361 switch (deviceClause) {
5362 case mlir::omp::DeclareTargetDeviceType::host:
5363 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
5365 case mlir::omp::DeclareTargetDeviceType::nohost:
5366 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
5368 case mlir::omp::DeclareTargetDeviceType::any:
5369 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
5372 llvm_unreachable(
"unhandled device clause");
5375static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
5377 mlir::omp::DeclareTargetCaptureClause captureClause) {
5378 switch (captureClause) {
5379 case mlir::omp::DeclareTargetCaptureClause::to:
5380 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
5381 case mlir::omp::DeclareTargetCaptureClause::link:
5382 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
5383 case mlir::omp::DeclareTargetCaptureClause::enter:
5384 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
5385 case mlir::omp::DeclareTargetCaptureClause::none:
5386 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
5388 llvm_unreachable(
"unhandled capture clause");
5393 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5395 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
5396 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
5397 return modOp.lookupSymbol(addressOfOp.getGlobalName());
5404 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5405 value = addrCast.getOperand();
5422static llvm::SmallString<64>
5424 llvm::OpenMPIRBuilder &ompBuilder,
5425 llvm::vfs::FileSystem &vfs) {
5427 llvm::raw_svector_ostream os(suffix);
5430 auto fileInfoCallBack = [&loc]() {
5431 return std::pair<std::string, uint64_t>(
5432 llvm::StringRef(loc.getFilename()), loc.getLine());
5437 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5439 os <<
"_decl_tgt_ref_ptr";
5445 if (
auto declareTargetGlobal =
5446 dyn_cast_if_present<omp::DeclareTargetInterface>(
5448 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5449 omp::DeclareTargetCaptureClause::link)
5455 if (
auto declareTargetGlobal =
5456 dyn_cast_if_present<omp::DeclareTargetInterface>(
5458 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5459 omp::DeclareTargetCaptureClause::to ||
5460 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5461 omp::DeclareTargetCaptureClause::enter)
5479 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5483 if (gOp.getSymName().contains(suffix))
5488 (gOp.getSymName().str() + suffix.str()).str());
5496struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5497 SmallVector<Operation *, 4> Mappers;
5500 void append(MapInfosTy &curInfo) {
5501 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5502 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5511struct MapInfoData : MapInfosTy {
5512 llvm::SmallVector<bool, 4> IsDeclareTarget;
5513 llvm::SmallVector<bool, 4> IsAMember;
5515 llvm::SmallVector<bool, 4> IsAMapping;
5516 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5517 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5520 llvm::SmallVector<llvm::Type *, 4> BaseType;
5523 void append(MapInfoData &CurInfo) {
5524 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5525 CurInfo.IsDeclareTarget.end());
5526 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5527 OriginalValue.append(CurInfo.OriginalValue.begin(),
5528 CurInfo.OriginalValue.end());
5529 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5530 MapInfosTy::append(CurInfo);
5534enum class TargetDirectiveEnumTy : uint32_t {
5538 TargetEnterData = 3,
5543static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5544 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5545 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
5546 .Case([](omp::TargetEnterDataOp) {
5547 return TargetDirectiveEnumTy::TargetEnterData;
5549 .Case([&](omp::TargetExitDataOp) {
5550 return TargetDirectiveEnumTy::TargetExitData;
5552 .Case([&](omp::TargetUpdateOp) {
5553 return TargetDirectiveEnumTy::TargetUpdate;
5555 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
5556 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
5563 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5564 arrTy.getElementType()))
5581 llvm::Value *basePointer,
5582 llvm::Type *baseType,
5583 llvm::IRBuilderBase &builder,
5585 if (
auto memberClause =
5586 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5591 if (!memberClause.getBounds().empty()) {
5592 llvm::Value *elementCount = builder.getInt64(1);
5593 for (
auto bounds : memberClause.getBounds()) {
5594 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5595 bounds.getDefiningOp())) {
5600 elementCount = builder.CreateMul(
5604 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
5605 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
5606 builder.getInt64(1)));
5613 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5621 return builder.CreateMul(elementCount,
5622 builder.getInt64(underlyingTypeSzInBits / 8));
5633static llvm::omp::OpenMPOffloadMappingFlags
5635 const bool hasExplicitMap =
5636 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
5637 omp::ClauseMapFlags::none;
5639 llvm::omp::OpenMPOffloadMappingFlags mapType =
5640 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5642 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::to))
5643 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5645 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::from))
5646 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5648 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::always))
5649 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5651 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::del))
5652 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5654 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::return_param))
5655 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5657 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::priv))
5658 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5660 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::literal))
5661 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5663 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::implicit))
5664 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5666 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::close))
5667 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5669 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::present))
5670 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5672 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::ompx_hold))
5673 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5675 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::attach))
5676 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5678 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::is_device_ptr)) {
5679 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5680 if (!hasExplicitMap)
5681 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5691 ArrayRef<Value> useDevAddrOperands = {},
5692 ArrayRef<Value> hasDevAddrOperands = {}) {
5694 auto checkRefPtrOrPteeMapWithAttach = [](omp::ClauseMapFlags mapType) {
5696 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptr) ||
5697 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptee);
5698 return hasRefType &&
5699 bitEnumContainsAll(mapType, omp::ClauseMapFlags::attach);
5702 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
5710 for (Value mapValue : mapVars) {
5711 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5712 for (
auto member : map.getMembers())
5713 if (member == mapOp)
5720 for (Value mapValue : mapVars) {
5721 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5722 bool isRefPtrOrPteeMapWithAttach =
5723 checkRefPtrOrPteeMapWithAttach(mapOp.getMapType());
5724 Value offloadPtr = (mapOp.getVarPtrPtr() && !isRefPtrOrPteeMapWithAttach)
5725 ? mapOp.getVarPtrPtr()
5726 : mapOp.getVarPtr();
5727 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
5728 mapData.Pointers.push_back(
5729 isRefPtrOrPteeMapWithAttach
5730 ? moduleTranslation.
lookupValue(mapOp.getVarPtrPtr())
5731 : mapData.OriginalValue.back());
5733 if (llvm::Value *refPtr =
5735 mapData.IsDeclareTarget.push_back(
true);
5736 mapData.BasePointers.push_back(refPtr);
5738 mapData.IsDeclareTarget.push_back(
true);
5739 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5741 mapData.IsDeclareTarget.push_back(
false);
5742 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5748 mapData.BaseType.push_back(moduleTranslation.
convertType(
5749 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5750 : mapOp.getVarPtrType()));
5757 mlir::Type sizeType = (isRefPtrOrPteeMapWithAttach || !mapOp.getVarPtrPtr())
5758 ? mapOp.getVarPtrType()
5759 : mapOp.getVarPtrPtrType().value();
5761 dl, sizeType, isRefPtrOrPteeMapWithAttach ?
nullptr : mapOp,
5762 mapData.Pointers.back(), moduleTranslation.
convertType(sizeType),
5763 builder, moduleTranslation));
5764 mapData.MapClause.push_back(mapOp.getOperation());
5768 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5769 if (mapOp.getMapperId())
5770 mapData.Mappers.push_back(
5772 mapOp, mapOp.getMapperIdAttr()));
5774 mapData.Mappers.push_back(
nullptr);
5775 mapData.IsAMapping.push_back(
true);
5776 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5779 auto findMapInfo = [&mapData](llvm::Value *val,
5780 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
5781 size_t memberCount) {
5784 for (llvm::Value *basePtr : mapData.OriginalValue) {
5785 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[index]);
5796 (mapData.Types[index] &
5797 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
5798 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5799 if (!isAttachMap && basePtr == val && mapData.IsAMapping[index] &&
5800 memberCount == mapOp.getMembers().size()) {
5802 mapData.Types[index] |=
5803 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5804 mapData.DevicePointers[index] = devInfoTy;
5812 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
5813 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5814 for (Value mapValue : useDevOperands) {
5815 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5817 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5818 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5821 if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
5822 mapData.OriginalValue.push_back(origValue);
5823 mapData.Pointers.push_back(mapData.OriginalValue.back());
5824 mapData.IsDeclareTarget.push_back(
false);
5825 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5826 mlir::Type baseTy = mapOp.getVarPtrPtr()
5827 ? mapOp.getVarPtrPtrType().value()
5828 : mapOp.getVarPtrType();
5829 mapData.BaseType.push_back(moduleTranslation.
convertType(baseTy));
5830 mapData.Sizes.push_back(builder.getInt64(0));
5831 mapData.MapClause.push_back(mapOp.getOperation());
5832 mapData.Types.push_back(
5833 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5836 mapData.DevicePointers.push_back(devInfoTy);
5837 mapData.Mappers.push_back(
nullptr);
5838 mapData.IsAMapping.push_back(
false);
5839 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5844 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5845 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5847 for (Value mapValue : hasDevAddrOperands) {
5848 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5850 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5851 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5853 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5855 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5856 omp::ClauseMapFlags::none;
5858 mapData.OriginalValue.push_back(origValue);
5859 mapData.BasePointers.push_back(origValue);
5860 mapData.Pointers.push_back(origValue);
5861 mapData.IsDeclareTarget.push_back(
false);
5863 mlir::Type baseTy = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5864 : mapOp.getVarPtrType();
5865 mapData.BaseType.push_back(moduleTranslation.
convertType(baseTy));
5866 mapData.Sizes.push_back(builder.getInt64(dl.
getTypeSize(baseTy)));
5868 mapData.MapClause.push_back(mapOp.getOperation());
5869 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5873 mapData.Types.push_back(mapType);
5877 if (mapOp.getMapperId()) {
5878 mapData.Mappers.push_back(
5880 mapOp, mapOp.getMapperIdAttr()));
5882 mapData.Mappers.push_back(
nullptr);
5887 mapData.Types.push_back(
5888 isDevicePtr ? mapType
5889 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5890 mapData.Mappers.push_back(
nullptr);
5894 mapData.DevicePointers.push_back(
5895 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5896 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5897 mapData.IsAMapping.push_back(
false);
5898 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5903 auto *res = llvm::find(mapData.MapClause, memberOp);
5904 assert(res != mapData.MapClause.end() &&
5905 "MapInfoOp for member not found in MapData, cannot return index");
5906 return std::distance(mapData.MapClause.begin(), res);
5910 omp::MapInfoOp mapInfo,
bool first =
true) {
5911 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5921 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5922 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5924 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5925 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5926 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5928 if (aIndex == bIndex)
5931 if (aIndex < bIndex)
5934 if (aIndex > bIndex)
5941 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5943 occludedChildren.push_back(
b);
5945 occludedChildren.push_back(a);
5946 return memberAParent;
5949 for (
auto v : occludedChildren)
5956 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5958 if (indexAttr.size() == 1)
5959 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5963 return llvm::cast<omp::MapInfoOp>(
5964 mapInfo.getMembers()[
indices.front()].getDefiningOp());
5987static std::vector<llvm::Value *>
5989 llvm::IRBuilderBase &builder,
bool isArrayTy,
5991 std::vector<llvm::Value *> idx;
6002 idx.push_back(builder.getInt64(0));
6003 for (
int i = bounds.size() - 1; i >= 0; --i) {
6004 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
6005 bounds[i].getDefiningOp())) {
6006 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
6024 std::vector<llvm::Value *> dimensionIndexSizeOffset;
6025 for (
int i = bounds.size() - 1; i >= 0; --i) {
6026 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
6027 bounds[i].getDefiningOp())) {
6028 if (i == ((
int)bounds.size() - 1))
6030 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
6032 idx.back() = builder.CreateAdd(
6033 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
6034 boundOp.getExtent())),
6035 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
6044 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
6045 return cast<IntegerAttr>(value).getInt();
6053 omp::MapInfoOp parentOp) {
6055 if (parentOp.getMembers().empty())
6059 if (parentOp.getMembers().size() == 1) {
6060 overlapMapDataIdxs.push_back(0);
6064 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
6065 size_t numMembers = indexAttr.size();
6069 for (
auto [i, indicesAttr] : llvm::enumerate(indexAttr))
6070 getAsIntegers(cast<ArrayAttr>(indicesAttr), memberIndices[i]);
6076 llvm::SmallDenseSet<size_t> skipIndices;
6077 for (
size_t i = 0; i < numMembers; ++i) {
6078 const auto &iIndices = memberIndices[i];
6079 for (
size_t j = 0;
j < numMembers; ++
j) {
6082 const auto &jIndices = memberIndices[
j];
6084 if (jIndices.size() < iIndices.size() &&
6085 std::equal(jIndices.begin(), jIndices.end(), iIndices.begin())) {
6086 skipIndices.insert(i);
6093 for (
size_t i = 0; i < numMembers; ++i)
6094 if (!skipIndices.contains(i))
6095 overlapMapDataIdxs.push_back(i);
6107 if (mapOp.getVarPtrPtr())
6130 llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData,
6131 size_t mapDataIdx, MapInfosTy &combinedInfo,
6132 TargetDirectiveEnumTy targetDirective,
6133 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6134 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6135 bool isTargetParam =
true,
int mapDataParentIdx = -1) {
6136 auto mapFlag = mapData.Types[mapDataIdx];
6137 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
6141 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6142 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6148 if (isTargetParam &&
6149 (targetDirective == TargetDirectiveEnumTy::Target &&
6150 !mapData.IsDeclareTarget[mapDataIdx]) &&
6152 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
6154 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
6156 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
6165 if (memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE) {
6166 if (!isPtrTy && !isAttachMap)
6167 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6174 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6184 if (isPtrTy && !isAttachMap && mapData.IsDeclareTarget[mapDataIdx])
6185 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
6194 !bitEnumContainsAll(mapInfoOp.getMapType(),
6195 omp::ClauseMapFlags::ref_ptr) &&
6196 bitEnumContainsAll(mapInfoOp.getMapType(), omp::ClauseMapFlags::ref_ptee);
6197 bool isRefPtrPtee = bitEnumContainsAll(mapInfoOp.getMapType(),
6198 omp::ClauseMapFlags::ref_ptr |
6199 omp::ClauseMapFlags::ref_ptee);
6201 if (!mapInfoOp->getParentOfType<omp::DeclareMapperOp>() &&
6202 mapDataParentIdx >= 0 && !(isRefPtee || (isRefPtrPtee && isPtrTy))) {
6203 combinedInfo.BasePointers.emplace_back(
6204 mapData.BasePointers[mapDataParentIdx]);
6206 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
6209 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
6210 combinedInfo.DevicePointers.emplace_back(
6211 memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE
6212 ? llvm::OpenMPIRBuilder::DeviceInfoTy::None
6213 : mapData.DevicePointers[mapDataIdx]);
6214 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
6215 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
6216 combinedInfo.Types.emplace_back(mapFlag);
6217 combinedInfo.Sizes.emplace_back(
6218 isPtrTy ? builder.CreateSelect(
6219 builder.CreateIsNull(mapData.Pointers[mapDataIdx]),
6220 builder.getInt64(0), mapData.Sizes[mapDataIdx])
6221 : mapData.Sizes[mapDataIdx]);
6241 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
6242 MapInfoData &mapData, uint64_t mapDataIndex,
6243 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
6244 TargetDirectiveEnumTy targetDirective) {
6245 using MapFlags = llvm::omp::OpenMPOffloadMappingFlags;
6246 assert(!ompBuilder.Config.isTargetDevice() &&
6247 "function only supported for host device codegen");
6249 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6250 auto *parentMapper = mapData.Mappers[mapDataIndex];
6256 MapFlags baseFlag = (targetDirective == TargetDirectiveEnumTy::Target &&
6257 !mapData.IsDeclareTarget[mapDataIndex])
6258 ? MapFlags::OMP_MAP_TARGET_PARAM
6259 : MapFlags::OMP_MAP_NONE;
6265 MapFlags parentFlags = mapData.Types[mapDataIndex];
6266 MapFlags preserve = MapFlags::OMP_MAP_TO | MapFlags::OMP_MAP_FROM |
6267 MapFlags::OMP_MAP_ALWAYS | MapFlags::OMP_MAP_CLOSE |
6268 MapFlags::OMP_MAP_PRESENT |
6269 MapFlags::OMP_MAP_OMPX_HOLD |
6270 MapFlags::OMP_MAP_IMPLICIT;
6271 baseFlag |= (parentFlags & preserve);
6273 MapFlags parentFlags = mapData.Types[mapDataIndex];
6275 MapFlags::OMP_MAP_PRESENT | MapFlags::OMP_MAP_RETURN_PARAM;
6276 baseFlag |= (parentFlags & preserve);
6279 combinedInfo.Types.emplace_back(baseFlag);
6280 combinedInfo.DevicePointers.emplace_back(
6281 mapData.DevicePointers[mapDataIndex]);
6285 combinedInfo.Mappers.emplace_back(
6286 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
6288 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6289 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
6298 llvm::Value *lowAddr, *highAddr;
6299 if (!parentClause.getPartialMap()) {
6300 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6301 builder.getPtrTy());
6302 highAddr = builder.CreatePointerCast(
6303 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6304 mapData.Pointers[mapDataIndex], 1),
6305 builder.getPtrTy());
6306 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6308 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6311 lowAddr = builder.CreatePointerCast(mapData.BasePointers[firstMemberIdx],
6312 builder.getPtrTy());
6316 auto lastMemberMapInfo =
6317 cast<omp::MapInfoOp>(mapData.MapClause[lastMemberIdx]);
6326 bool isRefPteeMap = bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6327 omp::ClauseMapFlags::ref_ptee) &&
6328 !bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6329 omp::ClauseMapFlags::ref_ptr);
6330 llvm::Type *castType = mapData.BaseType[lastMemberIdx];
6333 moduleTranslation.
convertType(lastMemberMapInfo.getVarPtrType());
6334 highAddr = builder.CreatePointerCast(
6335 builder.CreateGEP(castType, mapData.BasePointers[lastMemberIdx],
6336 builder.getInt64(1)),
6337 builder.getPtrTy());
6338 combinedInfo.Pointers.emplace_back(mapData.BasePointers[firstMemberIdx]);
6341 llvm::Value *size = builder.CreateIntCast(
6342 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6343 builder.getInt64Ty(),
6345 combinedInfo.Sizes.push_back(size);
6353 if (!parentClause.getPartialMap()) {
6358 MapFlags mapFlag = mapData.Types[mapDataIndex];
6359 bool hasMapClose = (MapFlags(mapFlag) & MapFlags::OMP_MAP_CLOSE) ==
6360 MapFlags::OMP_MAP_CLOSE;
6361 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6377 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose ||
6378 overlapIdxs.size() == 1) {
6379 combinedInfo.Types.emplace_back(mapFlag);
6380 combinedInfo.DevicePointers.emplace_back(
6381 mapData.DevicePointers[mapDataIndex]);
6383 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6384 combinedInfo.BasePointers.emplace_back(
6385 mapData.BasePointers[mapDataIndex]);
6386 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6387 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
6388 combinedInfo.Mappers.emplace_back(
nullptr);
6394 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6395 builder.getPtrTy());
6396 highAddr = builder.CreatePointerCast(
6397 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6398 mapData.Pointers[mapDataIndex], 1),
6399 builder.getPtrTy());
6406 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6413 for (
auto v : overlapIdxs) {
6416 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
6418 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataOverlapIdx]));
6419 combinedInfo.Types.emplace_back(mapFlag);
6420 combinedInfo.DevicePointers.emplace_back(
6421 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6423 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6424 combinedInfo.BasePointers.emplace_back(
6425 mapData.BasePointers[mapDataIndex]);
6426 combinedInfo.Mappers.emplace_back(
nullptr);
6427 combinedInfo.Pointers.emplace_back(lowAddr);
6428 auto sizeCalc = builder.CreateIntCast(
6429 builder.CreatePtrDiff(builder.getInt8Ty(),
6430 mapData.OriginalValue[mapDataOverlapIdx],
6432 builder.getInt64Ty(),
true);
6437 auto sizeSel = builder.CreateSelect(
6438 builder.CreateICmpNE(builder.getInt64(0), sizeCalc), sizeCalc,
6439 isPtrMap ? llvm::ConstantExpr::getSizeOf(builder.getPtrTy())
6440 : mapData.Sizes[mapDataOverlapIdx]);
6441 combinedInfo.Sizes.emplace_back(sizeSel);
6442 lowAddr = builder.CreateConstGEP1_32(
6443 isPtrMap ? builder.getPtrTy() : mapData.BaseType[mapDataOverlapIdx],
6444 mapData.BasePointers[mapDataOverlapIdx], 1);
6447 combinedInfo.Types.emplace_back(mapFlag);
6448 combinedInfo.DevicePointers.emplace_back(
6449 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6451 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6452 combinedInfo.BasePointers.emplace_back(
6453 mapData.BasePointers[mapDataIndex]);
6454 combinedInfo.Mappers.emplace_back(
nullptr);
6455 combinedInfo.Pointers.emplace_back(lowAddr);
6456 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
6457 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6458 builder.getInt64Ty(),
true));
6464 llvm::IRBuilderBase &builder,
6465 llvm::OpenMPIRBuilder &ompBuilder,
6467 MapInfoData &mapData, uint64_t mapDataIndex,
6468 TargetDirectiveEnumTy targetDirective) {
6469 assert(!ompBuilder.Config.isTargetDevice() &&
6470 "function only supported for host device codegen");
6473 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6478 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6479 auto memberClause = llvm::cast<omp::MapInfoOp>(
6480 parentClause.getMembers()[0].getDefiningOp());
6493 builder, ompBuilder, mapData, memberDataIdx, combinedInfo,
6495 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6496 true, mapDataIndex);
6500 auto collectMapInfoIdxs =
6503 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6505 for (
auto member : parentClause.getMembers())
6507 mapData, llvm::cast<omp::MapInfoOp>(member.getDefiningOp())));
6511 collectMapInfoIdxs(mapInfoIdx);
6513 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6514 ompBuilder.getMemberOfFlag(combinedInfo.Types.size());
6515 for (
size_t i = 0; i < mapInfoIdx.size(); i++) {
6520 combinedInfo, mapData, mapInfoIdx[i], memberOfFlag,
6524 combinedInfo, targetDirective, memberOfFlag,
6525 false, mapDataIndex);
6537 llvm::IRBuilderBase &builder) {
6539 "function only supported for host device codegen");
6540 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6541 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6544 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6545 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6550 if (!mapData.IsDeclareTarget[i] ||
6551 (mapData.IsDeclareTarget[i] && isAttachMap)) {
6552 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6562 switch (captureKind) {
6563 case omp::VariableCaptureKind::ByRef: {
6564 llvm::Value *newV = mapData.Pointers[i];
6566 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6569 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6571 if (!offsetIdx.empty())
6572 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6574 mapData.Pointers[i] = newV;
6576 case omp::VariableCaptureKind::ByCopy: {
6577 llvm::Type *type = mapData.BaseType[i];
6579 if (mapData.Pointers[i]->getType()->isPointerTy())
6580 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6582 newV = mapData.Pointers[i];
6585 auto curInsert = builder.saveIP();
6586 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6588 auto *memTempAlloc =
6589 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
6590 builder.SetCurrentDebugLocation(DbgLoc);
6591 builder.restoreIP(curInsert);
6593 builder.CreateStore(newV, memTempAlloc);
6594 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6597 mapData.Pointers[i] = newV;
6598 mapData.BasePointers[i] = newV;
6600 case omp::VariableCaptureKind::This:
6601 case omp::VariableCaptureKind::VLAType:
6602 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
6613 MapInfoData &mapData,
6614 TargetDirectiveEnumTy targetDirective) {
6616 "function only supported for host device codegen");
6637 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6638 if (mapData.IsAMember[i])
6641 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6642 if (!mapInfoOp.getMembers().empty()) {
6644 combinedInfo, mapData, i, targetDirective);
6653static llvm::Expected<llvm::Function *>
6655 LLVM::ModuleTranslation &moduleTranslation,
6656 llvm::StringRef mapperFuncName,
6657 TargetDirectiveEnumTy targetDirective);
6659static llvm::Expected<llvm::Function *>
6662 TargetDirectiveEnumTy targetDirective) {
6664 "function only supported for host device codegen");
6665 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6666 std::string mapperFuncName =
6668 {
"omp_mapper", declMapperOp.getSymName()});
6670 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
6678 if (llvm::Function *existingFunc =
6679 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
6680 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
6681 return existingFunc;
6685 mapperFuncName, targetDirective);
6688static llvm::Expected<llvm::Function *>
6691 llvm::StringRef mapperFuncName,
6692 TargetDirectiveEnumTy targetDirective) {
6694 "function only supported for host device codegen");
6695 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6696 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6698 return llvm::make_error<PreviouslyReportedError>();
6702 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
6705 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6708 MapInfosTy combinedInfo;
6710 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6711 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6712 builder.restoreIP(codeGenIP);
6713 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
6714 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
6715 builder.GetInsertBlock());
6716 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
6719 return llvm::make_error<PreviouslyReportedError>();
6720 MapInfoData mapData;
6723 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6729 return combinedInfo;
6733 if (!combinedInfo.Mappers[i])
6736 moduleTranslation, targetDirective);
6740 genMapInfoCB, varType, mapperFuncName, customMapperCB,
6743 return newFn.takeError();
6744 if ([[maybe_unused]] llvm::Function *mappedFunc =
6746 assert(mappedFunc == *newFn &&
6747 "mapper function mapping disagrees with emitted function");
6749 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
6757 llvm::Value *ifCond =
nullptr;
6758 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6762 llvm::omp::RuntimeFunction RTLFn;
6764 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6767 llvm::OpenMPIRBuilder::TargetDataInfo info(
6770 assert(!ompBuilder->Config.isTargetDevice() &&
6771 "target data/enter/exit/update are host ops");
6772 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6774 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
6775 llvm::Value *v = moduleTranslation.
lookupValue(dev);
6776 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
6781 .Case([&](omp::TargetDataOp dataOp) {
6785 if (
auto ifVar = dataOp.getIfExpr())
6789 deviceID = getDeviceID(devId);
6791 mapVars = dataOp.getMapVars();
6792 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6793 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6796 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6800 if (
auto ifVar = enterDataOp.getIfExpr())
6804 deviceID = getDeviceID(devId);
6807 enterDataOp.getNowait()
6808 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6809 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6810 mapVars = enterDataOp.getMapVars();
6811 info.HasNoWait = enterDataOp.getNowait();
6814 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6818 if (
auto ifVar = exitDataOp.getIfExpr())
6822 deviceID = getDeviceID(devId);
6824 RTLFn = exitDataOp.getNowait()
6825 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6826 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6827 mapVars = exitDataOp.getMapVars();
6828 info.HasNoWait = exitDataOp.getNowait();
6831 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6835 if (
auto ifVar = updateDataOp.getIfExpr())
6839 deviceID = getDeviceID(devId);
6842 updateDataOp.getNowait()
6843 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6844 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6845 mapVars = updateDataOp.getMapVars();
6846 info.HasNoWait = updateDataOp.getNowait();
6849 .DefaultUnreachable(
"unexpected operation");
6854 if (!isOffloadEntry)
6855 ifCond = builder.getFalse();
6857 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6858 MapInfoData mapData;
6860 builder, useDevicePtrVars, useDeviceAddrVars);
6863 MapInfosTy combinedInfo;
6864 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6865 builder.restoreIP(codeGenIP);
6866 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6868 return combinedInfo;
6874 [&moduleTranslation](
6875 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6879 for (
auto [arg, useDevVar] :
6880 llvm::zip_equal(blockArgs, useDeviceVars)) {
6882 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6883 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6884 : mapInfoOp.getVarPtr();
6887 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6888 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6889 mapInfoData.MapClause, mapInfoData.DevicePointers,
6890 mapInfoData.BasePointers)) {
6891 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6892 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6893 devicePointer != type)
6896 if (llvm::Value *devPtrInfoMap =
6897 mapper ? mapper(basePointer) : basePointer) {
6898 moduleTranslation.
mapValue(arg, devPtrInfoMap);
6905 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6906 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6907 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6910 builder.restoreIP(codeGenIP);
6911 assert(isa<omp::TargetDataOp>(op) &&
6912 "BodyGen requested for non TargetDataOp");
6913 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6914 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
6915 switch (bodyGenType) {
6916 case BodyGenTy::Priv:
6918 if (!info.DevicePtrInfoMap.empty()) {
6919 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6920 blockArgIface.getUseDeviceAddrBlockArgs(),
6921 useDeviceAddrVars, mapData,
6922 [&](llvm::Value *basePointer) -> llvm::Value * {
6923 if (!info.DevicePtrInfoMap[basePointer].second)
6925 return builder.CreateLoad(
6927 info.DevicePtrInfoMap[basePointer].second);
6929 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6930 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6931 mapData, [&](llvm::Value *basePointer) {
6932 return info.DevicePtrInfoMap[basePointer].second;
6936 moduleTranslation)))
6937 return llvm::make_error<PreviouslyReportedError>();
6940 case BodyGenTy::DupNoPriv:
6941 if (info.DevicePtrInfoMap.empty()) {
6944 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6945 blockArgIface.getUseDeviceAddrBlockArgs(),
6946 useDeviceAddrVars, mapData);
6947 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6948 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6952 case BodyGenTy::NoPriv:
6954 if (info.DevicePtrInfoMap.empty()) {
6956 moduleTranslation)))
6957 return llvm::make_error<PreviouslyReportedError>();
6961 return builder.saveIP();
6964 auto customMapperCB =
6966 if (!combinedInfo.Mappers[i])
6968 info.HasMapper =
true;
6970 moduleTranslation, targetDirective);
6973 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6975 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6977 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6978 if (isa<omp::TargetDataOp>(op))
6979 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6980 deallocBlocks, deviceID, ifCond, info,
6981 genMapInfoCB, customMapperCB,
6984 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6985 deallocBlocks, deviceID, ifCond, info,
6986 genMapInfoCB, customMapperCB, &RTLFn);
6992 builder.restoreIP(*afterIP);
7000 auto distributeOp = cast<omp::DistributeOp>(opInst);
7007 bool doDistributeReduction =
7011 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
7016 if (doDistributeReduction) {
7017 isByRef =
getIsByRef(teamsOp.getReductionByref());
7018 assert(isByRef.size() == teamsOp.getNumReductionVars());
7021 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7025 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
7026 .getReductionBlockArgs();
7029 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
7030 reductionDecls, privateReductionVariables, reductionVariableMap,
7035 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7037 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7042 moduleTranslation, allocaIP, deallocBlocks);
7045 builder.restoreIP(codeGenIP);
7049 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
7051 return llvm::make_error<PreviouslyReportedError>();
7056 return llvm::make_error<PreviouslyReportedError>();
7059 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
7061 distributeOp.getPrivateNeedsBarrier())))
7062 return llvm::make_error<PreviouslyReportedError>();
7065 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7068 builder, moduleTranslation);
7070 return regionBlock.takeError();
7071 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
7076 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
7079 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
7080 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
7081 : omp::ClauseScheduleKind::Static;
7083 bool isOrdered = hasDistSchedule;
7084 std::optional<omp::ScheduleModifier> scheduleMod;
7085 bool isSimd =
false;
7086 llvm::omp::WorksharingLoopType workshareLoopType =
7087 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
7088 bool loopNeedsBarrier =
false;
7089 llvm::Value *chunk = moduleTranslation.
lookupValue(
7090 distributeOp.getDistScheduleChunkSize());
7091 llvm::CanonicalLoopInfo *loopInfo =
7093 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
7094 ompBuilder->applyWorkshareLoop(
7095 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
7096 convertToScheduleKind(schedule), chunk, isSimd,
7097 scheduleMod == omp::ScheduleModifier::monotonic,
7098 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
7099 workshareLoopType,
false, hasDistSchedule, chunk);
7102 return wsloopIP.takeError();
7105 distributeOp.getLoc(), privVarsInfo)))
7106 return llvm::make_error<PreviouslyReportedError>();
7108 return llvm::Error::success();
7112 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7114 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7115 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7116 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
7121 builder.restoreIP(*afterIP);
7123 if (doDistributeReduction) {
7126 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
7127 privateReductionVariables, isByRef,
7139 if (!cast<mlir::ModuleOp>(op))
7144 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
7145 attribute.getOpenmpDeviceVersion());
7147 if (attribute.getNoGpuLib())
7150 ompBuilder->createGlobalFlag(
7151 attribute.getDebugKind() ,
7152 "__omp_rtl_debug_kind");
7153 ompBuilder->createGlobalFlag(
7155 .getAssumeTeamsOversubscription()
7157 "__omp_rtl_assume_teams_oversubscription");
7158 ompBuilder->createGlobalFlag(
7160 .getAssumeThreadsOversubscription()
7162 "__omp_rtl_assume_threads_oversubscription");
7163 ompBuilder->createGlobalFlag(
7164 attribute.getAssumeNoThreadState() ,
7165 "__omp_rtl_assume_no_thread_state");
7166 ompBuilder->createGlobalFlag(
7168 .getAssumeNoNestedParallelism()
7170 "__omp_rtl_assume_no_nested_parallelism");
7175 omp::TargetOp targetOp,
7176 llvm::OpenMPIRBuilder &ompBuilder,
7177 llvm::vfs::FileSystem &vfs,
7178 llvm::StringRef parentName =
"") {
7179 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
7180 assert(fileLoc &&
"No file found from location");
7182 auto fileInfoCallBack = [&fileLoc]() {
7183 return std::pair<std::string, uint64_t>(
7184 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
7188 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
7194 llvm::IRBuilderBase &builder, llvm::Function *
func) {
7196 "function only supported for target device codegen");
7197 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7198 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
7211 if (!mapData.IsDeclareTarget[i])
7219 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
7220 convertUsersOfConstantsToInstructions(constant,
func,
false);
7227 for (llvm::User *user : mapData.OriginalValue[i]->users())
7228 userVec.push_back(user);
7230 for (llvm::User *user : userVec) {
7231 auto *insn = dyn_cast<llvm::Instruction>(user);
7232 if (!insn || insn->getFunction() !=
func)
7234 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7235 llvm::Value *substitute = mapData.BasePointers[i];
7237 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
7241 ->Config.hasRequiresUnifiedSharedMemory())) {
7242 builder.SetCurrentDebugLocation(insn->getDebugLoc());
7243 substitute = builder.CreateLoad(mapData.BasePointers[i]->getType(),
7244 mapData.BasePointers[i]);
7245 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
7247 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
7292 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
7293 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
7294 llvm::OpenMPIRBuilder &ompBuilder,
7296 llvm::IRBuilderBase::InsertPoint allocaIP,
7297 llvm::IRBuilderBase::InsertPoint codeGenIP,
7299 assert(ompBuilder.Config.isTargetDevice() &&
7300 "function only supported for target device codegen");
7301 builder.restoreIP(allocaIP);
7303 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
7305 ompBuilder.M.getContext());
7306 unsigned alignmentValue = 0;
7309 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
7312 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
7313 if (mapData.OriginalValue[i] == input) {
7314 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7315 capture = mapOp.getMapCaptureType();
7318 mapOp.getVarPtrType(), ompBuilder.M.getDataLayout());
7322 for (
auto &[val, arg] : blockArgsPairs) {
7323 if (mapOp.getResult() == val) {
7328 assert(mlirArg &&
"expected to find entry block argument for map clause");
7333 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
7334 unsigned int defaultAS =
7335 ompBuilder.M.getDataLayout().getProgramAddressSpace();
7338 llvm::Value *v =
nullptr;
7346 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
7347 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
7351 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7352 for (
auto deallocIP : deallocIPs) {
7353 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
7354 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
7358 v = builder.CreateAlloca(arg.getType(), allocaAS);
7360 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
7361 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
7364 builder.CreateStore(&arg, v);
7366 builder.restoreIP(codeGenIP);
7369 case omp::VariableCaptureKind::ByCopy: {
7373 case omp::VariableCaptureKind::ByRef: {
7374 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
7376 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
7391 if (v->getType()->isPointerTy() && alignmentValue) {
7392 llvm::MDBuilder MDB(builder.getContext());
7393 loadInst->setMetadata(
7394 llvm::LLVMContext::MD_align,
7395 llvm::MDNode::get(builder.getContext(),
7396 MDB.createConstant(llvm::ConstantInt::get(
7397 llvm::Type::getInt64Ty(builder.getContext()),
7404 case omp::VariableCaptureKind::This:
7405 case omp::VariableCaptureKind::VLAType:
7408 assert(
false &&
"Currently unsupported capture kind");
7412 return builder.saveIP();
7429 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
7430 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
7431 blockArgIface.getHostEvalBlockArgs())) {
7432 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
7436 .Case([&](omp::TeamsOp teamsOp) {
7437 if (teamsOp.getNumTeamsLower() == blockArg)
7438 numTeamsLower = hostEvalVar;
7439 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
7441 numTeamsUpper = hostEvalVar;
7442 else if (!teamsOp.getThreadLimitVars().empty() &&
7443 teamsOp.getThreadLimit(0) == blockArg)
7444 threadLimit = hostEvalVar;
7446 llvm_unreachable(
"unsupported host_eval use");
7448 .Case([&](omp::ParallelOp parallelOp) {
7449 if (!parallelOp.getNumThreadsVars().empty() &&
7450 parallelOp.getNumThreads(0) == blockArg)
7451 numThreads = hostEvalVar;
7453 llvm_unreachable(
"unsupported host_eval use");
7455 .Case([&](omp::LoopNestOp loopOp) {
7456 auto processBounds =
7460 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
7461 if (lb == blockArg) {
7464 (*outBounds)[i] = hostEvalVar;
7470 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
7471 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
7473 found = processBounds(loopOp.getLoopSteps(), steps) || found;
7475 assert(found &&
"unsupported host_eval use");
7477 .DefaultUnreachable(
"unsupported host_eval use");
7489template <
typename OpTy>
7494 if (OpTy casted = dyn_cast<OpTy>(op))
7497 if (immediateParent)
7498 return dyn_cast_if_present<OpTy>(op->
getParentOp());
7507 return std::nullopt;
7510 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7511 return constAttr.getInt();
7513 return std::nullopt;
7518 uint64_t sizeInBytes = sizeInBits / 8;
7522template <
typename OpTy>
7524 if (op.getNumReductionVars() > 0) {
7529 members.reserve(reductions.size());
7530 for (omp::DeclareReductionOp &red : reductions) {
7534 if (red.getByrefElementType())
7535 members.push_back(*red.getByrefElementType());
7537 members.push_back(red.getType());
7540 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7556 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7557 bool isTargetDevice,
bool isGPU) {
7560 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7561 if (!isTargetDevice) {
7569 numTeamsLower = teamsOp.getNumTeamsLower();
7571 if (!teamsOp.getNumTeamsUpperVars().empty())
7572 numTeamsUpper = teamsOp.getNumTeams(0);
7573 if (!teamsOp.getThreadLimitVars().empty())
7574 threadLimit = teamsOp.getThreadLimit(0);
7578 if (!parallelOp.getNumThreadsVars().empty())
7579 numThreads = parallelOp.getNumThreads(0);
7585 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7589 if (numTeamsUpper) {
7591 minTeamsVal = maxTeamsVal = *val;
7593 minTeamsVal = maxTeamsVal = 0;
7599 minTeamsVal = maxTeamsVal = 1;
7601 minTeamsVal = maxTeamsVal = -1;
7606 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
7620 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7621 if (!targetOp.getThreadLimitVars().empty())
7622 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7623 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7626 int32_t maxThreadsVal = -1;
7628 setMaxValueFromClause(numThreads, maxThreadsVal);
7636 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7637 if (combinedMaxThreadsVal < 0 ||
7638 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7639 combinedMaxThreadsVal = teamsThreadLimitVal;
7641 if (combinedMaxThreadsVal < 0 ||
7642 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7643 combinedMaxThreadsVal = maxThreadsVal;
7645 int32_t reductionDataSize = 0;
7646 if (isGPU && capturedOp) {
7652 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7654 case omp::TargetExecMode::bare:
7655 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7657 case omp::TargetExecMode::generic:
7658 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7660 case omp::TargetExecMode::spmd:
7661 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7663 case omp::TargetExecMode::no_loop:
7664 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7667 attrs.MinTeams = minTeamsVal;
7668 attrs.MaxTeams.front() = maxTeamsVal;
7669 attrs.MinThreads = 1;
7670 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7671 attrs.ReductionDataSize = reductionDataSize;
7674 if (attrs.ReductionDataSize != 0)
7675 attrs.ReductionBufferLength = 1024;
7687 omp::TargetOp targetOp,
Operation *capturedOp,
7688 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7690 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7692 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7696 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7699 if (!targetOp.getThreadLimitVars().empty()) {
7700 Value targetThreadLimit = targetOp.getThreadLimit(0);
7701 attrs.TargetThreadLimit.front() =
7709 attrs.MinTeams = builder.CreateSExtOrTrunc(
7710 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
7713 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7714 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
7716 if (teamsThreadLimit)
7717 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7718 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
7721 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
7723 bool hostEvalTripCount;
7724 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7725 if (hostEvalTripCount) {
7727 attrs.LoopTripCount =
nullptr;
7732 for (
auto [loopLower, loopUpper, loopStep] :
7733 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7734 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
7735 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
7736 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
7738 if (!lowerBound || !upperBound || !step) {
7739 attrs.LoopTripCount =
nullptr;
7743 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7744 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7745 loc, lowerBound, upperBound, step,
true,
7746 loopOp.getLoopInclusive());
7748 if (!attrs.LoopTripCount) {
7749 attrs.LoopTripCount = tripCount;
7754 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7759 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7761 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
7763 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7767static llvm::omp::OMPDynGroupprivateFallbackType
7769 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7770 : omp::FallbackModifier::default_mem;
7772 case omp::FallbackModifier::abort:
7773 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7774 case omp::FallbackModifier::null:
7775 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7776 case omp::FallbackModifier::default_mem:
7777 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7780 llvm_unreachable(
"unexpected dyn_groupprivate fallback type");
7786 auto targetOp = cast<omp::TargetOp>(opInst);
7791 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7800 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7801 assert(parentBB &&
"No insert block is set for the builder");
7802 llvm::Function *parentLLVMFn = parentBB->getParent();
7803 assert(parentLLVMFn &&
"Parent Function must be valid");
7804 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7805 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7806 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7807 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7810 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7811 bool isGPU = ompBuilder->Config.isGPU();
7814 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7815 auto &targetRegion = targetOp.getRegion();
7832 llvm::Function *llvmOutlinedFn =
nullptr;
7833 TargetDirectiveEnumTy targetDirective =
7834 getTargetDirectiveEnumTyFromOp(&opInst);
7838 bool isOffloadEntry =
7839 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7846 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7848 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7849 std::optional<DenseI64ArrayAttr> privateMapIndices =
7850 targetOp.getPrivateMapsAttr();
7852 for (
auto [privVarIdx, privVarSymPair] :
7853 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7854 auto privVar = std::get<0>(privVarSymPair);
7855 auto privSym = std::get<1>(privVarSymPair);
7857 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7858 omp::PrivateClauseOp privatizer =
7861 if (!privatizer.needsMap())
7865 targetOp.getMappedValueForPrivateVar(privVarIdx);
7866 assert(mappedValue &&
"Expected to find mapped value for a privatized "
7867 "variable that needs mapping");
7872 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
7873 [[maybe_unused]]
Type varType = mapInfoOp.getVarPtrType();
7877 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7879 varType == privVar.getType() &&
7880 "Type of private var doesn't match the type of the mapped value");
7884 mappedPrivateVars.insert(
7886 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7887 (*privateMapIndices)[privVarIdx])});
7891 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7892 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7894 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7895 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7896 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7899 llvm::Function *llvmParentFn =
7901 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7902 assert(llvmParentFn && llvmOutlinedFn &&
7903 "Both parent and outlined functions must exist at this point");
7905 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7906 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7908 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
7909 attr.isStringAttribute())
7910 llvmOutlinedFn->addFnAttr(attr);
7912 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
7913 attr.isStringAttribute())
7914 llvmOutlinedFn->addFnAttr(attr);
7916 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7917 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7918 llvm::Value *mapOpValue =
7919 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7920 moduleTranslation.
mapValue(arg, mapOpValue);
7922 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7923 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7924 llvm::Value *mapOpValue =
7925 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7926 moduleTranslation.
mapValue(arg, mapOpValue);
7935 privateVarsInfo, allocaIP, &mappedPrivateVars);
7938 return llvm::make_error<PreviouslyReportedError>();
7940 builder.restoreIP(codeGenIP);
7942 &mappedPrivateVars),
7945 return llvm::make_error<PreviouslyReportedError>();
7948 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
7950 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7951 return llvm::make_error<PreviouslyReportedError>();
7954 moduleTranslation, allocaIP, deallocBlocks);
7956 targetRegion,
"omp.target", builder, moduleTranslation);
7959 return llvm::make_error<PreviouslyReportedError>();
7961 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7964 targetOp.getLoc(), privateVarsInfo)))
7965 return llvm::make_error<PreviouslyReportedError>();
7967 return builder.saveIP();
7970 StringRef parentName = parentFn.getName();
7972 llvm::TargetRegionEntryInfo entryInfo;
7978 MapInfoData mapData;
7983 MapInfosTy combinedInfos;
7985 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7986 builder.restoreIP(codeGenIP);
7987 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7992 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7993 combinedInfos.BasePointers.push_back(nullPtr);
7994 combinedInfos.Pointers.push_back(nullPtr);
7995 combinedInfos.DevicePointers.push_back(
7996 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7997 combinedInfos.Sizes.push_back(builder.getInt64(0));
7998 combinedInfos.Types.push_back(
7999 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
8000 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
8001 if (!combinedInfos.Names.empty())
8002 combinedInfos.Names.push_back(nullPtr);
8003 combinedInfos.Mappers.push_back(
nullptr);
8005 return combinedInfos;
8008 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
8009 llvm::Value *&retVal, InsertPointTy allocaIP,
8010 InsertPointTy codeGenIP,
8012 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
8013 llvm::IRBuilderBase::InsertPointGuard guard(builder);
8014 builder.SetCurrentDebugLocation(llvm::DebugLoc());
8020 if (!isTargetDevice) {
8021 retVal = cast<llvm::Value>(&arg);
8026 builder, *ompBuilder, moduleTranslation,
8027 allocaIP, codeGenIP, deallocIPs);
8030 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
8031 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
8032 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
8034 isTargetDevice, isGPU);
8038 if (!isTargetDevice)
8040 targetCapturedOp, runtimeAttrs);
8048 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
8049 llvm::Value *value = moduleTranslation.
lookupValue(var);
8050 moduleTranslation.
mapValue(arg, value);
8052 if (!llvm::isa<llvm::Constant>(value))
8053 kernelInput.push_back(value);
8056 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
8065 bool isAttachMap = (mapData.Types[i] &
8066 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
8067 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
8068 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i] && !isAttachMap)
8069 kernelInput.push_back(mapData.OriginalValue[i]);
8073 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
8076 llvm::OpenMPIRBuilder::DependenciesInfo dds;
8078 targetOp.getDependVars(), targetOp.getDependKinds(),
8079 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
8080 builder, moduleTranslation, dds)))
8083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8085 llvm::OpenMPIRBuilder::TargetDataInfo info(
8089 auto customMapperCB =
8091 if (!combinedInfos.Mappers[i])
8093 info.HasMapper =
true;
8095 moduleTranslation, targetDirective);
8098 llvm::Value *ifCond =
nullptr;
8099 if (
Value targetIfCond = targetOp.getIfExpr())
8100 ifCond = moduleTranslation.
lookupValue(targetIfCond);
8102 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
8103 llvm::Value *dynSizeVal =
nullptr;
8104 if (dynGroupPrivateSize) {
8105 dynSizeVal = moduleTranslation.
lookupValue(dynGroupPrivateSize);
8106 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
8110 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
8113 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8115 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
8116 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
8117 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
8118 targetOp.getNowait(), dynSizeVal, fallbackType);
8123 builder.restoreIP(*afterIP);
8126 builder.CreateFree(dds.DepArray);
8139 llvm::OpenMPIRBuilder *ompBuilder,
8148 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
8149 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
8151 if (!offloadMod.getIsTargetDevice())
8154 omp::DeclareTargetDeviceType declareType =
8155 attribute.getDeviceType().getValue();
8157 if (declareType == omp::DeclareTargetDeviceType::host) {
8158 llvm::Function *llvmFunc =
8160 llvmFunc->dropAllReferences();
8161 llvmFunc->eraseFromParent();
8165 ompBuilder->Builder.ClearInsertionPoint();
8166 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8172 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
8173 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8174 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
8176 bool isDeclaration = gOp.isDeclaration();
8177 bool isExternallyVisible =
8180 llvm::StringRef mangledName = gOp.getSymName();
8181 auto captureClause =
8187 std::vector<llvm::GlobalVariable *> generatedRefs;
8189 std::vector<llvm::Triple> targetTriple;
8190 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
8192 LLVM::LLVMDialect::getTargetTripleAttrName()));
8193 if (targetTripleAttr)
8194 targetTriple.emplace_back(targetTripleAttr.data());
8196 auto fileInfoCallBack = [&loc]() {
8197 std::string filename =
"";
8198 std::uint64_t lineNo = 0;
8201 filename = loc.getFilename().str();
8202 lineNo = loc.getLine();
8205 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
8209 llvm::vfs::FileSystem &vfs = moduleTranslation.
getFileSystem();
8211 ompBuilder->registerTargetGlobalVariable(
8212 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8213 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8214 mangledName, generatedRefs,
false, targetTriple,
8216 gVal->getType(), gVal);
8218 bool requiresUSM = ompBuilder->Config.hasRequiresUnifiedSharedMemory();
8219 if (ompBuilder->Config.isTargetDevice() &&
8220 (attribute.getCaptureClause().getValue() ==
8221 mlir::omp::DeclareTargetCaptureClause::link ||
8223 llvm::Type *ptrTy = gVal->getType();
8227 ptrTy = llvm::PointerType::get(llvmModule->getContext(), 0);
8228 ompBuilder->getAddrOfDeclareTargetVar(
8229 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8230 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8231 mangledName, generatedRefs,
false, targetTriple,
8245class OpenMPDialectLLVMIRTranslationInterface
8246 :
public LLVMTranslationDialectInterface {
8248 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
8253 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
8254 LLVM::ModuleTranslation &moduleTranslation)
const final;
8259 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
8260 NamedAttribute attribute,
8261 LLVM::ModuleTranslation &moduleTranslation)
const final;
8266 void registerAllocatedPtr(Value var, llvm::Value *ptr)
const {
8267 ompAllocatedPtrs[var] = ptr;
8272 llvm::Value *lookupAllocatedPtr(Value var)
const {
8273 auto it = ompAllocatedPtrs.find(var);
8274 return it != ompAllocatedPtrs.end() ? it->second :
nullptr;
8286LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
8287 Operation *op, ArrayRef<llvm::Instruction *> instructions,
8288 NamedAttribute attribute,
8289 LLVM::ModuleTranslation &moduleTranslation)
const {
8290 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
8292 .Case(
"omp.is_target_device",
8293 [&](Attribute attr) {
8294 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
8295 llvm::OpenMPIRBuilderConfig &config =
8297 config.setIsTargetDevice(deviceAttr.getValue());
8303 [&](Attribute attr) {
8304 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
8305 llvm::OpenMPIRBuilderConfig &config =
8307 config.setIsGPU(gpuAttr.getValue());
8312 .Case(
"omp.host_ir_filepath",
8313 [&](Attribute attr) {
8314 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
8315 llvm::OpenMPIRBuilder *ompBuilder =
8317 ompBuilder->loadOffloadInfoMetadata(
8318 moduleTranslation.
getFileSystem(), filepathAttr.getValue());
8324 [&](Attribute attr) {
8325 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
8329 .Case(
"omp.version",
8330 [&](Attribute attr) {
8331 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
8332 llvm::OpenMPIRBuilder *ompBuilder =
8334 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
8335 versionAttr.getVersion());
8340 .Case(
"omp.declare_target",
8341 [&](Attribute attr) {
8342 if (
auto declareTargetAttr =
8343 dyn_cast<omp::DeclareTargetAttr>(attr)) {
8344 llvm::OpenMPIRBuilder *ompBuilder =
8347 ompBuilder, moduleTranslation);
8351 .Case(
"omp.requires",
8352 [&](Attribute attr) {
8353 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
8354 using Requires = omp::ClauseRequires;
8355 Requires flags = requiresAttr.getValue();
8356 llvm::OpenMPIRBuilderConfig &config =
8358 config.setHasRequiresReverseOffload(
8359 bitEnumContainsAll(flags, Requires::reverse_offload));
8360 config.setHasRequiresUnifiedAddress(
8361 bitEnumContainsAll(flags, Requires::unified_address));
8362 config.setHasRequiresUnifiedSharedMemory(
8363 bitEnumContainsAll(flags, Requires::unified_shared_memory));
8364 config.setHasRequiresDynamicAllocators(
8365 bitEnumContainsAll(flags, Requires::dynamic_allocators));
8370 .Case(
"omp.target_triples",
8371 [&](Attribute attr) {
8372 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
8373 llvm::OpenMPIRBuilderConfig &config =
8375 config.TargetTriples.clear();
8376 config.TargetTriples.reserve(triplesAttr.size());
8377 for (Attribute tripleAttr : triplesAttr) {
8378 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
8379 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
8387 .Default([](Attribute) {
8403 if (
auto declareTargetIface =
8404 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
8405 parentFn.getOperation()))
8406 if (declareTargetIface.isDeclareTarget() &&
8407 declareTargetIface.getDeclareTargetDeviceType() !=
8408 mlir::omp::DeclareTargetDeviceType::host)
8418 llvm::Module *llvmModule) {
8419 llvm::Type *i64Ty = builder.getInt64Ty();
8420 llvm::Type *i32Ty = builder.getInt32Ty();
8421 llvm::Type *returnType = builder.getPtrTy(0);
8422 llvm::FunctionType *fnType =
8423 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
8424 llvm::Function *
func = cast<llvm::Function>(
8425 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
8429template <
typename T>
8433 llvm::DataLayout dataLayout =
8435 llvm::Type *llvmHeapTy =
8436 moduleTranslation.
convertType(op.getMemElemTypeAttr().getValue());
8438 auto alignment = op.getMemAlignment();
8439 llvm::TypeSize typeSize = llvm::alignTo(
8440 dataLayout.getTypeStoreSize(llvmHeapTy),
8441 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
8443 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8444 return builder.CreateMul(
8446 builder.CreateIntCast(moduleTranslation.
lookupValue(op.getMemArraySize()),
8447 builder.getInt64Ty(),
8454 omp::TargetAllocMemOp op) {
8455 llvm::DataLayout dataLayout =
8457 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(op.getAllocatedType());
8458 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
8459 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8460 for (
auto typeParam : op.getTypeparams()) {
8461 allocSize = builder.CreateMul(
8463 builder.CreateIntCast(moduleTranslation.
lookupValue(typeParam),
8464 builder.getInt64Ty(),
8473 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
8478 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8482 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8484 llvm::Value *allocSize =
8487 llvm::CallInst *call =
8488 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
8489 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
8492 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
8498 llvm::IRBuilderBase &builder,
8502 moduleTranslation.
mapValue(allocMemOp.getResult(),
8503 ompBuilder->createOMPAllocShared(builder, size));
8510 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8511 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8514 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8515 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8516 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8518 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8520 llvm::Value *allocator;
8521 if (
auto allocatorVar = allocateDirOp.getAllocator()) {
8522 allocator = moduleTranslation.
lookupValue(allocatorVar);
8523 if (allocator->getType()->isIntegerTy())
8524 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8525 else if (allocator->getType()->isPointerTy())
8526 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8527 allocator, builder.getPtrTy());
8529 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8532 for (
Value var : vars) {
8533 llvm::Type *llvmVarTy = moduleTranslation.
convertType(var.getType());
8537 llvm::Type *typeToInspect = llvmVarTy;
8538 if (llvmVarTy->isPointerTy()) {
8541 if (
auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8542 typeToInspect = moduleTranslation.
convertType(gop.getGlobalType());
8547 if (
auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8548 llvm::Value *elementCount = builder.getInt64(1);
8549 llvm::Type *currentType = arrTy;
8550 while (
auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8551 elementCount = builder.CreateMul(
8552 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8553 currentType = nestedArrTy->getElementType();
8555 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8557 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8559 size = builder.getInt64(
8560 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8563 uint64_t alignValue =
8564 alignAttr ? alignAttr.value()
8565 : dataLayout.getABITypeAlign(typeToInspect).value();
8566 llvm::Value *alignConst = builder.getInt64(alignValue);
8568 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1),
"",
true);
8569 size = builder.CreateUDiv(size, alignConst);
8570 size = builder.CreateMul(size, alignConst,
"",
true);
8572 std::string allocName =
8573 ompBuilder->createPlatformSpecificName({
".void.addr"});
8574 llvm::CallInst *allocCall;
8575 if (alignAttr.has_value()) {
8576 allocCall = ompBuilder->createOMPAlignedAlloc(
8577 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8581 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8584 ompIface.registerAllocatedPtr(var, allocCall);
8593 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8594 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8596 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8598 llvm::Value *allocator;
8599 if (
auto allocatorVar = freeOp.getAllocator()) {
8600 allocator = moduleTranslation.
lookupValue(allocatorVar);
8601 if (allocator->getType()->isIntegerTy())
8602 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8603 else if (allocator->getType()->isPointerTy())
8604 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8605 allocator, builder.getPtrTy());
8607 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8612 for (
Value var : llvm::reverse(vars)) {
8613 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8615 return opInst.
emitError(
"omp.allocate_free: no allocation recorded");
8616 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator,
"");
8623 llvm::Module *llvmModule) {
8624 llvm::Type *ptrTy = builder.getPtrTy(0);
8625 llvm::Type *i32Ty = builder.getInt32Ty();
8626 llvm::Type *voidTy = builder.getVoidTy();
8627 llvm::FunctionType *fnType =
8628 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
8629 llvm::Function *
func = dyn_cast<llvm::Function>(
8630 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
8637 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8642 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8646 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8649 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
8651 llvm::Value *intToPtr =
8652 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8653 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8659 llvm::IRBuilderBase &builder,
8663 ompBuilder->createOMPFreeShared(
8664 builder, moduleTranslation.
lookupValue(freeMemOp.getHeapref()), size);
8673 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8678 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8682 bool shouldAllocate =
true;
8683 switch (groupprivateOp.getDeviceType().value_or(
8684 mlir::omp::DeclareTargetDeviceType::any)) {
8685 case mlir::omp::DeclareTargetDeviceType::host:
8686 shouldAllocate = !isTargetDevice;
8688 case mlir::omp::DeclareTargetDeviceType::nohost:
8689 shouldAllocate = isTargetDevice;
8691 case mlir::omp::DeclareTargetDeviceType::any:
8692 shouldAllocate =
true;
8698 &opInst, groupprivateOp.getSymNameAttr());
8701 <<
"expected symbol '" << groupprivateOp.getSymName()
8702 <<
"' to reference an LLVM global variable";
8704 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
8705 llvm::Type *varType = moduleTranslation.
convertType(global.getType());
8706 std::string varName = globalValue->getName().str();
8708 llvm::Value *resultPtr;
8709 if (shouldAllocate && isTargetDevice) {
8710 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8711 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8712 unsigned sharedAddressSpace;
8713 if (targetTriple.isAMDGCN())
8714 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8715 else if (targetTriple.isNVPTX())
8716 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8718 return opInst.
emitError() <<
"groupprivate is not supported for target: "
8719 << targetTriple.str();
8720 llvm::GlobalVariable *sharedVar =
new llvm::GlobalVariable(
8721 *llvmModule, varType,
false,
8722 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8723 varName,
nullptr, llvm::GlobalValue::NotThreadLocal,
8726 resultPtr = sharedVar;
8728 if (shouldAllocate && !isTargetDevice)
8729 opInst.
emitWarning(
"groupprivate directive is currently ignored on the "
8730 "host, using original global");
8731 resultPtr = globalValue;
8740LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8741 Operation *op, llvm::IRBuilderBase &builder,
8742 LLVM::ModuleTranslation &moduleTranslation)
const {
8745 if (ompBuilder->Config.isTargetDevice() &&
8746 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8749 return op->
emitOpError() <<
"unsupported host op found in device";
8757 bool isOutermostLoopWrapper =
8758 isa_and_present<omp::LoopWrapperInterface>(op) &&
8759 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
8768 if (isa<omp::TaskloopContextOp>(op))
8769 isOutermostLoopWrapper =
true;
8770 else if (isa<omp::TaskloopWrapperOp>(op))
8771 isOutermostLoopWrapper =
false;
8773 if (isOutermostLoopWrapper)
8774 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
8777 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8778 .Case([&](omp::BarrierOp op) -> LogicalResult {
8782 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8783 ompBuilder->createBarrier(builder.saveIP(),
8784 llvm::omp::OMPD_barrier);
8786 if (res.succeeded()) {
8789 builder.restoreIP(*afterIP);
8793 .Case([&](omp::TaskyieldOp op) {
8797 ompBuilder->createTaskyield(builder.saveIP());
8800 .Case([&](omp::FlushOp op) {
8812 ompBuilder->createFlush(builder.saveIP());
8815 .Case([&](omp::ParallelOp op) {
8818 .Case([&](omp::MaskedOp) {
8821 .Case([&](omp::MasterOp) {
8824 .Case([&](omp::CriticalOp) {
8827 .Case([&](omp::OrderedRegionOp) {
8830 .Case([&](omp::OrderedOp) {
8833 .Case([&](omp::WsloopOp) {
8836 .Case([&](omp::SimdOp) {
8839 .Case([&](omp::AtomicReadOp) {
8842 .Case([&](omp::AtomicWriteOp) {
8845 .Case([&](omp::AtomicUpdateOp op) {
8848 .Case([&](omp::AtomicCaptureOp op) {
8851 .Case([&](omp::AtomicCompareOp op) {
8854 .Case([&](omp::CancelOp op) {
8857 .Case([&](omp::CancellationPointOp op) {
8860 .Case([&](omp::SectionsOp) {
8863 .Case([&](omp::ScopeOp op) {
8866 .Case([&](omp::SingleOp op) {
8869 .Case([&](omp::TeamsOp op) {
8872 .Case([&](omp::TaskOp op) {
8875 .Case([&](omp::TaskloopWrapperOp op) {
8878 .Case([&](omp::TaskloopContextOp op) {
8881 .Case([&](omp::TaskgroupOp op) {
8884 .Case([&](omp::TaskwaitOp op) {
8887 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8888 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8889 omp::CriticalDeclareOp>([](
auto op) {
8902 .Case([&](omp::ThreadprivateOp) {
8905 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8906 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
8909 .Case([&](omp::TargetOp) {
8912 .Case([&](omp::DistributeOp) {
8915 .Case([&](omp::LoopNestOp) {
8918 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8919 omp::AffinityEntryOp, omp::IteratorOp>([&](
auto op) {
8925 .Case([&](omp::NewCliOp op) {
8930 .Case([&](omp::CanonicalLoopOp op) {
8933 .Case([&](omp::UnrollHeuristicOp op) {
8942 .Case([&](omp::TileOp op) {
8943 return applyTile(op, builder, moduleTranslation);
8945 .Case([&](omp::FuseOp op) {
8946 return applyFuse(op, builder, moduleTranslation);
8948 .Case([&](omp::TargetAllocMemOp) {
8951 .Case([&](omp::TargetFreeMemOp) {
8954 .Case([&](omp::AllocateDirOp) {
8957 .Case([&](omp::AllocateFreeOp) {
8961 .Case([&](omp::AllocSharedMemOp op) {
8964 .Case([&](omp::FreeSharedMemOp op) {
8967 .Case([&](omp::GroupprivateOp) {
8970 .Default([&](Operation *inst) {
8972 <<
"not yet implemented: " << inst->
getName();
8975 if (isOutermostLoopWrapper)
8982 registry.
insert<omp::OpenMPDialect>();
8984 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static void mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static void processIndividualMap(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag=llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, bool isTargetParam=true, int mapDataParentIdx=-1)
This function handles the insertion of a single item of map data from MapInfoData into the OMPIRBuild...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo, bool first=true)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.compare operation to LLVM IR.
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from a floating-point comparison predicate....
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from an integer comparison predicate. Returns std::nullopt f...
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.