26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
71 llvm_unreachable(
"unhandled schedule clause argument");
76class OpenMPAllocStackFrame
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
92class OpenMPLoopInfoStackFrame
96 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
115class PreviouslyReportedError
116 :
public llvm::ErrorInfo<PreviouslyReportedError> {
118 void log(raw_ostream &)
const override {
122 std::error_code convertToErrorCode()
const override {
124 "PreviouslyReportedError doesn't support ECError conversion");
131char PreviouslyReportedError::ID = 0;
142class LinearClauseProcessor {
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.
convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar,
int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
205 if (linearVarType->isIntegerTy()) {
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 }
else if (linearVarType->isFloatingPointTy()) {
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
227 "Linear variable must be of integer or floating-point type");
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(),
"omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
272 builder.SetInsertPoint(linearExitBB->getTerminator());
274 builder.saveIP(), llvm::omp::OMPD_barrier);
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
289 void rewriteInPlace(llvm::IRBuilderBase &builder, llvm::BasicBlock *startBB,
290 llvm::BasicBlock *endBB, llvm::StringRef prefix,
292 llvm::SmallVector<llvm::BasicBlock *, 32> worklist;
293 llvm::SmallPtrSet<llvm::BasicBlock *, 32> visited;
294 llvm::SmallPtrSet<llvm::BasicBlock *, 32> matchingBBs;
296 assert(startBB && endBB &&
"Invalid startBB/endBB");
300 worklist.push_back(startBB);
301 visited.insert(startBB);
303 while (!worklist.empty()) {
304 llvm::BasicBlock *bb = worklist.pop_back_val();
306 if (bb->hasName() && bb->getName().starts_with(prefix))
307 matchingBBs.insert(bb);
312 for (llvm::BasicBlock *succ : llvm::successors(bb)) {
313 if (visited.insert(succ).second)
314 worklist.push_back(succ);
319 llvm::SmallVector<llvm::User *> users(linearOrigVal[varIndex]->users());
320 for (
auto *user : users) {
321 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
322 if (matchingBBs.contains(userInst->getParent()))
323 user->replaceUsesOfWith(linearOrigVal[varIndex],
324 linearLoopBodyTemps[varIndex]);
335 SymbolRefAttr symbolName) {
336 omp::PrivateClauseOp privatizer =
339 assert(privatizer &&
"privatizer not found in the symbol table");
350 auto todo = [&op](StringRef clauseName) {
351 return op.
emitError() <<
"not yet implemented: Unhandled clause "
352 << clauseName <<
" in " << op.
getName()
356 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
357 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
358 result = todo(
"allocate");
360 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
362 result = todo(
"ompx_bare");
364 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
365 if (!op.getDependVars().empty() || op.getDependKinds())
368 auto checkHint = [](
auto op, LogicalResult &) {
372 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
373 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
374 op.getInReductionSyms())
375 result = todo(
"in_reduction");
377 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
381 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
382 if (op.getOrder() || op.getOrderMod())
385 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
386 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
387 result = todo(
"privatization");
389 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
390 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
391 if (!op.getReductionVars().empty() || op.getReductionByref() ||
392 op.getReductionSyms())
393 result = todo(
"reduction");
394 if (op.getReductionMod() &&
395 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
396 result = todo(
"reduction with modifier");
398 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
399 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
400 op.getTaskReductionSyms())
401 result = todo(
"task_reduction");
403 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
404 if (op.hasNumTeamsMultiDim())
405 result = todo(
"num_teams with multi-dimensional values");
407 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
408 if (op.hasNumThreadsMultiDim())
409 result = todo(
"num_threads with multi-dimensional values");
412 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
413 if (op.hasThreadLimitMultiDim())
414 result = todo(
"thread_limit with multi-dimensional values");
417 auto checkDynGroupprivate = [&todo](
auto op, LogicalResult &
result) {
418 if (op.getDynGroupprivateSize())
419 result = todo(
"dyn_groupprivate");
424 .Case([&](omp::DistributeOp op) {
425 checkAllocate(op,
result);
428 .Case([&](omp::SectionsOp op) {
429 checkAllocate(op,
result);
431 checkReduction(op,
result);
433 .Case([&](omp::ScopeOp op) {
434 checkAllocate(op,
result);
435 checkReduction(op,
result);
437 .Case([&](omp::SingleOp op) {
438 checkAllocate(op,
result);
441 .Case([&](omp::TeamsOp op) {
442 checkAllocate(op,
result);
444 checkNumTeams(op,
result);
445 checkThreadLimit(op,
result);
446 checkDynGroupprivate(op,
result);
448 .Case([&](omp::TaskOp op) {
449 checkAllocate(op,
result);
450 checkInReduction(op,
result);
452 .Case([&](omp::TaskgroupOp op) {
453 checkAllocate(op,
result);
454 checkTaskReduction(op,
result);
456 .Case([&](omp::TaskwaitOp op) {
460 .Case([&](omp::TaskloopContextOp op) {
461 checkAllocate(op,
result);
462 checkInReduction(op,
result);
463 checkReduction(op,
result);
465 .Case([&](omp::WsloopOp op) {
466 checkAllocate(op,
result);
468 checkReduction(op,
result);
470 .Case([&](omp::ParallelOp op) {
471 checkAllocate(op,
result);
472 checkReduction(op,
result);
473 checkNumThreads(op,
result);
475 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
476 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
477 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
478 .Case([&](omp::AtomicCompareOp op) {
484 auto structTy = dyn_cast<LLVM::LLVMStructType>(argType);
490 result = todo(
"compare for complex types wider than 128 bits");
492 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
493 [&](
auto op) { checkDepend(op,
result); })
494 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
495 .Case([&](omp::TargetOp op) {
496 checkAllocate(op,
result);
498 checkInReduction(op,
result);
499 checkThreadLimit(op,
result);
511 llvm::handleAllErrors(
513 [&](
const PreviouslyReportedError &) {
result = failure(); },
514 [&](
const llvm::ErrorInfoBase &err) {
537 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
540 [&](OpenMPAllocStackFrame &frame) {
541 allocInsertPoint = frame.allocInsertPoint;
542 deallocInsertPoints = frame.deallocBlocks;
550 allocInsertPoint.getBlock()->getParent() ==
551 builder.GetInsertBlock()->getParent()) {
553 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
554 deallocInsertPoints.end());
555 return allocInsertPoint;
565 if (builder.GetInsertBlock() ==
566 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
567 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
568 "Assuming end of basic block");
569 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
570 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
571 builder.GetInsertBlock()->getNextNode());
572 builder.CreateBr(entryBB);
573 builder.SetInsertPoint(entryBB);
579 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
583 llvm::Instruction *terminator = block.getTerminatorOrNull();
584 if (isa_and_present<llvm::ReturnInst>(terminator))
585 deallocBlocks->emplace_back(&block);
589 llvm::BasicBlock &funcEntryBlock =
590 builder.GetInsertBlock()->getParent()->getEntryBlock();
591 return llvm::OpenMPIRBuilder::InsertPointTy(
592 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
598static llvm::CanonicalLoopInfo *
600 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
601 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
602 [&](OpenMPLoopInfoStackFrame &frame) {
603 loopInfo = frame.loopInfo;
615 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
618 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
620 llvm::BasicBlock *continuationBlock =
621 splitBB(builder,
true,
"omp.region.cont");
622 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
624 llvm::LLVMContext &llvmContext = builder.getContext();
625 for (
Block &bb : region) {
626 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
627 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
628 builder.GetInsertBlock()->getNextNode());
629 moduleTranslation.
mapBlock(&bb, llvmBB);
632 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
639 unsigned numYields = 0;
641 if (!isLoopWrapper) {
642 bool operandsProcessed =
false;
644 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
645 if (!operandsProcessed) {
646 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
647 continuationBlockPHITypes.push_back(
648 moduleTranslation.
convertType(yield->getOperand(i).getType()));
650 operandsProcessed =
true;
652 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
653 "mismatching number of values yielded from the region");
654 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
655 llvm::Type *operandType =
656 moduleTranslation.
convertType(yield->getOperand(i).getType());
658 assert(continuationBlockPHITypes[i] == operandType &&
659 "values of mismatching types yielded from the region");
669 if (!continuationBlockPHITypes.empty())
671 continuationBlockPHIs &&
672 "expected continuation block PHIs if converted regions yield values");
673 if (continuationBlockPHIs) {
674 llvm::IRBuilderBase::InsertPointGuard guard(builder);
675 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
676 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
677 for (llvm::Type *ty : continuationBlockPHITypes)
678 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
684 for (
Block *bb : blocks) {
685 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
688 if (bb->isEntryBlock()) {
689 assert(sourceTerminator->getNumSuccessors() == 1 &&
690 "provided entry block has multiple successors");
691 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
692 "ContinuationBlock is not the successor of the entry block");
693 sourceTerminator->setSuccessor(0, llvmBB);
696 llvm::IRBuilderBase::InsertPointGuard guard(builder);
698 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
699 return llvm::make_error<PreviouslyReportedError>();
704 builder.CreateBr(continuationBlock);
715 Operation *terminator = bb->getTerminator();
716 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
717 builder.CreateBr(continuationBlock);
719 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
720 (*continuationBlockPHIs)[i]->addIncoming(
734 return continuationBlock;
740 case omp::ClauseProcBindKind::Close:
741 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
742 case omp::ClauseProcBindKind::Master:
743 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
744 case omp::ClauseProcBindKind::Primary:
745 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
746 case omp::ClauseProcBindKind::Spread:
747 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
749 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
756 auto maskedOp = cast<omp::MaskedOp>(opInst);
757 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
762 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
765 auto ®ion = maskedOp.getRegion();
766 builder.restoreIP(codeGenIP);
774 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
776 llvm::Value *filterVal =
nullptr;
777 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
778 filterVal = moduleTranslation.
lookupValue(filterVar);
780 llvm::LLVMContext &llvmContext = builder.getContext();
782 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
784 assert(filterVal !=
nullptr);
785 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
786 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
793 builder.restoreIP(*afterIP);
801 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
802 auto masterOp = cast<omp::MasterOp>(opInst);
807 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
810 auto ®ion = masterOp.getRegion();
811 builder.restoreIP(codeGenIP);
819 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
821 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
822 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
829 builder.restoreIP(*afterIP);
837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
838 auto criticalOp = cast<omp::CriticalOp>(opInst);
843 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
846 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
847 builder.restoreIP(codeGenIP);
855 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
857 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
858 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
859 llvm::Constant *hint =
nullptr;
862 if (criticalOp.getNameAttr()) {
865 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
866 auto criticalDeclareOp =
870 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
871 static_cast<int>(criticalDeclareOp.getHint()));
873 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
875 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
880 builder.restoreIP(*afterIP);
887 template <
typename OP>
890 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
893 collectPrivatizationDecls<OP>(op);
908 void collectPrivatizationDecls(OP op) {
909 std::optional<ArrayAttr> attr = op.getPrivateSyms();
914 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
925 std::optional<ArrayAttr> attr = op.getReductionSyms();
929 reductions.reserve(reductions.size() + op.getNumReductionVars());
930 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
931 reductions.push_back(
943 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
952 llvm::Instruction *potentialTerminator =
953 builder.GetInsertBlock()->empty() ?
nullptr
954 : &builder.GetInsertBlock()->back();
956 if (potentialTerminator && potentialTerminator->isTerminator())
957 potentialTerminator->removeFromParent();
958 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
961 region.
front(),
true, builder)))
965 if (continuationBlockArgs)
967 *continuationBlockArgs,
974 if (potentialTerminator && potentialTerminator->isTerminator()) {
975 llvm::BasicBlock *block = builder.GetInsertBlock();
976 if (block->empty()) {
982 potentialTerminator->insertInto(block, block->begin());
984 potentialTerminator->insertAfter(&block->back());
998 if (continuationBlockArgs)
999 llvm::append_range(*continuationBlockArgs, phis);
1000 builder.SetInsertPoint(*continuationBlock,
1001 (*continuationBlock)->getFirstInsertionPt());
1008using OwningReductionGen =
1009 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1010 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
1012using OwningAtomicReductionGen =
1013 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1014 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
1016using OwningDataPtrPtrReductionGen =
1017 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1018 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
1024static OwningReductionGen
1030 OwningReductionGen gen =
1031 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1032 llvm::Value *
lhs, llvm::Value *
rhs,
1033 llvm::Value *&
result)
mutable
1034 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1035 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
1036 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
1037 builder.restoreIP(insertPoint);
1040 "omp.reduction.nonatomic.body", builder,
1041 moduleTranslation, &phis)))
1042 return llvm::createStringError(
1043 "failed to inline `combiner` region of `omp.declare_reduction`");
1044 result = llvm::getSingleElement(phis);
1045 return builder.saveIP();
1054static OwningAtomicReductionGen
1056 llvm::IRBuilderBase &builder,
1058 if (decl.getAtomicReductionRegion().empty())
1059 return OwningAtomicReductionGen();
1065 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1066 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1067 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1068 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1069 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1070 builder.restoreIP(insertPoint);
1073 "omp.reduction.atomic.body", builder,
1074 moduleTranslation, &phis)))
1075 return llvm::createStringError(
1076 "failed to inline `atomic` region of `omp.declare_reduction`");
1077 assert(phis.empty());
1078 return builder.saveIP();
1087static OwningDataPtrPtrReductionGen
1090 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1091 return OwningDataPtrPtrReductionGen();
1093 OwningDataPtrPtrReductionGen refDataPtrGen =
1094 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1095 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1096 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1097 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1098 builder.restoreIP(insertPoint);
1101 "omp.data_ptr_ptr.body", builder,
1102 moduleTranslation, &phis)))
1103 return llvm::createStringError(
1104 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1105 result = llvm::getSingleElement(phis);
1106 return builder.saveIP();
1109 return refDataPtrGen;
1116 auto orderedOp = cast<omp::OrderedOp>(opInst);
1121 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1122 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1123 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1125 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1127 size_t indexVecValues = 0;
1128 while (indexVecValues < vecValues.size()) {
1130 storeValues.reserve(numLoops);
1131 for (
unsigned i = 0; i < numLoops; i++) {
1132 storeValues.push_back(vecValues[indexVecValues]);
1135 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1137 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1138 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1139 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1149 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1150 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1155 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1158 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1159 builder.restoreIP(codeGenIP);
1167 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1169 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1170 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1172 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1177 builder.restoreIP(*afterIP);
1183struct DeferredStore {
1184 DeferredStore(llvm::Value *value, llvm::Value *address)
1185 : value(value), address(address) {}
1188 llvm::Value *address;
1195template <
typename T>
1198 llvm::IRBuilderBase &builder,
1200 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1206 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1207 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1213 deferredStores.reserve(op.getNumReductionVars());
1215 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1216 Region &allocRegion = reductionDecls[i].getAllocRegion();
1218 if (allocRegion.
empty())
1223 builder, moduleTranslation, &phis)))
1224 return op.emitError(
1225 "failed to inline `alloc` region of `omp.declare_reduction`");
1227 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1228 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1232 llvm::Type *ptrTy = builder.getPtrTy();
1236 if (useDeviceSharedMem) {
1237 var = ompBuilder->createOMPAllocShared(builder, varTy);
1239 var = builder.CreateAlloca(varTy);
1240 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1243 llvm::Value *castPhi =
1244 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1246 deferredStores.emplace_back(castPhi, var);
1248 privateReductionVariables[i] = var;
1249 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1250 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1252 assert(allocRegion.
empty() &&
1253 "allocaction is implicit for by-val reduction");
1255 llvm::Type *ptrTy = builder.getPtrTy();
1259 if (useDeviceSharedMem) {
1260 var = ompBuilder->createOMPAllocShared(builder, varTy);
1262 var = builder.CreateAlloca(varTy);
1263 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1266 moduleTranslation.
mapValue(reductionArgs[i], var);
1267 privateReductionVariables[i] = var;
1268 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1276template <
typename T>
1279 llvm::IRBuilderBase &builder,
1284 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1285 Region &initializerRegion = reduction.getInitializerRegion();
1288 mlir::Value mlirSource = loop.getReductionVars()[i];
1289 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1290 llvm::Value *origVal = llvmSource;
1292 if (!isa<LLVM::LLVMPointerType>(
1293 reduction.getInitializerMoldArg().getType()) &&
1294 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1297 reduction.getInitializerMoldArg().getType()),
1298 llvmSource,
"omp_orig");
1300 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1303 llvm::Value *allocation =
1304 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1305 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1311 llvm::BasicBlock *block =
nullptr) {
1312 if (block ==
nullptr)
1313 block = builder.GetInsertBlock();
1315 if (!block->hasTerminator())
1316 builder.SetInsertPoint(block);
1318 builder.SetInsertPoint(block->getTerminator());
1326template <
typename OP>
1329 llvm::IRBuilderBase &builder,
1331 llvm::BasicBlock *latestAllocaBlock,
1337 if (op.getNumReductionVars() == 0)
1343 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1344 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1345 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1346 builder.restoreIP(allocaIP);
1349 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1351 if (!reductionDecls[i].getAllocRegion().empty())
1359 if (useDeviceSharedMem)
1360 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1362 byRefVars[i] = builder.CreateAlloca(varTy);
1370 for (
auto [data, addr] : deferredStores)
1371 builder.CreateStore(data, addr);
1376 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1381 reductionVariableMap, i);
1389 "omp.reduction.neutral", builder,
1390 moduleTranslation, &phis)))
1393 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1394 "reduction neutral element declaration region");
1399 if (!reductionDecls[i].getAllocRegion().empty())
1408 builder.CreateStore(phis[0], byRefVars[i]);
1410 privateReductionVariables[i] = byRefVars[i];
1411 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1412 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1415 builder.CreateStore(phis[0], privateReductionVariables[i]);
1422 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1429template <
typename T>
1430static void collectReductionInfo(
1431 T loop, llvm::IRBuilderBase &builder,
1440 unsigned numReductions = loop.getNumReductionVars();
1442 for (
unsigned i = 0; i < numReductions; ++i) {
1445 owningAtomicReductionGens.push_back(
1448 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1452 reductionInfos.reserve(numReductions);
1453 for (
unsigned i = 0; i < numReductions; ++i) {
1454 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1455 if (owningAtomicReductionGens[i])
1456 atomicGen = owningAtomicReductionGens[i];
1457 llvm::Value *variable =
1458 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1461 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1462 allocatedType = alloca.getElemType();
1469 reductionInfos.push_back(
1471 privateReductionVariables[i],
1472 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1476 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1477 reductionDecls[i].getByrefElementType()
1479 *reductionDecls[i].getByrefElementType())
1489 llvm::IRBuilderBase &builder, StringRef regionName,
1490 bool shouldLoadCleanupRegionArg =
true) {
1491 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1492 if (cleanupRegion->empty())
1498 llvm::Instruction *potentialTerminator =
1499 builder.GetInsertBlock()->empty() ?
nullptr
1500 : &builder.GetInsertBlock()->back();
1501 if (potentialTerminator && potentialTerminator->isTerminator())
1502 builder.SetInsertPoint(potentialTerminator);
1503 llvm::Value *privateVarValue =
1504 shouldLoadCleanupRegionArg
1505 ? builder.CreateLoad(
1507 privateVariables[i])
1508 : privateVariables[i];
1513 moduleTranslation)))
1526 OP op, llvm::IRBuilderBase &builder,
1528 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1531 bool isNowait =
false,
bool isTeamsReduction =
false) {
1533 if (op.getNumReductionVars() == 0)
1545 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1547 owningReductionGenRefDataPtrGens,
1548 privateReductionVariables, reductionInfos, isByRef);
1553 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1554 builder.SetInsertPoint(tempTerminator);
1555 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1556 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1557 isByRef, isNowait, isTeamsReduction);
1562 if (!contInsertPoint->getBlock())
1563 return op->emitOpError() <<
"failed to convert reductions";
1565 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1566 if (!isTeamsReduction) {
1567 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1568 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1572 afterIP = *barrierIP;
1575 tempTerminator->eraseFromParent();
1576 builder.restoreIP(afterIP);
1580 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1581 [](omp::DeclareReductionOp reductionDecl) {
1582 return &reductionDecl.getCleanupRegion();
1585 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1586 "omp.reduction.cleanup");
1589 if (useDeviceSharedMem) {
1590 for (
auto [var, reductionDecl] :
1591 llvm::zip_equal(privateReductionVariables, reductionDecls))
1592 ompBuilder->createOMPFreeShared(
1593 builder, var, moduleTranslation.
convertType(reductionDecl.getType()));
1606template <
typename OP>
1610 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1615 if (op.getNumReductionVars() == 0)
1621 allocaIP, reductionDecls,
1622 privateReductionVariables, reductionVariableMap,
1623 deferredStores, isByRef)))
1626 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1627 allocaIP.getBlock(), reductionDecls,
1628 privateReductionVariables, reductionVariableMap,
1629 isByRef, deferredStores);
1643 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1646 Value blockArg = (*mappedPrivateVars)[privateVar];
1649 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1650 "A block argument corresponding to a mapped var should have "
1653 if (privVarType == blockArgType)
1660 if (!isa<LLVM::LLVMPointerType>(privVarType))
1661 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1674 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1676 llvm::BasicBlock *privInitBlock,
1678 Region &initRegion = privDecl.getInitRegion();
1679 if (initRegion.
empty())
1680 return llvmPrivateVar;
1682 assert(nonPrivateVar);
1683 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1684 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1689 moduleTranslation, &phis)))
1690 return llvm::createStringError(
1691 "failed to inline `init` region of `omp.private`");
1693 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1710 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1713 builder, moduleTranslation, privDecl,
1716 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1725 return llvm::Error::success();
1727 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1730 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1733 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1735 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1736 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1739 return privVarOrErr.takeError();
1741 llvmPrivateVar = privVarOrErr.get();
1742 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1747 return llvm::Error::success();
1753template <
typename T>
1758 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1761 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1762 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1763 allocaTerminator->getIterator()),
1764 true, allocaTerminator->getStableDebugLoc(),
1765 "omp.region.after_alloca");
1767 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1769 allocaTerminator = allocaIP.getBlock()->getTerminator();
1770 builder.SetInsertPoint(allocaTerminator);
1772 assert(allocaTerminator->getNumSuccessors() == 1 &&
1773 "This is an unconditional branch created by splitBB");
1775 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1776 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1780 unsigned int allocaAS =
1781 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1784 .getProgramAddressSpace();
1786 for (
auto [privDecl, mlirPrivVar, blockArg] :
1789 llvm::Type *llvmAllocType =
1790 moduleTranslation.
convertType(privDecl.getType());
1791 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1792 llvm::Value *llvmPrivateVar =
nullptr;
1794 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1796 llvmPrivateVar = builder.CreateAlloca(
1797 llvmAllocType,
nullptr,
"omp.private.alloc");
1798 if (allocaAS != defaultAS)
1799 llvmPrivateVar = builder.CreateAddrSpaceCast(
1800 llvmPrivateVar, builder.getPtrTy(defaultAS));
1803 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1806 return afterAllocas;
1814 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1823 if (mlir::isa<omp::ParallelOp>(parent))
1837 bool needsFirstprivate =
1838 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1839 return privOp.getDataSharingType() ==
1840 omp::DataSharingClauseType::FirstPrivate;
1843 if (!needsFirstprivate)
1846 llvm::BasicBlock *copyBlock =
1847 splitBB(builder,
true,
"omp.private.copy");
1850 for (
auto [decl, moldVar, llvmVar] :
1851 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1852 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1856 Region ©Region = decl.getCopyRegion();
1858 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1861 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1865 moduleTranslation)))
1866 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1880 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1881 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1897 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1899 llvm::Value *moldVar = findAssociatedValue(
1900 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1905 llvmPrivateVars, privateDecls, insertBarrier,
1909template <
typename T>
1917 std::back_inserter(privateCleanupRegions),
1918 [](omp::PrivateClauseOp privatizer) {
1919 return &privatizer.getDeallocRegion();
1923 privateVarsInfo.
llvmVars, moduleTranslation,
1924 builder,
"omp.private.dealloc",
1926 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1927 "`omp.private` op in");
1931 for (
auto [privDecl, llvmPrivVar, blockArg] :
1935 ompBuilder->createOMPFreeShared(
1936 builder, llvmPrivVar,
1937 moduleTranslation.
convertType(privDecl.getType()));
1951 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1961 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1962 using StorableBodyGenCallbackTy =
1963 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1965 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1971 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1975 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1979 sectionsOp.getNumReductionVars());
1983 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1986 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1987 reductionDecls, privateReductionVariables, reductionVariableMap,
1994 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1998 Region ®ion = sectionOp.getRegion();
1999 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
2000 InsertPointTy allocaIP, InsertPointTy codeGenIP,
2002 builder.restoreIP(codeGenIP);
2009 sectionsOp.getRegion().getNumArguments());
2010 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
2011 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
2012 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
2014 moduleTranslation.
mapValue(sectionArg, llvmVal);
2021 sectionCBs.push_back(sectionCB);
2027 if (sectionCBs.empty())
2030 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
2035 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
2036 llvm::Value &vPtr, llvm::Value *&replacementValue)
2037 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
2038 replacementValue = &vPtr;
2044 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2048 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2049 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2051 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2052 sectionsOp.getNowait());
2057 builder.restoreIP(*afterIP);
2061 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2062 privateReductionVariables, isByRef, sectionsOp.getNowait());
2069 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2076 assert(isByRef.size() == scopeOp.getNumReductionVars());
2085 scopeOp.getNumReductionVars());
2089 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2093 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2098 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2099 reductionDecls, privateReductionVariables, reductionVariableMap,
2104 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2106 builder.restoreIP(codeGenIP);
2112 return llvm::make_error<PreviouslyReportedError>();
2115 scopeOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2117 scopeOp.getPrivateNeedsBarrier())))
2118 return llvm::make_error<PreviouslyReportedError>();
2125 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2126 InsertPointTy oldIP = builder.saveIP();
2127 builder.restoreIP(codeGenIP);
2129 scopeOp.getLoc(), privateVarsInfo)))
2130 return llvm::make_error<PreviouslyReportedError>();
2131 builder.restoreIP(oldIP);
2132 return llvm::Error::success();
2135 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2136 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2137 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2142 builder.restoreIP(*afterIP);
2146 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2147 privateReductionVariables, isByRef, scopeOp.getNowait(),
2155 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2156 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2161 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2163 builder.restoreIP(codegenIP);
2165 builder, moduleTranslation)
2168 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2172 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2175 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
2176 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
2178 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2179 llvmCPFuncs.push_back(
2183 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2185 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2191 builder.restoreIP(*afterIP);
2195static omp::DistributeOp
2199 omp::DistributeOp distOp;
2200 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2206 if (walk.wasInterrupted() || !distOp)
2210 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2214 for (
auto ra : iface.getReductionBlockArgs())
2215 for (
auto &use : ra.getUses()) {
2216 auto *useOp = use.getOwner();
2218 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2219 debugUses.push_back(useOp);
2222 if (!distOp->isProperAncestor(useOp))
2229 for (
auto *use : debugUses)
2238 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2243 unsigned numReductionVars = op.getNumReductionVars();
2247 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2253 if (doTeamsReduction) {
2254 isByRef =
getIsByRef(op.getReductionByref());
2256 assert(isByRef.size() == op.getNumReductionVars());
2259 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2264 op, reductionArgs, builder, moduleTranslation, allocaIP,
2265 reductionDecls, privateReductionVariables, reductionVariableMap,
2270 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2273 moduleTranslation, allocaIP, deallocBlocks);
2274 builder.restoreIP(codegenIP);
2280 llvm::Value *numTeamsLower =
nullptr;
2281 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2282 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2284 llvm::Value *numTeamsUpper =
nullptr;
2285 if (!op.getNumTeamsUpperVars().empty())
2286 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2288 llvm::Value *threadLimit =
nullptr;
2289 if (!op.getThreadLimitVars().empty())
2290 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2292 llvm::Value *ifExpr =
nullptr;
2293 if (
Value ifVar = op.getIfExpr())
2296 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2297 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2299 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2304 builder.restoreIP(*afterIP);
2305 if (doTeamsReduction) {
2308 op, builder, moduleTranslation, allocaIP, reductionDecls,
2309 privateReductionVariables, isByRef,
2315static llvm::omp::RTLDependenceKindTy
2318 case mlir::omp::ClauseTaskDepend::taskdependin:
2319 return llvm::omp::RTLDependenceKindTy::DepIn;
2323 case mlir::omp::ClauseTaskDepend::taskdependout:
2324 case mlir::omp::ClauseTaskDepend::taskdependinout:
2325 return llvm::omp::RTLDependenceKindTy::DepInOut;
2326 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2327 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2328 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2329 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2331 llvm_unreachable(
"unhandled depend kind");
2335 std::optional<ArrayAttr> dependKinds,
OperandRange dependVars,
2338 if (dependVars.empty())
2340 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2342 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2344 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2345 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2346 dds.emplace_back(dd);
2358 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2360 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2361 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2365 llvmBuilder.restoreIP(ip);
2371 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2372 return llvm::Error::success();
2377 ompBuilder.pushFinalizationCB(
2387 llvm::OpenMPIRBuilder &ompBuilder,
2388 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2389 ompBuilder.popFinalizationCB();
2390 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2391 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2392 cancelBranch->setSuccessor(constructFini);
2398class TaskContextStructManager {
2400 TaskContextStructManager(llvm::IRBuilderBase &builder,
2401 LLVM::ModuleTranslation &moduleTranslation,
2402 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2403 : builder{builder}, moduleTranslation{moduleTranslation},
2404 privateDecls{privateDecls} {}
2410 void generateTaskContextStruct();
2416 void createGEPsToPrivateVars();
2422 SmallVector<llvm::Value *>
2423 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2426 void freeStructPtr();
2428 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2429 return llvmPrivateVarGEPs;
2432 llvm::Value *getStructPtr() {
return structPtr; }
2435 llvm::IRBuilderBase &builder;
2436 LLVM::ModuleTranslation &moduleTranslation;
2437 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2440 SmallVector<llvm::Type *> privateVarTypes;
2444 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2447 llvm::Value *structPtr =
nullptr;
2449 llvm::Type *structTy =
nullptr;
2460 llvm::SmallVector<llvm::Value *> lowerBounds;
2461 llvm::SmallVector<llvm::Value *> upperBounds;
2462 llvm::SmallVector<llvm::Value *> steps;
2463 llvm::SmallVector<llvm::Value *> trips;
2465 llvm::Value *totalTrips;
2467 llvm::Value *lookUpAsI64(mlir::Value val,
const LLVM::ModuleTranslation &mt,
2468 llvm::IRBuilderBase &builder) {
2472 if (v->getType()->isIntegerTy(64))
2474 if (v->getType()->isIntegerTy())
2475 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2480 IteratorInfo(mlir::omp::IteratorOp itersOp,
2481 mlir::LLVM::ModuleTranslation &moduleTranslation,
2482 llvm::IRBuilderBase &builder) {
2483 dims = itersOp.getLoopLowerBounds().size();
2484 lowerBounds.resize(dims);
2485 upperBounds.resize(dims);
2489 for (
unsigned d = 0; d < dims; ++d) {
2490 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2491 moduleTranslation, builder);
2492 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2493 moduleTranslation, builder);
2495 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2496 assert(lb && ub && st &&
2497 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2498 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2499 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2500 "Expect non-zero step in IteratorOp");
2502 lowerBounds[d] = lb;
2503 upperBounds[d] = ub;
2507 llvm::Value *diff = builder.CreateSub(ub, lb);
2508 llvm::Value *
div = builder.CreateSDiv(diff, st);
2509 trips[d] = builder.CreateAdd(
2510 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2513 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2514 for (
unsigned d = 0; d < dims; ++d)
2515 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2518 unsigned getDims()
const {
return dims; }
2519 llvm::ArrayRef<llvm::Value *> getLowerBounds()
const {
return lowerBounds; }
2520 llvm::ArrayRef<llvm::Value *> getUpperBounds()
const {
return upperBounds; }
2521 llvm::ArrayRef<llvm::Value *> getSteps()
const {
return steps; }
2522 llvm::ArrayRef<llvm::Value *> getTrips()
const {
return trips; }
2523 llvm::Value *getTotalTrips()
const {
return totalTrips; }
2528void TaskContextStructManager::generateTaskContextStruct() {
2529 if (privateDecls.empty())
2531 privateVarTypes.reserve(privateDecls.size());
2533 for (omp::PrivateClauseOp &privOp : privateDecls) {
2536 if (!privOp.readsFromMold())
2538 Type mlirType = privOp.getType();
2539 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2542 if (privateVarTypes.empty())
2545 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2548 llvm::DataLayout dataLayout =
2549 builder.GetInsertBlock()->getModule()->getDataLayout();
2550 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2551 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2554 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2556 "omp.task.context_ptr");
2559SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2560 llvm::Value *altStructPtr)
const {
2561 SmallVector<llvm::Value *> ret;
2564 ret.reserve(privateDecls.size());
2565 llvm::Value *zero = builder.getInt32(0);
2567 for (
auto privDecl : privateDecls) {
2568 if (!privDecl.readsFromMold()) {
2570 ret.push_back(
nullptr);
2573 llvm::Value *iVal = builder.getInt32(i);
2574 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2581void TaskContextStructManager::createGEPsToPrivateVars() {
2583 assert(privateVarTypes.empty());
2587 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2590void TaskContextStructManager::freeStructPtr() {
2594 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2596 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2597 builder.CreateFree(structPtr);
2601 llvm::OpenMPIRBuilder &ompBuilder,
2602 llvm::Value *affinityList, llvm::Value *
index,
2603 llvm::Value *addr, llvm::Value *len) {
2604 llvm::StructType *kmpTaskAffinityInfoTy =
2605 ompBuilder.getKmpTaskAffinityInfoTy();
2606 llvm::Value *entry = builder.CreateInBoundsGEP(
2607 kmpTaskAffinityInfoTy, affinityList,
index,
"omp.affinity.entry");
2609 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2610 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2612 llvm::Value *flags = builder.getInt32(0);
2614 builder.CreateStore(addr,
2615 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2616 builder.CreateStore(len,
2617 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2618 builder.CreateStore(flags,
2619 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2623 llvm::IRBuilderBase &builder,
2625 llvm::Value *affinityList) {
2626 for (
auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2627 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2628 assert(entryOp &&
"affinity item must be omp.affinity_entry");
2630 llvm::Value *addr = moduleTranslation.
lookupValue(entryOp.getAddr());
2631 llvm::Value *len = moduleTranslation.
lookupValue(entryOp.getLen());
2632 assert(addr && len &&
"expect affinity addr and len to be non-null");
2634 affinityList, builder.getInt64(i), addr, len);
2638static mlir::LogicalResult
2641 llvm::IRBuilderBase &builder,
2643 llvm::Value *tmp = linearIV;
2644 for (
int d = (
int)iterInfo.getDims() - 1; d >= 0; --d) {
2645 llvm::Value *trip = iterInfo.getTrips()[d];
2647 llvm::Value *idx = builder.CreateURem(tmp, trip);
2649 tmp = builder.CreateUDiv(tmp, trip);
2652 llvm::Value *physIV = builder.CreateAdd(
2653 iterInfo.getLowerBounds()[d],
2654 builder.CreateMul(idx, iterInfo.getSteps()[d]),
"omp.it.phys_iv");
2660 moduleTranslation.
mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2661 if (mlir::failed(moduleTranslation.
convertBlock(iteratorRegionBlock,
2664 return mlir::failure();
2666 return mlir::success();
2672static mlir::LogicalResult
2675 IteratorInfo &iterInfo, llvm::StringRef loopName,
2680 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2682 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2683 llvm::Value *linearIV) -> llvm::Error {
2684 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2685 builder.restoreIP(bodyIP);
2688 builder, moduleTranslation))) {
2689 return llvm::make_error<llvm::StringError>(
2690 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2694 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.
getTerminator());
2695 assert(yield && yield.getResults().size() == 1 &&
2696 "expect omp.yield in iterator region to have one result");
2698 genStoreEntry(linearIV, yield);
2704 return llvm::Error::success();
2707 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2709 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2713 builder.restoreIP(*afterIP);
2715 return mlir::success();
2718static mlir::LogicalResult
2721 llvm::OpenMPIRBuilder::AffinityData &ad) {
2723 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2726 return mlir::success();
2730 llvm::StructType *kmpTaskAffinityInfoTy =
2733 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2734 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2735 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2737 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2738 "omp.affinity_list");
2741 auto createAffinity =
2742 [&](llvm::Value *count,
2743 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2744 llvm::OpenMPIRBuilder::AffinityData ad{};
2745 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2747 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2751 if (!taskOp.getAffinityVars().empty()) {
2752 llvm::Value *count = llvm::ConstantInt::get(
2753 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2754 llvm::Value *list = allocateAffinityList(count);
2757 ads.emplace_back(createAffinity(count, list));
2760 if (!taskOp.getIterated().empty()) {
2761 for (
auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2762 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2763 assert(itersOp &&
"iterated value must be defined by omp.iterator");
2764 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2765 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2767 itersOp, builder, moduleTranslation, iterInfo,
"iterator",
2768 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2769 auto entryOp = yield.getResults()[0]
2770 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2771 assert(entryOp &&
"expect yield produce an affinity entry");
2778 affList, linearIV, addr, len);
2780 return llvm::failure();
2781 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2785 llvm::Value *totalAffinityCount = builder.getInt32(0);
2786 for (
const auto &affinity : ads)
2787 totalAffinityCount = builder.CreateAdd(
2789 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2792 llvm::Value *affinityInfo = ads.front().Info;
2793 if (ads.size() > 1) {
2794 llvm::StructType *kmpTaskAffinityInfoTy =
2796 llvm::Value *affinityInfoElemSize = builder.getInt64(
2797 moduleTranslation.
getLLVMModule()->getDataLayout().getTypeAllocSize(
2798 kmpTaskAffinityInfoTy));
2800 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2801 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2802 for (
const auto &affinity : ads) {
2803 llvm::Value *affinityCount = builder.CreateIntCast(
2804 affinity.Count, builder.getInt32Ty(),
false);
2805 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2806 affinityCount, builder.getInt64Ty(),
false);
2807 llvm::Value *affinityInfoSize =
2808 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2810 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2811 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2813 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2814 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2816 builder.CreateMemCpy(
2817 packedAffinityInfoIndex, llvm::Align(1),
2818 builder.CreatePointerBitCastOrAddrSpaceCast(
2819 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2820 ->getPointerAddressSpace())),
2821 llvm::Align(1), affinityInfoSize);
2823 packedAffinityInfoOffset =
2824 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2827 affinityInfo = packedAffinityInfo;
2830 ad.Count = totalAffinityCount;
2831 ad.Info = affinityInfo;
2833 return mlir::success();
2839static mlir::LogicalResult
2842 std::optional<ArrayAttr> dependIteratedKinds,
2843 llvm::IRBuilderBase &builder,
2845 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2846 if (dependIterated.empty()) {
2849 return mlir::success();
2853 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2854 unsigned numLocator = dependVars.size();
2857 llvm::Value *totalCount =
2858 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2861 for (
auto iter : dependIterated) {
2862 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2863 assert(itersOp &&
"depend_iterated value must be defined by omp.iterator");
2864 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2866 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2871 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2872 llvm::Value *depArray =
2873 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2874 totalCount,
nullptr,
".dep.arr.addr");
2877 if (numLocator > 0) {
2880 for (
auto [i, dd] : llvm::enumerate(dds)) {
2881 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2882 llvm::Value *entry =
2883 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2884 ompBuilder.emitTaskDependency(builder, entry, dd);
2889 llvm::Value *offset =
2890 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2891 for (
auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2892 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2893 dependIteratedKinds->getValue()[i]);
2894 llvm::omp::RTLDependenceKindTy rtlKind =
2897 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2899 itersOp, builder, moduleTranslation, iterInfo,
"dep_iterator",
2900 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2902 moduleTranslation.
lookupValue(yield.getResults()[0]);
2903 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2904 llvm::Value *entry =
2905 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2906 ompBuilder.emitTaskDependency(
2908 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2911 return mlir::failure();
2914 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2917 taskDeps.DepArray = depArray;
2918 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2919 return mlir::success();
2926 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2931 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2943 InsertPointTy allocaIP =
2948 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2949 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2950 builder.getContext(),
"omp.task.start",
2951 builder.GetInsertBlock()->getParent());
2952 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2953 builder.SetInsertPoint(branchToTaskStartBlock);
2956 llvm::BasicBlock *copyBlock =
2957 splitBB(builder,
true,
"omp.private.copy");
2958 llvm::BasicBlock *initBlock =
2959 splitBB(builder,
true,
"omp.private.init");
2975 moduleTranslation, allocaIP, deallocBlocks);
2978 builder.SetInsertPoint(initBlock->getTerminator());
2981 taskStructMgr.generateTaskContextStruct();
2988 taskStructMgr.createGEPsToPrivateVars();
2990 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2993 taskStructMgr.getLLVMPrivateVarGEPs())) {
2995 if (!privDecl.readsFromMold())
2997 assert(llvmPrivateVarAlloc &&
2998 "reads from mold so shouldn't have been skipped");
3001 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3002 blockArg, llvmPrivateVarAlloc, initBlock);
3003 if (!privateVarOrErr)
3004 return handleError(privateVarOrErr, *taskOp.getOperation());
3013 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3014 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3015 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3017 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3018 llvmPrivateVarAlloc);
3020 assert(llvmPrivateVarAlloc->getType() ==
3021 moduleTranslation.
convertType(blockArg.getType()));
3031 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3032 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3033 taskOp.getPrivateNeedsBarrier())))
3034 return llvm::failure();
3036 llvm::OpenMPIRBuilder::AffinityData ad;
3038 return llvm::failure();
3041 builder.SetInsertPoint(taskStartBlock);
3044 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3049 moduleTranslation, allocaIP, deallocBlocks);
3052 builder.restoreIP(codegenIP);
3054 llvm::BasicBlock *privInitBlock =
nullptr;
3056 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3059 auto [blockArg, privDecl, mlirPrivVar] = zip;
3061 if (privDecl.readsFromMold())
3064 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3065 llvm::Type *llvmAllocType =
3066 moduleTranslation.
convertType(privDecl.getType());
3067 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3068 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3069 llvmAllocType,
nullptr,
"omp.private.alloc");
3072 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3073 blockArg, llvmPrivateVar, privInitBlock);
3074 if (!privateVarOrError)
3075 return privateVarOrError.takeError();
3076 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3077 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3080 taskStructMgr.createGEPsToPrivateVars();
3081 for (
auto [i, llvmPrivVar] :
3082 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3084 assert(privateVarsInfo.
llvmVars[i] &&
3085 "This is added in the loop above");
3088 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3093 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3097 if (!privateDecl.readsFromMold())
3100 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3101 llvmPrivateVar = builder.CreateLoad(
3102 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3104 assert(llvmPrivateVar->getType() ==
3105 moduleTranslation.
convertType(blockArg.getType()));
3106 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3110 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
3111 if (failed(
handleError(continuationBlockOrError, *taskOp)))
3112 return llvm::make_error<PreviouslyReportedError>();
3114 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3117 taskOp.getLoc(), privateVarsInfo)))
3118 return llvm::make_error<PreviouslyReportedError>();
3121 taskStructMgr.freeStructPtr();
3123 return llvm::Error::success();
3132 llvm::omp::Directive::OMPD_taskgroup);
3134 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3135 if (failed(
buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3136 taskOp.getDependIterated(),
3137 taskOp.getDependIteratedKinds(), builder,
3138 moduleTranslation, dependencies)))
3141 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3142 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3144 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3146 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dependencies, ad,
3147 taskOp.getMergeable(),
3148 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
3149 moduleTranslation.
lookupValue(taskOp.getPriority()));
3157 builder.restoreIP(*afterIP);
3159 if (dependencies.DepArray)
3160 builder.CreateFree(dependencies.DepArray);
3169 llvm::IRBuilderBase &builder,
3177 loopWrapperOp.getRegion(),
"omp.taskloop.wrapper.region", builder,
3180 if (failed(
handleError(continuationBlockOrError, opInst)))
3183 builder.SetInsertPoint(continuationBlockOrError.get());
3191static llvm::Expected<llvm::Value *>
3194 llvm::IRBuilderBase &builder) {
3195 if (llvm::Value *mapped = moduleTranslation.
lookupValue(value))
3200 return llvm::make_error<llvm::StringError>(
3201 "value is a block argument and is not mapped",
3202 llvm::inconvertibleErrorCode());
3204 return llvm::make_error<llvm::StringError>(
3205 "unsupported op defining taskloop loop bound",
3206 llvm::inconvertibleErrorCode());
3216 if (!operandOrError)
3217 return operandOrError.takeError();
3218 moduleTranslation.
mapValue(operand, *operandOrError);
3219 mappingsToRemove.push_back(operand);
3223 return llvm::make_error<llvm::StringError>(
3224 "failed to convert op defining taskloop loop bound",
3225 llvm::inconvertibleErrorCode());
3228 assert(
result &&
"expected conversion of loop bound op to produce a value");
3232 mappingsToRemove.push_back(resultValue);
3234 for (
Value mappedValue : mappingsToRemove)
3243 llvm::Value *&lbVal, llvm::Value *&ubVal,
3244 llvm::Value *&stepVal) {
3252 return firstLbOrErr.takeError();
3254 llvm::Type *boundType = (*firstLbOrErr)->getType();
3255 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3256 if (loopOp.getCollapseNumLoops() > 1) {
3274 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3276 i == 0 ? std::move(firstLbOrErr)
3280 return lbOrErr.takeError();
3282 upperBounds[i], moduleTranslation, builder);
3284 return ubOrErr.takeError();
3288 return stepOrErr.takeError();
3290 llvm::Value *loopLb = *lbOrErr;
3291 llvm::Value *loopUb = *ubOrErr;
3292 llvm::Value *loopStep = *stepOrErr;
3298 llvm::Value *loopLbMinusOne = builder.CreateSub(
3299 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3300 llvm::Value *loopUbMinusOne = builder.CreateSub(
3301 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3302 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3303 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3304 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3305 llvm::Value *loopTripCount =
3306 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3307 loopTripCount = builder.CreateBinaryIntrinsic(
3308 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3312 llvm::Value *loopTripCountDivStep =
3313 builder.CreateSDiv(loopTripCount, loopStep);
3314 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3315 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3316 llvm::Value *loopTripCountRem =
3317 builder.CreateSRem(loopTripCount, loopStep);
3318 loopTripCountRem = builder.CreateBinaryIntrinsic(
3319 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3320 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3322 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3325 builder.CreateAdd(loopTripCountDivStep,
3326 builder.CreateZExtOrTrunc(
3327 needsRoundUp, loopTripCountDivStep->getType()));
3328 ubVal = builder.CreateMul(ubVal, loopTripCount);
3330 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3331 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3336 return ubOrErr.takeError();
3340 return stepOrErr.takeError();
3341 lbVal = *firstLbOrErr;
3343 stepVal = *stepOrErr;
3346 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
3347 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
3348 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
3349 return llvm::Error::success();
3355 llvm::IRBuilderBase &builder,
3357 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3359 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3367 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3371 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3374 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3375 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3376 builder.getContext(),
"omp.taskloop.wrapper.start",
3377 builder.GetInsertBlock()->getParent());
3378 llvm::Instruction *branchToTaskloopStartBlock =
3379 builder.CreateBr(taskloopStartBlock);
3380 builder.SetInsertPoint(branchToTaskloopStartBlock);
3382 llvm::BasicBlock *copyBlock =
3383 splitBB(builder,
true,
"omp.private.copy");
3384 llvm::BasicBlock *initBlock =
3385 splitBB(builder,
true,
"omp.private.init");
3388 moduleTranslation, allocaIP, deallocBlocks);
3391 builder.SetInsertPoint(initBlock->getTerminator());
3394 taskStructMgr.generateTaskContextStruct();
3395 taskStructMgr.createGEPsToPrivateVars();
3397 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
3399 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3401 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3402 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3404 if (!privDecl.readsFromMold())
3406 assert(llvmPrivateVarAlloc &&
3407 "reads from mold so shouldn't have been skipped");
3410 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3411 blockArg, llvmPrivateVarAlloc, initBlock);
3412 if (!privateVarOrErr)
3413 return handleError(privateVarOrErr, *contextOp.getOperation());
3415 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3417 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3418 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3420 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3421 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3422 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3424 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3425 llvmPrivateVarAlloc);
3427 assert(llvmPrivateVarAlloc->getType() ==
3428 moduleTranslation.
convertType(blockArg.getType()));
3434 contextOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3435 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3436 contextOp.getPrivateNeedsBarrier())))
3437 return llvm::failure();
3440 builder.SetInsertPoint(taskloopStartBlock);
3442 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3443 llvm::Value *lbVal =
nullptr;
3444 llvm::Value *ubVal =
nullptr;
3445 llvm::Value *stepVal =
nullptr;
3447 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3451 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3456 moduleTranslation, allocaIP, deallocBlocks);
3459 builder.restoreIP(codegenIP);
3461 llvm::BasicBlock *privInitBlock =
nullptr;
3463 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3466 auto [blockArg, privDecl, mlirPrivVar] = zip;
3468 if (privDecl.readsFromMold())
3471 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3472 llvm::Type *llvmAllocType =
3473 moduleTranslation.
convertType(privDecl.getType());
3474 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3475 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3476 llvmAllocType,
nullptr,
"omp.private.alloc");
3479 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3480 blockArg, llvmPrivateVar, privInitBlock);
3481 if (!privateVarOrError)
3482 return privateVarOrError.takeError();
3483 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3484 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3487 taskStructMgr.createGEPsToPrivateVars();
3488 for (
auto [i, llvmPrivVar] :
3489 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3491 assert(privateVarsInfo.
llvmVars[i] &&
3492 "This is added in the loop above");
3495 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3500 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3504 if (!privateDecl.readsFromMold())
3507 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3508 llvmPrivateVar = builder.CreateLoad(
3509 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3511 assert(llvmPrivateVar->getType() ==
3512 moduleTranslation.
convertType(blockArg.getType()));
3513 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3519 contextOp.getRegion(),
"omp.taskloop.context.region", builder,
3522 if (failed(
handleError(continuationBlockOrError, opInst)))
3523 return llvm::make_error<PreviouslyReportedError>();
3525 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3533 contextOp.getLoc(), privateVarsInfo)))
3534 return llvm::make_error<PreviouslyReportedError>();
3537 taskStructMgr.freeStructPtr();
3539 return llvm::Error::success();
3545 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3546 llvm::Value *destPtr, llvm::Value *srcPtr)
3548 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3549 builder.restoreIP(codegenIP);
3552 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3554 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
3556 TaskContextStructManager &srcStructMgr = taskStructMgr;
3557 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3559 destStructMgr.generateTaskContextStruct();
3560 llvm::Value *dest = destStructMgr.getStructPtr();
3561 dest->setName(
"omp.taskloop.context.dest");
3562 builder.CreateStore(dest, destPtr);
3565 srcStructMgr.createGEPsToPrivateVars(src);
3567 destStructMgr.createGEPsToPrivateVars(dest);
3570 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3571 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
3574 if (!privDecl.readsFromMold())
3576 assert(llvmPrivateVarAlloc &&
3577 "reads from mold so shouldn't have been skipped");
3580 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3581 llvmPrivateVarAlloc, builder.GetInsertBlock());
3582 if (!privateVarOrErr)
3583 return privateVarOrErr.takeError();
3592 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3593 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3594 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3596 llvmPrivateVarAlloc = builder.CreateLoad(
3597 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3599 assert(llvmPrivateVarAlloc->getType() ==
3600 moduleTranslation.
convertType(blockArg.getType()));
3608 moduleTranslation, srcGEPs, destGEPs,
3610 contextOp.getPrivateNeedsBarrier())))
3611 return llvm::make_error<PreviouslyReportedError>();
3613 return builder.saveIP();
3621 llvm::Value *ifCond =
nullptr;
3622 llvm::Value *grainsize =
nullptr;
3624 mlir::Value grainsizeVal = contextOp.getGrainsize();
3625 mlir::Value numTasksVal = contextOp.getNumTasks();
3626 if (
Value ifVar = contextOp.getIfExpr())
3629 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
3631 }
else if (numTasksVal) {
3632 grainsize = moduleTranslation.
lookupValue(numTasksVal);
3636 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
3637 if (taskStructMgr.getStructPtr())
3638 taskDupOrNull = taskDupCB;
3648 llvm::omp::Directive::OMPD_taskgroup);
3650 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3651 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3653 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3654 stepVal, contextOp.getUntied(), ifCond, grainsize,
3655 contextOp.getNogroup(), sched,
3656 moduleTranslation.
lookupValue(contextOp.getFinal()),
3657 contextOp.getMergeable(),
3658 moduleTranslation.
lookupValue(contextOp.getPriority()),
3659 loopOp.getCollapseNumLoops(), taskDupOrNull,
3660 taskStructMgr.getStructPtr());
3667 builder.restoreIP(*afterIP);
3675 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3679 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3681 builder.restoreIP(codegenIP);
3683 builder, moduleTranslation)
3688 InsertPointTy allocaIP =
3690 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3691 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3693 ompLoc, allocaIP, deallocBlocks, bodyCB);
3698 builder.restoreIP(*afterIP);
3717 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3721 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3723 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3727 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3730 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
3731 llvm::Type *ivType = step->getType();
3732 llvm::Value *chunk =
nullptr;
3733 if (wsloopOp.getScheduleChunk()) {
3734 llvm::Value *chunkVar =
3735 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
3736 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3739 omp::DistributeOp distributeOp =
nullptr;
3740 llvm::Value *distScheduleChunk =
nullptr;
3741 bool hasDistSchedule =
false;
3742 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
3743 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
3744 hasDistSchedule = distributeOp.getDistScheduleStatic();
3745 if (distributeOp.getDistScheduleChunkSize()) {
3746 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3747 distributeOp.getDistScheduleChunkSize());
3748 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3757 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3761 wsloopOp.getNumReductionVars());
3764 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3771 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3776 moduleTranslation, allocaIP, reductionDecls,
3777 privateReductionVariables, reductionVariableMap,
3778 deferredStores, isByRef)))
3787 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3789 wsloopOp.getPrivateNeedsBarrier())))
3792 assert(afterAllocas.get()->getSinglePredecessor());
3793 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3795 afterAllocas.get()->getSinglePredecessor(),
3796 reductionDecls, privateReductionVariables,
3797 reductionVariableMap, isByRef, deferredStores)))
3801 bool isOrdered = wsloopOp.getOrdered().has_value();
3802 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3803 bool isSimd = wsloopOp.getScheduleSimd();
3804 bool loopNeedsBarrier = !wsloopOp.getNowait();
3809 llvm::omp::WorksharingLoopType workshareLoopType =
3810 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3811 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3812 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3816 llvm::omp::Directive::OMPD_for);
3818 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3821 LinearClauseProcessor linearClauseProcessor;
3823 if (!wsloopOp.getLinearVars().empty()) {
3824 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3826 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3828 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3829 linearClauseProcessor.createLinearVar(
3830 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3832 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3833 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3836 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3838 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3846 if (!wsloopOp.getLinearVars().empty()) {
3847 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3848 loopInfo->getPreheader());
3849 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3851 builder.saveIP(), llvm::omp::OMPD_barrier);
3854 builder.restoreIP(*afterBarrierIP);
3855 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3856 loopInfo->getIndVar());
3857 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3860 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3863 bool noLoopMode =
false;
3864 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3866 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3870 if (loopOp == targetCapturedOp) {
3871 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3872 omp::TargetExecMode::no_loop)
3877 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3878 ompBuilder->applyWorkshareLoop(
3879 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3880 convertToScheduleKind(schedule), chunk, isSimd,
3881 scheduleMod == omp::ScheduleModifier::monotonic,
3882 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3883 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3889 if (!wsloopOp.getLinearVars().empty()) {
3890 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3891 assert(loopInfo->getLastIter() &&
3892 "`lastiter` in CanonicalLoopInfo is nullptr");
3893 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3894 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3895 loopInfo->getLastIter());
3898 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3899 linearClauseProcessor.rewriteInPlace(
3900 builder, sourceBlock->getSingleSuccessor(), *regionBlock,
3901 "omp.loop_nest.region",
index);
3903 builder.restoreIP(oldIP);
3911 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3912 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3917 wsloopOp.getLoc(), privateVarsInfo);
3924 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3926 assert(isByRef.size() == opInst.getNumReductionVars());
3939 opInst.getNumReductionVars());
3943 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3946 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3948 return llvm::make_error<PreviouslyReportedError>();
3954 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3957 InsertPointTy(allocaIP.getBlock(),
3958 allocaIP.getBlock()->getTerminator()->getIterator());
3961 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3962 reductionDecls, privateReductionVariables, reductionVariableMap,
3963 deferredStores, isByRef)))
3964 return llvm::make_error<PreviouslyReportedError>();
3966 assert(afterAllocas.get()->getSinglePredecessor());
3967 builder.restoreIP(codeGenIP);
3973 return llvm::make_error<PreviouslyReportedError>();
3976 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3978 opInst.getPrivateNeedsBarrier())))
3979 return llvm::make_error<PreviouslyReportedError>();
3982 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3983 afterAllocas.get()->getSinglePredecessor(),
3984 reductionDecls, privateReductionVariables,
3985 reductionVariableMap, isByRef, deferredStores)))
3986 return llvm::make_error<PreviouslyReportedError>();
3991 moduleTranslation, allocaIP, deallocBlocks);
3995 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3997 return regionBlock.takeError();
4000 if (opInst.getNumReductionVars() > 0) {
4005 owningReductionGenRefDataPtrGens;
4007 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
4009 owningReductionGenRefDataPtrGens,
4010 privateReductionVariables, reductionInfos, isByRef);
4013 builder.SetInsertPoint((*regionBlock)->getTerminator());
4016 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
4017 builder.SetInsertPoint(tempTerminator);
4019 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
4020 ompBuilder->createReductions(
4021 builder.saveIP(), allocaIP, reductionInfos, isByRef,
4023 if (!contInsertPoint)
4024 return contInsertPoint.takeError();
4026 if (!contInsertPoint->getBlock())
4027 return llvm::make_error<PreviouslyReportedError>();
4029 tempTerminator->eraseFromParent();
4030 builder.restoreIP(*contInsertPoint);
4033 return llvm::Error::success();
4036 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
4037 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
4046 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4047 InsertPointTy oldIP = builder.saveIP();
4048 builder.restoreIP(codeGenIP);
4053 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4054 [](omp::DeclareReductionOp reductionDecl) {
4055 return &reductionDecl.getCleanupRegion();
4058 reductionCleanupRegions, privateReductionVariables,
4059 moduleTranslation, builder,
"omp.reduction.cleanup")))
4060 return llvm::createStringError(
4061 "failed to inline `cleanup` region of `omp.declare_reduction`");
4064 opInst.getLoc(), privateVarsInfo)))
4065 return llvm::make_error<PreviouslyReportedError>();
4069 if (isCancellable) {
4070 auto IPOrErr = ompBuilder->createBarrier(
4071 llvm::OpenMPIRBuilder::LocationDescription(builder),
4072 llvm::omp::Directive::OMPD_unknown,
4076 return IPOrErr.takeError();
4079 builder.restoreIP(oldIP);
4080 return llvm::Error::success();
4083 llvm::Value *ifCond =
nullptr;
4084 if (
auto ifVar = opInst.getIfExpr())
4086 llvm::Value *numThreads =
nullptr;
4087 if (!opInst.getNumThreadsVars().empty())
4088 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
4089 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4090 if (
auto bind = opInst.getProcBindKind())
4094 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4096 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4098 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4099 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4100 privCB, finiCB, ifCond, numThreads, pbKind,
4106 builder.restoreIP(*afterIP);
4111static llvm::omp::OrderKind
4114 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4116 case omp::ClauseOrderKind::Concurrent:
4117 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4119 llvm_unreachable(
"Unknown ClauseOrderKind kind");
4127 auto simdOp = cast<omp::SimdOp>(opInst);
4135 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4138 simdOp.getNumReductionVars());
4143 assert(isByRef.size() == simdOp.getNumReductionVars());
4145 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4149 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4154 LinearClauseProcessor linearClauseProcessor;
4156 if (!simdOp.getLinearVars().empty()) {
4157 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4159 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4160 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4161 bool isImplicit =
false;
4162 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4166 if (linearVar == mlirPrivVar) {
4168 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4169 llvmPrivateVar, idx);
4175 linearClauseProcessor.createLinearVar(
4176 builder, moduleTranslation,
4179 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
4180 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4184 moduleTranslation, allocaIP, reductionDecls,
4185 privateReductionVariables, reductionVariableMap,
4186 deferredStores, isByRef)))
4197 assert(afterAllocas.get()->getSinglePredecessor());
4198 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4200 afterAllocas.get()->getSinglePredecessor(),
4201 reductionDecls, privateReductionVariables,
4202 reductionVariableMap, isByRef, deferredStores)))
4205 llvm::ConstantInt *simdlen =
nullptr;
4206 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4207 simdlen = builder.getInt64(simdlenVar.value());
4209 llvm::ConstantInt *safelen =
nullptr;
4210 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4211 safelen = builder.getInt64(safelenVar.value());
4213 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4216 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4217 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4219 for (
size_t i = 0; i < operands.size(); ++i) {
4220 llvm::Value *alignment =
nullptr;
4221 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
4222 llvm::Type *ty = llvmVal->getType();
4224 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4225 alignment = builder.getInt64(intAttr.getInt());
4226 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
4227 assert(alignment &&
"Invalid alignment value");
4231 if (!intAttr.getValue().isPowerOf2())
4234 auto curInsert = builder.saveIP();
4235 builder.SetInsertPoint(sourceBlock);
4236 llvmVal = builder.CreateLoad(ty, llvmVal);
4237 builder.restoreIP(curInsert);
4238 alignedVars[llvmVal] = alignment;
4242 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
4249 if (simdOp.getLinearVars().size()) {
4250 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4251 loopInfo->getPreheader());
4253 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4254 loopInfo->getIndVar());
4256 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4258 ompBuilder->applySimd(loopInfo, alignedVars,
4260 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
4262 order, simdlen, safelen);
4264 linearClauseProcessor.emitStoresForLinearVar(builder);
4267 bool hasOrderedRegions =
false;
4268 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4269 hasOrderedRegions =
true;
4273 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
4274 llvm::BasicBlock *startBB = sourceBlock->getSingleSuccessor();
4275 llvm::BasicBlock *endBB = *regionBlock;
4276 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4277 "omp.loop_nest.region",
index);
4279 if (hasOrderedRegions) {
4281 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4282 "omp.ordered.region",
index);
4284 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4285 "omp_region.finalize",
index);
4293 for (
auto [i, tuple] : llvm::enumerate(
4294 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4295 privateReductionVariables))) {
4296 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4298 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
4299 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
4300 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
4304 llvm::Value *redValue = originalVariable;
4307 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
4308 llvm::Value *privateRedValue = builder.CreateLoad(
4309 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
4310 llvm::Value *reduced;
4312 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4315 builder.restoreIP(res.get());
4319 builder.CreateStore(reduced, originalVariable);
4324 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4325 [](omp::DeclareReductionOp reductionDecl) {
4326 return &reductionDecl.getCleanupRegion();
4329 moduleTranslation, builder,
4330 "omp.reduction.cleanup")))
4342 auto loopOp = cast<omp::LoopNestOp>(opInst);
4348 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4353 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4354 llvm::Value *iv) -> llvm::Error {
4357 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4362 bodyInsertPoints.push_back(ip);
4364 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4365 return llvm::Error::success();
4368 builder.restoreIP(ip);
4370 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
4372 return regionBlock.takeError();
4374 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4375 return llvm::Error::success();
4383 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4384 llvm::Value *lowerBound =
4385 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
4386 llvm::Value *upperBound =
4387 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
4388 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
4393 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4394 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4396 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4398 computeIP = loopInfos.front()->getPreheaderIP();
4402 ompBuilder->createCanonicalLoop(
4403 loc, bodyGen, lowerBound, upperBound, step,
4404 true, loopOp.getLoopInclusive(), computeIP);
4409 loopInfos.push_back(*loopResult);
4412 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4413 loopInfos.front()->getAfterIP();
4416 if (
const auto &tiles = loopOp.getTileSizes()) {
4417 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4420 for (
auto tile : tiles.value()) {
4421 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
4422 tileSizes.push_back(tileVal);
4425 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4426 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4430 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4431 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4432 afterIP = {afterAfterBB, afterAfterBB->begin()};
4436 for (
const auto &newLoop : newLoops)
4437 loopInfos.push_back(newLoop);
4441 const auto &numCollapse = loopOp.getCollapseNumLoops();
4443 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4445 auto newTopLoopInfo =
4446 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4448 assert(newTopLoopInfo &&
"New top loop information is missing");
4449 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
4450 [&](OpenMPLoopInfoStackFrame &frame) {
4451 frame.loopInfo = newTopLoopInfo;
4459 builder.restoreIP(afterIP);
4469 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4470 Value loopIV = op.getInductionVar();
4471 Value loopTC = op.getTripCount();
4473 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
4476 ompBuilder->createCanonicalLoop(
4478 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4481 moduleTranslation.
mapValue(loopIV, llvmIV);
4483 builder.restoreIP(ip);
4488 return bodyGenStatus.takeError();
4490 llvmTC,
"omp.loop");
4492 return op.emitError(llvm::toString(llvmOrError.takeError()));
4494 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4495 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4496 builder.restoreIP(afterIP);
4499 if (
Value cli = op.getCli())
4512 Value applyee = op.getApplyee();
4513 assert(applyee &&
"Loop to apply unrolling on required");
4515 llvm::CanonicalLoopInfo *consBuilderCLI =
4517 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4518 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4526static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4529 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4534 for (
Value size : op.getSizes()) {
4535 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
4536 assert(translatedSize &&
4537 "sizes clause arguments must already be translated");
4538 translatedSizes.push_back(translatedSize);
4541 for (
Value applyee : op.getApplyees()) {
4542 llvm::CanonicalLoopInfo *consBuilderCLI =
4544 assert(applyee &&
"Canonical loop must already been translated");
4545 translatedLoops.push_back(consBuilderCLI);
4548 auto generatedLoops =
4549 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4550 if (!op.getGeneratees().empty()) {
4551 for (
auto [mlirLoop,
genLoop] :
4552 zip_equal(op.getGeneratees(), generatedLoops))
4557 for (
Value applyee : op.getApplyees())
4565static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4568 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4572 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
4573 Value applyee = op.getApplyees()[i];
4574 llvm::CanonicalLoopInfo *consBuilderCLI =
4576 assert(applyee &&
"Canonical loop must already been translated");
4577 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4578 beforeFuse.push_back(consBuilderCLI);
4579 else if (op.getCount().has_value() &&
4580 i >= op.getFirst().value() + op.getCount().value() - 1)
4581 afterFuse.push_back(consBuilderCLI);
4583 toFuse.push_back(consBuilderCLI);
4586 (op.getGeneratees().empty() ||
4587 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4588 "Wrong number of generatees");
4591 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4592 if (!op.getGeneratees().empty()) {
4594 for (; i < beforeFuse.size(); i++)
4595 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4596 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4597 for (; i < afterFuse.size(); i++)
4598 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4602 for (
Value applyee : op.getApplyees())
4609static llvm::AtomicOrdering
4612 return llvm::AtomicOrdering::Monotonic;
4615 case omp::ClauseMemoryOrderKind::Seq_cst:
4616 return llvm::AtomicOrdering::SequentiallyConsistent;
4617 case omp::ClauseMemoryOrderKind::Acq_rel:
4618 return llvm::AtomicOrdering::AcquireRelease;
4619 case omp::ClauseMemoryOrderKind::Acquire:
4620 return llvm::AtomicOrdering::Acquire;
4621 case omp::ClauseMemoryOrderKind::Release:
4622 return llvm::AtomicOrdering::Release;
4623 case omp::ClauseMemoryOrderKind::Relaxed:
4624 return llvm::AtomicOrdering::Monotonic;
4626 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
4633 auto readOp = cast<omp::AtomicReadOp>(opInst);
4638 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4641 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4644 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
4645 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
4647 llvm::Type *elementType =
4648 moduleTranslation.
convertType(readOp.getElementType());
4650 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
4651 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
4652 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4660 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4665 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4668 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4670 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
4671 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
4672 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
4673 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
4676 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4684 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
4685 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
4686 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
4687 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
4688 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
4689 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
4690 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
4691 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
4692 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
4693 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4697 bool &isIgnoreDenormalMode,
4698 bool &isFineGrainedMemory,
4699 bool &isRemoteMemory) {
4700 isIgnoreDenormalMode =
false;
4701 isFineGrainedMemory =
false;
4702 isRemoteMemory =
false;
4703 if (atomicUpdateOp &&
4704 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4705 mlir::omp::AtomicControlAttr atomicControlAttr =
4706 atomicUpdateOp.getAtomicControlAttr();
4707 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4708 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4709 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4716 llvm::IRBuilderBase &builder,
4723 auto &innerOpList = opInst.getRegion().front().getOperations();
4724 bool isXBinopExpr{
false};
4725 llvm::AtomicRMWInst::BinOp binop;
4727 llvm::Value *llvmExpr =
nullptr;
4728 llvm::Value *llvmX =
nullptr;
4729 llvm::Type *llvmXElementType =
nullptr;
4730 if (innerOpList.size() == 2) {
4736 opInst.getRegion().getArgument(0))) {
4737 return opInst.emitError(
"no atomic update operation with region argument"
4738 " as operand found inside atomic.update region");
4741 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
4743 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4747 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4749 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4751 opInst.getRegion().getArgument(0).getType());
4752 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4756 llvm::AtomicOrdering atomicOrdering =
4761 [&opInst, &moduleTranslation](
4762 llvm::Value *atomicx,
4765 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4766 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4767 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4768 return llvm::make_error<PreviouslyReportedError>();
4770 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4771 assert(yieldop && yieldop.getResults().size() == 1 &&
4772 "terminator must be omp.yield op and it must have exactly one "
4774 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4777 bool isIgnoreDenormalMode;
4778 bool isFineGrainedMemory;
4779 bool isRemoteMemory;
4784 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4785 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4786 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4787 atomicOrdering, binop, updateFn,
4788 isXBinopExpr, isIgnoreDenormalMode,
4789 isFineGrainedMemory, isRemoteMemory);
4794 builder.restoreIP(*afterIP);
4800 llvm::IRBuilderBase &builder,
4807 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4808 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4810 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4811 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4813 assert((atomicUpdateOp || atomicWriteOp) &&
4814 "internal op must be an atomic.update or atomic.write op");
4816 if (atomicWriteOp) {
4817 isPostfixUpdate =
true;
4818 mlirExpr = atomicWriteOp.getExpr();
4820 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4821 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4822 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4825 if (innerOpList.size() == 2) {
4828 atomicUpdateOp.getRegion().getArgument(0))) {
4829 return atomicUpdateOp.emitError(
4830 "no atomic update operation with region argument"
4831 " as operand found inside atomic.update region");
4835 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4838 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4842 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4843 llvm::Value *llvmX =
4844 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4845 llvm::Value *llvmV =
4846 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4847 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4848 atomicCaptureOp.getAtomicReadOp().getElementType());
4849 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4852 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4856 llvm::AtomicOrdering atomicOrdering =
4860 [&](llvm::Value *atomicx,
4863 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4864 Block &bb = *atomicUpdateOp.getRegion().
begin();
4865 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4867 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4868 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4869 return llvm::make_error<PreviouslyReportedError>();
4871 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4872 assert(yieldop && yieldop.getResults().size() == 1 &&
4873 "terminator must be omp.yield op and it must have exactly one "
4875 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4878 bool isIgnoreDenormalMode;
4879 bool isFineGrainedMemory;
4880 bool isRemoteMemory;
4882 isFineGrainedMemory, isRemoteMemory);
4885 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4886 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4887 ompBuilder->createAtomicCapture(
4888 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4889 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4890 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4892 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4895 builder.restoreIP(*afterIP);
4901static std::optional<llvm::omp::OMPAtomicCompareOp>
4903 switch (predicate) {
4904 case LLVM::ICmpPredicate::eq:
4905 return llvm::omp::OMPAtomicCompareOp::EQ;
4906 case LLVM::ICmpPredicate::slt:
4907 case LLVM::ICmpPredicate::ult:
4908 return llvm::omp::OMPAtomicCompareOp::MIN;
4909 case LLVM::ICmpPredicate::sgt:
4910 case LLVM::ICmpPredicate::ugt:
4911 return llvm::omp::OMPAtomicCompareOp::MAX;
4913 return std::nullopt;
4919static std::optional<llvm::omp::OMPAtomicCompareOp>
4921 switch (predicate) {
4922 case LLVM::FCmpPredicate::oeq:
4923 case LLVM::FCmpPredicate::ueq:
4924 return llvm::omp::OMPAtomicCompareOp::EQ;
4925 case LLVM::FCmpPredicate::olt:
4926 case LLVM::FCmpPredicate::ult:
4927 return llvm::omp::OMPAtomicCompareOp::MIN;
4928 case LLVM::FCmpPredicate::ogt:
4929 case LLVM::FCmpPredicate::ugt:
4930 return llvm::omp::OMPAtomicCompareOp::MAX;
4932 return std::nullopt;
4954 llvm::IRBuilderBase &builder,
4960 Region ®ion = atomicCompareOp.getRegion();
4964 llvm::Type *llvmXElementType =
4966 if (!llvmXElementType)
4967 return atomicCompareOp.emitError(
4968 "unable to determine element type for atomic compare");
4970 llvm::Value *llvmX = moduleTranslation.
lookupValue(atomicCompareOp.getX());
4975 bool isSigned =
false;
4976 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4980 llvm::AtomicOrdering atomicOrdering =
4983 auto isAtomicComparePatternOp = [](
Operation &op) {
4984 return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
5005 if (isAtomicComparePatternOp(op))
5010 return moduleTranslation.lookupValue(v) != nullptr;
5012 if (!allOperandsMapped)
5016 return atomicCompareOp.emitError(
5017 "failed to translate operation inside atomic compare region");
5022 auto materializeValue = [&](
mlir::Value val) -> llvm::Value * {
5024 if (llvm::Value *existing = moduleTranslation.
lookupValue(val))
5029 if (loadOp->getParentRegion() == ®ion) {
5030 llvm::Value *loadAddr = moduleTranslation.
lookupValue(loadOp.getAddr());
5033 llvm::Type *loadType =
5034 moduleTranslation.
convertType(loadOp.getResult().getType());
5035 return builder.CreateLoad(loadType, loadAddr);
5043 llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5044 llvm::Value *eVal =
nullptr;
5045 llvm::Value *dVal =
nullptr;
5046 bool isXBinopExpr =
false;
5049 if (
auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
5050 return extractOp.getContainer();
5064 bool isComplexPattern =
false;
5066 if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
5072 if (!lhsFcmp || !rhsFcmp)
5077 mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
5078 mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
5079 bool lhsXIsOp0 = (lhsAgg0 == block.
getArgument(0));
5080 bool lhsXIsOp1 = (lhsAgg1 == block.
getArgument(0));
5081 if (!lhsXIsOp0 && !lhsXIsOp1)
5083 mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
5087 if (isa<LLVM::AndOp>(op))
5088 compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5091 return atomicCompareOp.emitError(
5092 "unsupported comparison predicate (NE) for complex atomic compare");
5094 isXBinopExpr = lhsXIsOp0;
5095 eVal = materializeValue(eAggregate);
5096 isComplexPattern =
true;
5100 if (isComplexPattern) {
5103 if (
auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5104 dVal = materializeValue(selectOp.getTrueValue());
5110 if (yieldOp.getResults().empty())
5111 return atomicCompareOp.emitError(
5112 "failed to extract desired value (d) from atomic compare region");
5113 dVal = materializeValue(yieldOp.getResults()[0]);
5116 const llvm::DataLayout &DL =
5117 builder.GetInsertBlock()->getModule()->getDataLayout();
5118 unsigned totalBits =
5119 DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
5121 llvm::IntegerType *intTy =
5122 llvm::IntegerType::get(builder.getContext(), totalBits);
5124 llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
5125 llvm::Align intAlign = DL.getABITypeAlign(intTy);
5126 llvm::Align maxAlign = std::max(complexAlign, intAlign);
5128 llvm::AllocaInst *eAlloca =
5129 builder.CreateAlloca(llvmXElementType,
nullptr,
"cmplx.e");
5130 eAlloca->setAlignment(maxAlign);
5131 llvm::AllocaInst *dAlloca =
5132 builder.CreateAlloca(llvmXElementType,
nullptr,
"cmplx.d");
5133 dAlloca->setAlignment(maxAlign);
5135 builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
5137 builder.CreateAlignedLoad(intTy, eAlloca, maxAlign,
"cmplx.e.int");
5138 builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
5140 builder.CreateAlignedLoad(intTy, dAlloca, maxAlign,
"cmplx.d.int");
5142 llvm::AtomicOrdering failOrdering =
5143 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
5144 builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign, atomicOrdering,
5149 if (atomicOrdering == llvm::AtomicOrdering::Release ||
5150 atomicOrdering == llvm::AtomicOrdering::AcquireRelease ||
5151 atomicOrdering == llvm::AtomicOrdering::SequentiallyConsistent) {
5152 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5153 ompBuilder->createFlush(ompLoc);
5159 if (
auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
5163 return atomicCompareOp.emitError(
5164 "unsupported comparison predicate in atomic compare");
5165 compareOp = *maybeOp;
5167 LLVM::ICmpPredicate pred = icmpOp.getPredicate();
5168 isSigned = (pred == LLVM::ICmpPredicate::slt ||
5169 pred == LLVM::ICmpPredicate::sgt ||
5170 pred == LLVM::ICmpPredicate::sle ||
5171 pred == LLVM::ICmpPredicate::sge);
5174 isXBinopExpr = (icmpOp.getOperand(0) == block.
getArgument(0));
5176 isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
5177 eVal = materializeValue(eOperand);
5178 }
else if (
auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
5182 return atomicCompareOp.emitError(
5183 "unsupported comparison predicate in atomic compare");
5184 compareOp = *maybeOp;
5186 isXBinopExpr = (fcmpOp.getOperand(0) == block.
getArgument(0));
5188 isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
5189 eVal = materializeValue(eOperand);
5190 }
else if (
auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5192 dVal = materializeValue(selectOp.getTrueValue());
5200 if (
auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5201 dVal = materializeValue(selectOp.getTrueValue());
5208 return atomicCompareOp.emitError(
5209 "failed to extract expected value (e) from atomic compare region");
5213 if (yieldOp.getResults().empty())
5214 return atomicCompareOp.emitError(
5215 "failed to extract desired value (d) from atomic compare region");
5216 dVal = materializeValue(yieldOp.getResults()[0]);
5219 llvmAtomicX.IsSigned = isSigned;
5221 llvm::OpenMPIRBuilder::AtomicOpValue vOpVal = {
nullptr,
nullptr,
false,
5223 llvm::OpenMPIRBuilder::AtomicOpValue rOpVal = {
nullptr,
nullptr,
false,
5225 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5227 bool savedHandleFPNegZero = ompBuilder->setHandleFPNegZero(
true);
5228 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5229 ompBuilder->createAtomicCompare(ompLoc, llvmAtomicX, vOpVal, rOpVal, eVal,
5230 dVal, atomicOrdering, compareOp,
5231 isXBinopExpr,
false,
false);
5232 ompBuilder->setHandleFPNegZero(savedHandleFPNegZero);
5234 if (failed(
handleError(afterIP, *atomicCompareOp)))
5237 builder.restoreIP(*afterIP);
5242 omp::ClauseCancellationConstructType directive) {
5243 switch (directive) {
5244 case omp::ClauseCancellationConstructType::Loop:
5245 return llvm::omp::Directive::OMPD_for;
5246 case omp::ClauseCancellationConstructType::Parallel:
5247 return llvm::omp::Directive::OMPD_parallel;
5248 case omp::ClauseCancellationConstructType::Sections:
5249 return llvm::omp::Directive::OMPD_sections;
5250 case omp::ClauseCancellationConstructType::Taskgroup:
5251 return llvm::omp::Directive::OMPD_taskgroup;
5253 llvm_unreachable(
"Unhandled cancellation construct type");
5262 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5265 llvm::Value *ifCond =
nullptr;
5266 if (
Value ifVar = op.getIfExpr())
5269 llvm::omp::Directive cancelledDirective =
5272 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5273 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
5275 if (failed(
handleError(afterIP, *op.getOperation())))
5278 builder.restoreIP(afterIP.get());
5285 llvm::IRBuilderBase &builder,
5290 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5293 llvm::omp::Directive cancelledDirective =
5296 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5297 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
5299 if (failed(
handleError(afterIP, *op.getOperation())))
5302 builder.restoreIP(afterIP.get());
5312 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5314 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
5319 Value symAddr = threadprivateOp.getSymAddr();
5322 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
5325 if (!isa<LLVM::AddressOfOp>(symOp))
5326 return opInst.
emitError(
"Addressing symbol not found");
5327 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
5329 LLVM::GlobalOp global =
5330 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
5331 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
5332 llvm::Type *type = globalValue->getValueType();
5333 llvm::TypeSize typeSize =
5334 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
5336 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
5337 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
5338 ompLoc, globalValue, size, global.getSymName() +
".cache");
5344static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
5346 switch (deviceClause) {
5347 case mlir::omp::DeclareTargetDeviceType::host:
5348 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
5350 case mlir::omp::DeclareTargetDeviceType::nohost:
5351 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
5353 case mlir::omp::DeclareTargetDeviceType::any:
5354 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
5357 llvm_unreachable(
"unhandled device clause");
5360static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
5362 mlir::omp::DeclareTargetCaptureClause captureClause) {
5363 switch (captureClause) {
5364 case mlir::omp::DeclareTargetCaptureClause::to:
5365 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
5366 case mlir::omp::DeclareTargetCaptureClause::link:
5367 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
5368 case mlir::omp::DeclareTargetCaptureClause::enter:
5369 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
5370 case mlir::omp::DeclareTargetCaptureClause::none:
5371 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
5373 llvm_unreachable(
"unhandled capture clause");
5378 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5380 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
5381 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
5382 return modOp.lookupSymbol(addressOfOp.getGlobalName());
5389 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5390 value = addrCast.getOperand();
5407static llvm::SmallString<64>
5409 llvm::OpenMPIRBuilder &ompBuilder,
5410 llvm::vfs::FileSystem &vfs) {
5412 llvm::raw_svector_ostream os(suffix);
5415 auto fileInfoCallBack = [&loc]() {
5416 return std::pair<std::string, uint64_t>(
5417 llvm::StringRef(loc.getFilename()), loc.getLine());
5422 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5424 os <<
"_decl_tgt_ref_ptr";
5430 if (
auto declareTargetGlobal =
5431 dyn_cast_if_present<omp::DeclareTargetInterface>(
5433 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5434 omp::DeclareTargetCaptureClause::link)
5440 if (
auto declareTargetGlobal =
5441 dyn_cast_if_present<omp::DeclareTargetInterface>(
5443 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5444 omp::DeclareTargetCaptureClause::to ||
5445 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5446 omp::DeclareTargetCaptureClause::enter)
5460 if (
auto declareTargetGlobal =
5461 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
5464 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
5465 omp::DeclareTargetCaptureClause::link) ||
5466 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5467 omp::DeclareTargetCaptureClause::to &&
5468 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5472 if (gOp.getSymName().contains(suffix))
5477 (gOp.getSymName().str() + suffix.str()).str());
5486struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5487 SmallVector<Operation *, 4> Mappers;
5490 void append(MapInfosTy &curInfo) {
5491 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5492 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5501struct MapInfoData : MapInfosTy {
5502 llvm::SmallVector<bool, 4> IsDeclareTarget;
5503 llvm::SmallVector<bool, 4> IsAMember;
5505 llvm::SmallVector<bool, 4> IsAMapping;
5506 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5507 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5510 llvm::SmallVector<llvm::Type *, 4> BaseType;
5513 void append(MapInfoData &CurInfo) {
5514 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5515 CurInfo.IsDeclareTarget.end());
5516 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5517 OriginalValue.append(CurInfo.OriginalValue.begin(),
5518 CurInfo.OriginalValue.end());
5519 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5520 MapInfosTy::append(CurInfo);
5524enum class TargetDirectiveEnumTy : uint32_t {
5528 TargetEnterData = 3,
5533static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5534 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5535 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
5536 .Case([](omp::TargetEnterDataOp) {
5537 return TargetDirectiveEnumTy::TargetEnterData;
5539 .Case([&](omp::TargetExitDataOp) {
5540 return TargetDirectiveEnumTy::TargetExitData;
5542 .Case([&](omp::TargetUpdateOp) {
5543 return TargetDirectiveEnumTy::TargetUpdate;
5545 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
5546 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
5553 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5554 arrTy.getElementType()))
5571 llvm::Value *basePointer,
5572 llvm::Type *baseType,
5573 llvm::IRBuilderBase &builder,
5575 if (
auto memberClause =
5576 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5581 if (!memberClause.getBounds().empty()) {
5582 llvm::Value *elementCount = builder.getInt64(1);
5583 for (
auto bounds : memberClause.getBounds()) {
5584 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5585 bounds.getDefiningOp())) {
5590 elementCount = builder.CreateMul(
5594 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
5595 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
5596 builder.getInt64(1)));
5603 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5611 return builder.CreateMul(elementCount,
5612 builder.getInt64(underlyingTypeSzInBits / 8));
5623static llvm::omp::OpenMPOffloadMappingFlags
5625 const bool hasExplicitMap =
5626 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
5627 omp::ClauseMapFlags::none;
5629 llvm::omp::OpenMPOffloadMappingFlags mapType =
5630 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5632 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::to))
5633 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5635 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::from))
5636 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5638 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::always))
5639 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5641 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::del))
5642 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5644 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::return_param))
5645 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5647 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::priv))
5648 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5650 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::literal))
5651 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5653 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::implicit))
5654 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5656 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::close))
5657 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5659 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::present))
5660 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5662 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::ompx_hold))
5663 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5665 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::attach))
5666 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5668 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::is_device_ptr)) {
5669 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5670 if (!hasExplicitMap)
5671 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5681 ArrayRef<Value> useDevAddrOperands = {},
5682 ArrayRef<Value> hasDevAddrOperands = {}) {
5684 auto checkRefPtrOrPteeMapWithAttach = [](omp::ClauseMapFlags mapType) {
5686 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptr) ||
5687 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptee);
5688 return hasRefType &&
5689 bitEnumContainsAll(mapType, omp::ClauseMapFlags::attach);
5692 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
5700 for (Value mapValue : mapVars) {
5701 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5702 for (
auto member : map.getMembers())
5703 if (member == mapOp)
5710 for (Value mapValue : mapVars) {
5711 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5712 bool isRefPtrOrPteeMapWithAttach =
5713 checkRefPtrOrPteeMapWithAttach(mapOp.getMapType());
5714 Value offloadPtr = (mapOp.getVarPtrPtr() && !isRefPtrOrPteeMapWithAttach)
5715 ? mapOp.getVarPtrPtr()
5716 : mapOp.getVarPtr();
5717 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
5718 mapData.Pointers.push_back(
5719 isRefPtrOrPteeMapWithAttach
5720 ? moduleTranslation.
lookupValue(mapOp.getVarPtrPtr())
5721 : mapData.OriginalValue.back());
5723 if (llvm::Value *refPtr =
5725 mapData.IsDeclareTarget.push_back(
true);
5726 mapData.BasePointers.push_back(refPtr);
5728 mapData.IsDeclareTarget.push_back(
true);
5729 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5731 mapData.IsDeclareTarget.push_back(
false);
5732 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5738 mapData.BaseType.push_back(moduleTranslation.
convertType(
5739 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5740 : mapOp.getVarPtrType()));
5747 mlir::Type sizeType = (isRefPtrOrPteeMapWithAttach || !mapOp.getVarPtrPtr())
5748 ? mapOp.getVarPtrType()
5749 : mapOp.getVarPtrPtrType().value();
5751 dl, sizeType, isRefPtrOrPteeMapWithAttach ?
nullptr : mapOp,
5752 mapData.Pointers.back(), moduleTranslation.
convertType(sizeType),
5753 builder, moduleTranslation));
5754 mapData.MapClause.push_back(mapOp.getOperation());
5758 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5759 if (mapOp.getMapperId())
5760 mapData.Mappers.push_back(
5762 mapOp, mapOp.getMapperIdAttr()));
5764 mapData.Mappers.push_back(
nullptr);
5765 mapData.IsAMapping.push_back(
true);
5766 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5769 auto findMapInfo = [&mapData](llvm::Value *val,
5770 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
5771 size_t memberCount) {
5774 for (llvm::Value *basePtr : mapData.OriginalValue) {
5775 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[index]);
5786 (mapData.Types[index] &
5787 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
5788 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5789 if (!isAttachMap && basePtr == val && mapData.IsAMapping[index] &&
5790 memberCount == mapOp.getMembers().size()) {
5792 mapData.Types[index] |=
5793 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5794 mapData.DevicePointers[index] = devInfoTy;
5802 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
5803 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5804 for (Value mapValue : useDevOperands) {
5805 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5807 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5808 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5811 if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
5812 mapData.OriginalValue.push_back(origValue);
5813 mapData.Pointers.push_back(mapData.OriginalValue.back());
5814 mapData.IsDeclareTarget.push_back(
false);
5815 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5816 mlir::Type baseTy = mapOp.getVarPtrPtr()
5817 ? mapOp.getVarPtrPtrType().value()
5818 : mapOp.getVarPtrType();
5819 mapData.BaseType.push_back(moduleTranslation.
convertType(baseTy));
5820 mapData.Sizes.push_back(builder.getInt64(0));
5821 mapData.MapClause.push_back(mapOp.getOperation());
5822 mapData.Types.push_back(
5823 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5826 mapData.DevicePointers.push_back(devInfoTy);
5827 mapData.Mappers.push_back(
nullptr);
5828 mapData.IsAMapping.push_back(
false);
5829 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5834 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5835 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5837 for (Value mapValue : hasDevAddrOperands) {
5838 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5840 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5841 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5843 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5845 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5846 omp::ClauseMapFlags::none;
5848 mapData.OriginalValue.push_back(origValue);
5849 mapData.BasePointers.push_back(origValue);
5850 mapData.Pointers.push_back(origValue);
5851 mapData.IsDeclareTarget.push_back(
false);
5853 mlir::Type baseTy = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5854 : mapOp.getVarPtrType();
5855 mapData.BaseType.push_back(moduleTranslation.
convertType(baseTy));
5856 mapData.Sizes.push_back(builder.getInt64(dl.
getTypeSize(baseTy)));
5858 mapData.MapClause.push_back(mapOp.getOperation());
5859 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5863 mapData.Types.push_back(mapType);
5867 if (mapOp.getMapperId()) {
5868 mapData.Mappers.push_back(
5870 mapOp, mapOp.getMapperIdAttr()));
5872 mapData.Mappers.push_back(
nullptr);
5877 mapData.Types.push_back(
5878 isDevicePtr ? mapType
5879 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5880 mapData.Mappers.push_back(
nullptr);
5884 mapData.DevicePointers.push_back(
5885 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5886 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5887 mapData.IsAMapping.push_back(
false);
5888 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5893 auto *res = llvm::find(mapData.MapClause, memberOp);
5894 assert(res != mapData.MapClause.end() &&
5895 "MapInfoOp for member not found in MapData, cannot return index");
5896 return std::distance(mapData.MapClause.begin(), res);
5900 omp::MapInfoOp mapInfo,
bool first =
true) {
5901 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5911 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5912 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5914 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5915 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5916 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5918 if (aIndex == bIndex)
5921 if (aIndex < bIndex)
5924 if (aIndex > bIndex)
5931 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5933 occludedChildren.push_back(
b);
5935 occludedChildren.push_back(a);
5936 return memberAParent;
5939 for (
auto v : occludedChildren)
5946 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5948 if (indexAttr.size() == 1)
5949 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5953 return llvm::cast<omp::MapInfoOp>(
5954 mapInfo.getMembers()[
indices.front()].getDefiningOp());
5977static std::vector<llvm::Value *>
5979 llvm::IRBuilderBase &builder,
bool isArrayTy,
5981 std::vector<llvm::Value *> idx;
5992 idx.push_back(builder.getInt64(0));
5993 for (
int i = bounds.size() - 1; i >= 0; --i) {
5994 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5995 bounds[i].getDefiningOp())) {
5996 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
6014 std::vector<llvm::Value *> dimensionIndexSizeOffset;
6015 for (
int i = bounds.size() - 1; i >= 0; --i) {
6016 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
6017 bounds[i].getDefiningOp())) {
6018 if (i == ((
int)bounds.size() - 1))
6020 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
6022 idx.back() = builder.CreateAdd(
6023 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
6024 boundOp.getExtent())),
6025 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
6034 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
6035 return cast<IntegerAttr>(value).getInt();
6043 omp::MapInfoOp parentOp) {
6045 if (parentOp.getMembers().empty())
6049 if (parentOp.getMembers().size() == 1) {
6050 overlapMapDataIdxs.push_back(0);
6054 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
6055 size_t numMembers = indexAttr.size();
6059 for (
auto [i, indicesAttr] : llvm::enumerate(indexAttr))
6060 getAsIntegers(cast<ArrayAttr>(indicesAttr), memberIndices[i]);
6066 llvm::SmallDenseSet<size_t> skipIndices;
6067 for (
size_t i = 0; i < numMembers; ++i) {
6068 const auto &iIndices = memberIndices[i];
6069 for (
size_t j = 0;
j < numMembers; ++
j) {
6072 const auto &jIndices = memberIndices[
j];
6074 if (jIndices.size() < iIndices.size() &&
6075 std::equal(jIndices.begin(), jIndices.end(), iIndices.begin())) {
6076 skipIndices.insert(i);
6083 for (
size_t i = 0; i < numMembers; ++i)
6084 if (!skipIndices.contains(i))
6085 overlapMapDataIdxs.push_back(i);
6097 if (mapOp.getVarPtrPtr())
6120 llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData,
6121 size_t mapDataIdx, MapInfosTy &combinedInfo,
6122 TargetDirectiveEnumTy targetDirective,
6123 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6124 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6125 bool isTargetParam =
true,
int mapDataParentIdx = -1) {
6126 auto mapFlag = mapData.Types[mapDataIdx];
6127 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
6131 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6132 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6138 if (isTargetParam &&
6139 (targetDirective == TargetDirectiveEnumTy::Target &&
6140 !mapData.IsDeclareTarget[mapDataIdx]) &&
6142 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
6144 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
6146 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
6155 if (memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE) {
6156 if (!isPtrTy && !isAttachMap)
6157 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6164 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6174 if (isPtrTy && !isAttachMap && mapData.IsDeclareTarget[mapDataIdx])
6175 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
6184 !bitEnumContainsAll(mapInfoOp.getMapType(),
6185 omp::ClauseMapFlags::ref_ptr) &&
6186 bitEnumContainsAll(mapInfoOp.getMapType(), omp::ClauseMapFlags::ref_ptee);
6187 bool isRefPtrPtee = bitEnumContainsAll(mapInfoOp.getMapType(),
6188 omp::ClauseMapFlags::ref_ptr |
6189 omp::ClauseMapFlags::ref_ptee);
6191 if (!mapInfoOp->getParentOfType<omp::DeclareMapperOp>() &&
6192 mapDataParentIdx >= 0 && !(isRefPtee || (isRefPtrPtee && isPtrTy))) {
6193 combinedInfo.BasePointers.emplace_back(
6194 mapData.BasePointers[mapDataParentIdx]);
6196 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
6199 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
6200 combinedInfo.DevicePointers.emplace_back(
6201 memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE
6202 ? llvm::OpenMPIRBuilder::DeviceInfoTy::None
6203 : mapData.DevicePointers[mapDataIdx]);
6204 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
6205 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
6206 combinedInfo.Types.emplace_back(mapFlag);
6207 combinedInfo.Sizes.emplace_back(
6208 isPtrTy ? builder.CreateSelect(
6209 builder.CreateIsNull(mapData.Pointers[mapDataIdx]),
6210 builder.getInt64(0), mapData.Sizes[mapDataIdx])
6211 : mapData.Sizes[mapDataIdx]);
6231 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
6232 MapInfoData &mapData, uint64_t mapDataIndex,
6233 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
6234 TargetDirectiveEnumTy targetDirective) {
6235 using MapFlags = llvm::omp::OpenMPOffloadMappingFlags;
6236 assert(!ompBuilder.Config.isTargetDevice() &&
6237 "function only supported for host device codegen");
6239 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6240 auto *parentMapper = mapData.Mappers[mapDataIndex];
6246 MapFlags baseFlag = (targetDirective == TargetDirectiveEnumTy::Target &&
6247 !mapData.IsDeclareTarget[mapDataIndex])
6248 ? MapFlags::OMP_MAP_TARGET_PARAM
6249 : MapFlags::OMP_MAP_NONE;
6255 MapFlags parentFlags = mapData.Types[mapDataIndex];
6256 MapFlags preserve = MapFlags::OMP_MAP_TO | MapFlags::OMP_MAP_FROM |
6257 MapFlags::OMP_MAP_ALWAYS | MapFlags::OMP_MAP_CLOSE |
6258 MapFlags::OMP_MAP_PRESENT |
6259 MapFlags::OMP_MAP_OMPX_HOLD |
6260 MapFlags::OMP_MAP_IMPLICIT;
6261 baseFlag |= (parentFlags & preserve);
6263 MapFlags parentFlags = mapData.Types[mapDataIndex];
6265 MapFlags::OMP_MAP_PRESENT | MapFlags::OMP_MAP_RETURN_PARAM;
6266 baseFlag |= (parentFlags & preserve);
6269 combinedInfo.Types.emplace_back(baseFlag);
6270 combinedInfo.DevicePointers.emplace_back(
6271 mapData.DevicePointers[mapDataIndex]);
6275 combinedInfo.Mappers.emplace_back(
6276 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
6278 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6279 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
6288 llvm::Value *lowAddr, *highAddr;
6289 if (!parentClause.getPartialMap()) {
6290 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6291 builder.getPtrTy());
6292 highAddr = builder.CreatePointerCast(
6293 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6294 mapData.Pointers[mapDataIndex], 1),
6295 builder.getPtrTy());
6296 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6298 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6301 lowAddr = builder.CreatePointerCast(mapData.BasePointers[firstMemberIdx],
6302 builder.getPtrTy());
6306 auto lastMemberMapInfo =
6307 cast<omp::MapInfoOp>(mapData.MapClause[lastMemberIdx]);
6316 bool isRefPteeMap = bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6317 omp::ClauseMapFlags::ref_ptee) &&
6318 !bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6319 omp::ClauseMapFlags::ref_ptr);
6320 llvm::Type *castType = mapData.BaseType[lastMemberIdx];
6323 moduleTranslation.
convertType(lastMemberMapInfo.getVarPtrType());
6324 highAddr = builder.CreatePointerCast(
6325 builder.CreateGEP(castType, mapData.BasePointers[lastMemberIdx],
6326 builder.getInt64(1)),
6327 builder.getPtrTy());
6328 combinedInfo.Pointers.emplace_back(mapData.BasePointers[firstMemberIdx]);
6331 llvm::Value *size = builder.CreateIntCast(
6332 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6333 builder.getInt64Ty(),
6335 combinedInfo.Sizes.push_back(size);
6343 if (!parentClause.getPartialMap()) {
6348 MapFlags mapFlag = mapData.Types[mapDataIndex];
6349 bool hasMapClose = (MapFlags(mapFlag) & MapFlags::OMP_MAP_CLOSE) ==
6350 MapFlags::OMP_MAP_CLOSE;
6351 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6367 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose ||
6368 overlapIdxs.size() == 1) {
6369 combinedInfo.Types.emplace_back(mapFlag);
6370 combinedInfo.DevicePointers.emplace_back(
6371 mapData.DevicePointers[mapDataIndex]);
6373 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6374 combinedInfo.BasePointers.emplace_back(
6375 mapData.BasePointers[mapDataIndex]);
6376 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6377 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
6378 combinedInfo.Mappers.emplace_back(
nullptr);
6384 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6385 builder.getPtrTy());
6386 highAddr = builder.CreatePointerCast(
6387 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6388 mapData.Pointers[mapDataIndex], 1),
6389 builder.getPtrTy());
6396 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6403 for (
auto v : overlapIdxs) {
6406 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
6408 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataOverlapIdx]));
6409 combinedInfo.Types.emplace_back(mapFlag);
6410 combinedInfo.DevicePointers.emplace_back(
6411 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6413 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6414 combinedInfo.BasePointers.emplace_back(
6415 mapData.BasePointers[mapDataIndex]);
6416 combinedInfo.Mappers.emplace_back(
nullptr);
6417 combinedInfo.Pointers.emplace_back(lowAddr);
6418 auto sizeCalc = builder.CreateIntCast(
6419 builder.CreatePtrDiff(builder.getInt8Ty(),
6420 mapData.OriginalValue[mapDataOverlapIdx],
6422 builder.getInt64Ty(),
true);
6427 auto sizeSel = builder.CreateSelect(
6428 builder.CreateICmpNE(builder.getInt64(0), sizeCalc), sizeCalc,
6429 isPtrMap ? llvm::ConstantExpr::getSizeOf(builder.getPtrTy())
6430 : mapData.Sizes[mapDataOverlapIdx]);
6431 combinedInfo.Sizes.emplace_back(sizeSel);
6432 lowAddr = builder.CreateConstGEP1_32(
6433 isPtrMap ? builder.getPtrTy() : mapData.BaseType[mapDataOverlapIdx],
6434 mapData.BasePointers[mapDataOverlapIdx], 1);
6437 combinedInfo.Types.emplace_back(mapFlag);
6438 combinedInfo.DevicePointers.emplace_back(
6439 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6441 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6442 combinedInfo.BasePointers.emplace_back(
6443 mapData.BasePointers[mapDataIndex]);
6444 combinedInfo.Mappers.emplace_back(
nullptr);
6445 combinedInfo.Pointers.emplace_back(lowAddr);
6446 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
6447 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6448 builder.getInt64Ty(),
true));
6454 llvm::IRBuilderBase &builder,
6455 llvm::OpenMPIRBuilder &ompBuilder,
6457 MapInfoData &mapData, uint64_t mapDataIndex,
6458 TargetDirectiveEnumTy targetDirective) {
6459 assert(!ompBuilder.Config.isTargetDevice() &&
6460 "function only supported for host device codegen");
6463 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6468 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6469 auto memberClause = llvm::cast<omp::MapInfoOp>(
6470 parentClause.getMembers()[0].getDefiningOp());
6483 builder, ompBuilder, mapData, memberDataIdx, combinedInfo,
6485 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6486 true, mapDataIndex);
6490 auto collectMapInfoIdxs =
6493 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6495 for (
auto member : parentClause.getMembers())
6497 mapData, llvm::cast<omp::MapInfoOp>(member.getDefiningOp())));
6501 collectMapInfoIdxs(mapInfoIdx);
6503 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6504 ompBuilder.getMemberOfFlag(combinedInfo.Types.size());
6505 for (
size_t i = 0; i < mapInfoIdx.size(); i++) {
6510 combinedInfo, mapData, mapInfoIdx[i], memberOfFlag,
6514 combinedInfo, targetDirective, memberOfFlag,
6515 false, mapDataIndex);
6527 llvm::IRBuilderBase &builder) {
6529 "function only supported for host device codegen");
6530 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6531 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6534 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6535 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6540 if (!mapData.IsDeclareTarget[i] ||
6541 (mapData.IsDeclareTarget[i] && isAttachMap)) {
6542 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6552 switch (captureKind) {
6553 case omp::VariableCaptureKind::ByRef: {
6554 llvm::Value *newV = mapData.Pointers[i];
6556 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6559 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6561 if (!offsetIdx.empty())
6562 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6564 mapData.Pointers[i] = newV;
6566 case omp::VariableCaptureKind::ByCopy: {
6567 llvm::Type *type = mapData.BaseType[i];
6569 if (mapData.Pointers[i]->getType()->isPointerTy())
6570 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6572 newV = mapData.Pointers[i];
6575 auto curInsert = builder.saveIP();
6576 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6578 auto *memTempAlloc =
6579 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
6580 builder.SetCurrentDebugLocation(DbgLoc);
6581 builder.restoreIP(curInsert);
6583 builder.CreateStore(newV, memTempAlloc);
6584 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6587 mapData.Pointers[i] = newV;
6588 mapData.BasePointers[i] = newV;
6590 case omp::VariableCaptureKind::This:
6591 case omp::VariableCaptureKind::VLAType:
6592 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
6603 MapInfoData &mapData,
6604 TargetDirectiveEnumTy targetDirective) {
6606 "function only supported for host device codegen");
6627 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6628 if (mapData.IsAMember[i])
6631 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6632 if (!mapInfoOp.getMembers().empty()) {
6634 combinedInfo, mapData, i, targetDirective);
6643static llvm::Expected<llvm::Function *>
6645 LLVM::ModuleTranslation &moduleTranslation,
6646 llvm::StringRef mapperFuncName,
6647 TargetDirectiveEnumTy targetDirective);
6649static llvm::Expected<llvm::Function *>
6652 TargetDirectiveEnumTy targetDirective) {
6654 "function only supported for host device codegen");
6655 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6656 std::string mapperFuncName =
6658 {
"omp_mapper", declMapperOp.getSymName()});
6660 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
6668 if (llvm::Function *existingFunc =
6669 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
6670 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
6671 return existingFunc;
6675 mapperFuncName, targetDirective);
6678static llvm::Expected<llvm::Function *>
6681 llvm::StringRef mapperFuncName,
6682 TargetDirectiveEnumTy targetDirective) {
6684 "function only supported for host device codegen");
6685 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6686 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6689 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
6692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6695 MapInfosTy combinedInfo;
6697 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6698 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6699 builder.restoreIP(codeGenIP);
6700 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
6701 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
6702 builder.GetInsertBlock());
6703 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
6706 return llvm::make_error<PreviouslyReportedError>();
6707 MapInfoData mapData;
6710 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6716 return combinedInfo;
6720 if (!combinedInfo.Mappers[i])
6723 moduleTranslation, targetDirective);
6727 genMapInfoCB, varType, mapperFuncName, customMapperCB,
6730 return newFn.takeError();
6731 if ([[maybe_unused]] llvm::Function *mappedFunc =
6733 assert(mappedFunc == *newFn &&
6734 "mapper function mapping disagrees with emitted function");
6736 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
6744 llvm::Value *ifCond =
nullptr;
6745 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6749 llvm::omp::RuntimeFunction RTLFn;
6751 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6754 llvm::OpenMPIRBuilder::TargetDataInfo info(
6757 assert(!ompBuilder->Config.isTargetDevice() &&
6758 "target data/enter/exit/update are host ops");
6759 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6761 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
6762 llvm::Value *v = moduleTranslation.
lookupValue(dev);
6763 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
6768 .Case([&](omp::TargetDataOp dataOp) {
6772 if (
auto ifVar = dataOp.getIfExpr())
6776 deviceID = getDeviceID(devId);
6778 mapVars = dataOp.getMapVars();
6779 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6780 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6783 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6787 if (
auto ifVar = enterDataOp.getIfExpr())
6791 deviceID = getDeviceID(devId);
6794 enterDataOp.getNowait()
6795 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6796 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6797 mapVars = enterDataOp.getMapVars();
6798 info.HasNoWait = enterDataOp.getNowait();
6801 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6805 if (
auto ifVar = exitDataOp.getIfExpr())
6809 deviceID = getDeviceID(devId);
6811 RTLFn = exitDataOp.getNowait()
6812 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6813 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6814 mapVars = exitDataOp.getMapVars();
6815 info.HasNoWait = exitDataOp.getNowait();
6818 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6822 if (
auto ifVar = updateDataOp.getIfExpr())
6826 deviceID = getDeviceID(devId);
6829 updateDataOp.getNowait()
6830 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6831 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6832 mapVars = updateDataOp.getMapVars();
6833 info.HasNoWait = updateDataOp.getNowait();
6836 .DefaultUnreachable(
"unexpected operation");
6841 if (!isOffloadEntry)
6842 ifCond = builder.getFalse();
6844 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6845 MapInfoData mapData;
6847 builder, useDevicePtrVars, useDeviceAddrVars);
6850 MapInfosTy combinedInfo;
6851 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6852 builder.restoreIP(codeGenIP);
6853 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6855 return combinedInfo;
6861 [&moduleTranslation](
6862 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6866 for (
auto [arg, useDevVar] :
6867 llvm::zip_equal(blockArgs, useDeviceVars)) {
6869 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6870 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6871 : mapInfoOp.getVarPtr();
6874 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6875 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6876 mapInfoData.MapClause, mapInfoData.DevicePointers,
6877 mapInfoData.BasePointers)) {
6878 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6879 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6880 devicePointer != type)
6883 if (llvm::Value *devPtrInfoMap =
6884 mapper ? mapper(basePointer) : basePointer) {
6885 moduleTranslation.
mapValue(arg, devPtrInfoMap);
6892 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6893 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6894 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6897 builder.restoreIP(codeGenIP);
6898 assert(isa<omp::TargetDataOp>(op) &&
6899 "BodyGen requested for non TargetDataOp");
6900 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6901 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
6902 switch (bodyGenType) {
6903 case BodyGenTy::Priv:
6905 if (!info.DevicePtrInfoMap.empty()) {
6906 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6907 blockArgIface.getUseDeviceAddrBlockArgs(),
6908 useDeviceAddrVars, mapData,
6909 [&](llvm::Value *basePointer) -> llvm::Value * {
6910 if (!info.DevicePtrInfoMap[basePointer].second)
6912 return builder.CreateLoad(
6914 info.DevicePtrInfoMap[basePointer].second);
6916 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6917 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6918 mapData, [&](llvm::Value *basePointer) {
6919 return info.DevicePtrInfoMap[basePointer].second;
6923 moduleTranslation)))
6924 return llvm::make_error<PreviouslyReportedError>();
6927 case BodyGenTy::DupNoPriv:
6928 if (info.DevicePtrInfoMap.empty()) {
6931 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6932 blockArgIface.getUseDeviceAddrBlockArgs(),
6933 useDeviceAddrVars, mapData);
6934 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6935 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6939 case BodyGenTy::NoPriv:
6941 if (info.DevicePtrInfoMap.empty()) {
6943 moduleTranslation)))
6944 return llvm::make_error<PreviouslyReportedError>();
6948 return builder.saveIP();
6951 auto customMapperCB =
6953 if (!combinedInfo.Mappers[i])
6955 info.HasMapper =
true;
6957 moduleTranslation, targetDirective);
6960 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6962 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6964 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6965 if (isa<omp::TargetDataOp>(op))
6966 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6967 deallocBlocks, deviceID, ifCond, info,
6968 genMapInfoCB, customMapperCB,
6971 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6972 deallocBlocks, deviceID, ifCond, info,
6973 genMapInfoCB, customMapperCB, &RTLFn);
6979 builder.restoreIP(*afterIP);
6987 auto distributeOp = cast<omp::DistributeOp>(opInst);
6994 bool doDistributeReduction =
6998 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
7003 if (doDistributeReduction) {
7004 isByRef =
getIsByRef(teamsOp.getReductionByref());
7005 assert(isByRef.size() == teamsOp.getNumReductionVars());
7008 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7012 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
7013 .getReductionBlockArgs();
7016 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
7017 reductionDecls, privateReductionVariables, reductionVariableMap,
7022 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7024 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7029 moduleTranslation, allocaIP, deallocBlocks);
7032 builder.restoreIP(codeGenIP);
7036 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
7038 return llvm::make_error<PreviouslyReportedError>();
7043 return llvm::make_error<PreviouslyReportedError>();
7046 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
7048 distributeOp.getPrivateNeedsBarrier())))
7049 return llvm::make_error<PreviouslyReportedError>();
7052 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7055 builder, moduleTranslation);
7057 return regionBlock.takeError();
7058 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
7063 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
7066 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
7067 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
7068 : omp::ClauseScheduleKind::Static;
7070 bool isOrdered = hasDistSchedule;
7071 std::optional<omp::ScheduleModifier> scheduleMod;
7072 bool isSimd =
false;
7073 llvm::omp::WorksharingLoopType workshareLoopType =
7074 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
7075 bool loopNeedsBarrier =
false;
7076 llvm::Value *chunk = moduleTranslation.
lookupValue(
7077 distributeOp.getDistScheduleChunkSize());
7078 llvm::CanonicalLoopInfo *loopInfo =
7080 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
7081 ompBuilder->applyWorkshareLoop(
7082 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
7083 convertToScheduleKind(schedule), chunk, isSimd,
7084 scheduleMod == omp::ScheduleModifier::monotonic,
7085 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
7086 workshareLoopType,
false, hasDistSchedule, chunk);
7089 return wsloopIP.takeError();
7092 distributeOp.getLoc(), privVarsInfo)))
7093 return llvm::make_error<PreviouslyReportedError>();
7095 return llvm::Error::success();
7099 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7101 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7102 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7103 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
7108 builder.restoreIP(*afterIP);
7110 if (doDistributeReduction) {
7113 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
7114 privateReductionVariables, isByRef,
7126 if (!cast<mlir::ModuleOp>(op))
7131 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
7132 attribute.getOpenmpDeviceVersion());
7134 if (attribute.getNoGpuLib())
7137 ompBuilder->createGlobalFlag(
7138 attribute.getDebugKind() ,
7139 "__omp_rtl_debug_kind");
7140 ompBuilder->createGlobalFlag(
7142 .getAssumeTeamsOversubscription()
7144 "__omp_rtl_assume_teams_oversubscription");
7145 ompBuilder->createGlobalFlag(
7147 .getAssumeThreadsOversubscription()
7149 "__omp_rtl_assume_threads_oversubscription");
7150 ompBuilder->createGlobalFlag(
7151 attribute.getAssumeNoThreadState() ,
7152 "__omp_rtl_assume_no_thread_state");
7153 ompBuilder->createGlobalFlag(
7155 .getAssumeNoNestedParallelism()
7157 "__omp_rtl_assume_no_nested_parallelism");
7162 omp::TargetOp targetOp,
7163 llvm::OpenMPIRBuilder &ompBuilder,
7164 llvm::vfs::FileSystem &vfs,
7165 llvm::StringRef parentName =
"") {
7166 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
7167 assert(fileLoc &&
"No file found from location");
7169 auto fileInfoCallBack = [&fileLoc]() {
7170 return std::pair<std::string, uint64_t>(
7171 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
7175 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
7181 llvm::IRBuilderBase &builder, llvm::Function *
func) {
7183 "function only supported for target device codegen");
7184 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7185 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
7198 if (!mapData.IsDeclareTarget[i])
7206 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
7207 convertUsersOfConstantsToInstructions(constant,
func,
false);
7214 for (llvm::User *user : mapData.OriginalValue[i]->users())
7215 userVec.push_back(user);
7217 for (llvm::User *user : userVec) {
7218 auto *insn = dyn_cast<llvm::Instruction>(user);
7219 if (!insn || insn->getFunction() !=
func)
7221 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7222 llvm::Value *substitute = mapData.BasePointers[i];
7224 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
7228 ->Config.hasRequiresUnifiedSharedMemory())) {
7229 builder.SetCurrentDebugLocation(insn->getDebugLoc());
7230 substitute = builder.CreateLoad(mapData.BasePointers[i]->getType(),
7231 mapData.BasePointers[i]);
7232 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
7234 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
7279 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
7280 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
7281 llvm::OpenMPIRBuilder &ompBuilder,
7283 llvm::IRBuilderBase::InsertPoint allocaIP,
7284 llvm::IRBuilderBase::InsertPoint codeGenIP,
7286 assert(ompBuilder.Config.isTargetDevice() &&
7287 "function only supported for target device codegen");
7288 builder.restoreIP(allocaIP);
7290 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
7292 ompBuilder.M.getContext());
7293 unsigned alignmentValue = 0;
7296 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
7299 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
7300 if (mapData.OriginalValue[i] == input) {
7301 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7302 capture = mapOp.getMapCaptureType();
7305 mapOp.getVarPtrType(), ompBuilder.M.getDataLayout());
7309 for (
auto &[val, arg] : blockArgsPairs) {
7310 if (mapOp.getResult() == val) {
7315 assert(mlirArg &&
"expected to find entry block argument for map clause");
7320 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
7321 unsigned int defaultAS =
7322 ompBuilder.M.getDataLayout().getProgramAddressSpace();
7325 llvm::Value *v =
nullptr;
7333 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
7334 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
7338 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7339 for (
auto deallocIP : deallocIPs) {
7340 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
7341 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
7345 v = builder.CreateAlloca(arg.getType(), allocaAS);
7347 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
7348 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
7351 builder.CreateStore(&arg, v);
7353 builder.restoreIP(codeGenIP);
7356 case omp::VariableCaptureKind::ByCopy: {
7360 case omp::VariableCaptureKind::ByRef: {
7361 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
7363 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
7378 if (v->getType()->isPointerTy() && alignmentValue) {
7379 llvm::MDBuilder MDB(builder.getContext());
7380 loadInst->setMetadata(
7381 llvm::LLVMContext::MD_align,
7382 llvm::MDNode::get(builder.getContext(),
7383 MDB.createConstant(llvm::ConstantInt::get(
7384 llvm::Type::getInt64Ty(builder.getContext()),
7391 case omp::VariableCaptureKind::This:
7392 case omp::VariableCaptureKind::VLAType:
7395 assert(
false &&
"Currently unsupported capture kind");
7399 return builder.saveIP();
7416 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
7417 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
7418 blockArgIface.getHostEvalBlockArgs())) {
7419 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
7423 .Case([&](omp::TeamsOp teamsOp) {
7424 if (teamsOp.getNumTeamsLower() == blockArg)
7425 numTeamsLower = hostEvalVar;
7426 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
7428 numTeamsUpper = hostEvalVar;
7429 else if (!teamsOp.getThreadLimitVars().empty() &&
7430 teamsOp.getThreadLimit(0) == blockArg)
7431 threadLimit = hostEvalVar;
7433 llvm_unreachable(
"unsupported host_eval use");
7435 .Case([&](omp::ParallelOp parallelOp) {
7436 if (!parallelOp.getNumThreadsVars().empty() &&
7437 parallelOp.getNumThreads(0) == blockArg)
7438 numThreads = hostEvalVar;
7440 llvm_unreachable(
"unsupported host_eval use");
7442 .Case([&](omp::LoopNestOp loopOp) {
7443 auto processBounds =
7447 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
7448 if (lb == blockArg) {
7451 (*outBounds)[i] = hostEvalVar;
7457 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
7458 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
7460 found = processBounds(loopOp.getLoopSteps(), steps) || found;
7462 assert(found &&
"unsupported host_eval use");
7464 .DefaultUnreachable(
"unsupported host_eval use");
7476template <
typename OpTy>
7481 if (OpTy casted = dyn_cast<OpTy>(op))
7484 if (immediateParent)
7485 return dyn_cast_if_present<OpTy>(op->
getParentOp());
7494 return std::nullopt;
7497 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7498 return constAttr.getInt();
7500 return std::nullopt;
7505 uint64_t sizeInBytes = sizeInBits / 8;
7509template <
typename OpTy>
7511 if (op.getNumReductionVars() > 0) {
7516 members.reserve(reductions.size());
7517 for (omp::DeclareReductionOp &red : reductions) {
7521 if (red.getByrefElementType())
7522 members.push_back(*red.getByrefElementType());
7524 members.push_back(red.getType());
7527 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7543 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7544 bool isTargetDevice,
bool isGPU) {
7547 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7548 if (!isTargetDevice) {
7556 numTeamsLower = teamsOp.getNumTeamsLower();
7558 if (!teamsOp.getNumTeamsUpperVars().empty())
7559 numTeamsUpper = teamsOp.getNumTeams(0);
7560 if (!teamsOp.getThreadLimitVars().empty())
7561 threadLimit = teamsOp.getThreadLimit(0);
7565 if (!parallelOp.getNumThreadsVars().empty())
7566 numThreads = parallelOp.getNumThreads(0);
7572 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7576 if (numTeamsUpper) {
7578 minTeamsVal = maxTeamsVal = *val;
7580 minTeamsVal = maxTeamsVal = 0;
7586 minTeamsVal = maxTeamsVal = 1;
7588 minTeamsVal = maxTeamsVal = -1;
7593 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
7607 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7608 if (!targetOp.getThreadLimitVars().empty())
7609 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7610 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7613 int32_t maxThreadsVal = -1;
7615 setMaxValueFromClause(numThreads, maxThreadsVal);
7623 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7624 if (combinedMaxThreadsVal < 0 ||
7625 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7626 combinedMaxThreadsVal = teamsThreadLimitVal;
7628 if (combinedMaxThreadsVal < 0 ||
7629 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7630 combinedMaxThreadsVal = maxThreadsVal;
7632 int32_t reductionDataSize = 0;
7633 if (isGPU && capturedOp) {
7639 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7641 case omp::TargetExecMode::bare:
7642 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7644 case omp::TargetExecMode::generic:
7645 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7647 case omp::TargetExecMode::spmd:
7648 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7650 case omp::TargetExecMode::no_loop:
7651 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7654 attrs.MinTeams = minTeamsVal;
7655 attrs.MaxTeams.front() = maxTeamsVal;
7656 attrs.MinThreads = 1;
7657 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7658 attrs.ReductionDataSize = reductionDataSize;
7661 if (attrs.ReductionDataSize != 0)
7662 attrs.ReductionBufferLength = 1024;
7674 omp::TargetOp targetOp,
Operation *capturedOp,
7675 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7677 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7679 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7683 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7686 if (!targetOp.getThreadLimitVars().empty()) {
7687 Value targetThreadLimit = targetOp.getThreadLimit(0);
7688 attrs.TargetThreadLimit.front() =
7696 attrs.MinTeams = builder.CreateSExtOrTrunc(
7697 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
7700 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7701 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
7703 if (teamsThreadLimit)
7704 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7705 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
7708 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
7710 bool hostEvalTripCount;
7711 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7712 if (hostEvalTripCount) {
7714 attrs.LoopTripCount =
nullptr;
7719 for (
auto [loopLower, loopUpper, loopStep] :
7720 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7721 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
7722 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
7723 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
7725 if (!lowerBound || !upperBound || !step) {
7726 attrs.LoopTripCount =
nullptr;
7730 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7731 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7732 loc, lowerBound, upperBound, step,
true,
7733 loopOp.getLoopInclusive());
7735 if (!attrs.LoopTripCount) {
7736 attrs.LoopTripCount = tripCount;
7741 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7746 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7748 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
7750 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7754static llvm::omp::OMPDynGroupprivateFallbackType
7756 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7757 : omp::FallbackModifier::default_mem;
7759 case omp::FallbackModifier::abort:
7760 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7761 case omp::FallbackModifier::null:
7762 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7763 case omp::FallbackModifier::default_mem:
7764 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7767 llvm_unreachable(
"unexpected dyn_groupprivate fallback type");
7773 auto targetOp = cast<omp::TargetOp>(opInst);
7778 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7787 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7788 assert(parentBB &&
"No insert block is set for the builder");
7789 llvm::Function *parentLLVMFn = parentBB->getParent();
7790 assert(parentLLVMFn &&
"Parent Function must be valid");
7791 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7792 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7793 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7794 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7797 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7798 bool isGPU = ompBuilder->Config.isGPU();
7801 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7802 auto &targetRegion = targetOp.getRegion();
7819 llvm::Function *llvmOutlinedFn =
nullptr;
7820 TargetDirectiveEnumTy targetDirective =
7821 getTargetDirectiveEnumTyFromOp(&opInst);
7825 bool isOffloadEntry =
7826 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7833 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7835 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7836 std::optional<DenseI64ArrayAttr> privateMapIndices =
7837 targetOp.getPrivateMapsAttr();
7839 for (
auto [privVarIdx, privVarSymPair] :
7840 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7841 auto privVar = std::get<0>(privVarSymPair);
7842 auto privSym = std::get<1>(privVarSymPair);
7844 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7845 omp::PrivateClauseOp privatizer =
7848 if (!privatizer.needsMap())
7852 targetOp.getMappedValueForPrivateVar(privVarIdx);
7853 assert(mappedValue &&
"Expected to find mapped value for a privatized "
7854 "variable that needs mapping");
7859 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
7860 [[maybe_unused]]
Type varType = mapInfoOp.getVarPtrType();
7864 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7866 varType == privVar.getType() &&
7867 "Type of private var doesn't match the type of the mapped value");
7871 mappedPrivateVars.insert(
7873 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7874 (*privateMapIndices)[privVarIdx])});
7878 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7879 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7881 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7882 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7883 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7886 llvm::Function *llvmParentFn =
7888 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7889 assert(llvmParentFn && llvmOutlinedFn &&
7890 "Both parent and outlined functions must exist at this point");
7892 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7893 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7895 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
7896 attr.isStringAttribute())
7897 llvmOutlinedFn->addFnAttr(attr);
7899 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
7900 attr.isStringAttribute())
7901 llvmOutlinedFn->addFnAttr(attr);
7903 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7904 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7905 llvm::Value *mapOpValue =
7906 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7907 moduleTranslation.
mapValue(arg, mapOpValue);
7909 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7910 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7911 llvm::Value *mapOpValue =
7912 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7913 moduleTranslation.
mapValue(arg, mapOpValue);
7922 privateVarsInfo, allocaIP, &mappedPrivateVars);
7925 return llvm::make_error<PreviouslyReportedError>();
7927 builder.restoreIP(codeGenIP);
7929 &mappedPrivateVars),
7932 return llvm::make_error<PreviouslyReportedError>();
7935 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
7937 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7938 return llvm::make_error<PreviouslyReportedError>();
7941 moduleTranslation, allocaIP, deallocBlocks);
7943 targetRegion,
"omp.target", builder, moduleTranslation);
7946 return llvm::make_error<PreviouslyReportedError>();
7948 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7951 targetOp.getLoc(), privateVarsInfo)))
7952 return llvm::make_error<PreviouslyReportedError>();
7954 return builder.saveIP();
7957 StringRef parentName = parentFn.getName();
7959 llvm::TargetRegionEntryInfo entryInfo;
7965 MapInfoData mapData;
7970 MapInfosTy combinedInfos;
7972 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7973 builder.restoreIP(codeGenIP);
7974 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7979 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7980 combinedInfos.BasePointers.push_back(nullPtr);
7981 combinedInfos.Pointers.push_back(nullPtr);
7982 combinedInfos.DevicePointers.push_back(
7983 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7984 combinedInfos.Sizes.push_back(builder.getInt64(0));
7985 combinedInfos.Types.push_back(
7986 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7987 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7988 if (!combinedInfos.Names.empty())
7989 combinedInfos.Names.push_back(nullPtr);
7990 combinedInfos.Mappers.push_back(
nullptr);
7992 return combinedInfos;
7995 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7996 llvm::Value *&retVal, InsertPointTy allocaIP,
7997 InsertPointTy codeGenIP,
7999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
8000 llvm::IRBuilderBase::InsertPointGuard guard(builder);
8001 builder.SetCurrentDebugLocation(llvm::DebugLoc());
8007 if (!isTargetDevice) {
8008 retVal = cast<llvm::Value>(&arg);
8013 builder, *ompBuilder, moduleTranslation,
8014 allocaIP, codeGenIP, deallocIPs);
8017 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
8018 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
8019 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
8021 isTargetDevice, isGPU);
8025 if (!isTargetDevice)
8027 targetCapturedOp, runtimeAttrs);
8035 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
8036 llvm::Value *value = moduleTranslation.
lookupValue(var);
8037 moduleTranslation.
mapValue(arg, value);
8039 if (!llvm::isa<llvm::Constant>(value))
8040 kernelInput.push_back(value);
8043 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
8052 bool isAttachMap = (mapData.Types[i] &
8053 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
8054 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
8055 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i] && !isAttachMap)
8056 kernelInput.push_back(mapData.OriginalValue[i]);
8060 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
8063 llvm::OpenMPIRBuilder::DependenciesInfo dds;
8065 targetOp.getDependVars(), targetOp.getDependKinds(),
8066 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
8067 builder, moduleTranslation, dds)))
8070 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8072 llvm::OpenMPIRBuilder::TargetDataInfo info(
8076 auto customMapperCB =
8078 if (!combinedInfos.Mappers[i])
8080 info.HasMapper =
true;
8082 moduleTranslation, targetDirective);
8085 llvm::Value *ifCond =
nullptr;
8086 if (
Value targetIfCond = targetOp.getIfExpr())
8087 ifCond = moduleTranslation.
lookupValue(targetIfCond);
8089 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
8090 llvm::Value *dynSizeVal =
nullptr;
8091 if (dynGroupPrivateSize) {
8092 dynSizeVal = moduleTranslation.
lookupValue(dynGroupPrivateSize);
8093 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
8097 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
8100 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8102 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
8103 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
8104 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
8105 targetOp.getNowait(), dynSizeVal, fallbackType);
8110 builder.restoreIP(*afterIP);
8113 builder.CreateFree(dds.DepArray);
8126 llvm::OpenMPIRBuilder *ompBuilder,
8135 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
8136 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
8138 if (!offloadMod.getIsTargetDevice())
8141 omp::DeclareTargetDeviceType declareType =
8142 attribute.getDeviceType().getValue();
8144 if (declareType == omp::DeclareTargetDeviceType::host) {
8145 llvm::Function *llvmFunc =
8147 llvmFunc->dropAllReferences();
8148 llvmFunc->eraseFromParent();
8152 ompBuilder->Builder.ClearInsertionPoint();
8153 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8159 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
8160 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8161 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
8163 bool isDeclaration = gOp.isDeclaration();
8164 bool isExternallyVisible =
8167 llvm::StringRef mangledName = gOp.getSymName();
8168 auto captureClause =
8174 std::vector<llvm::GlobalVariable *> generatedRefs;
8176 std::vector<llvm::Triple> targetTriple;
8177 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
8179 LLVM::LLVMDialect::getTargetTripleAttrName()));
8180 if (targetTripleAttr)
8181 targetTriple.emplace_back(targetTripleAttr.data());
8183 auto fileInfoCallBack = [&loc]() {
8184 std::string filename =
"";
8185 std::uint64_t lineNo = 0;
8188 filename = loc.getFilename().str();
8189 lineNo = loc.getLine();
8192 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
8196 llvm::vfs::FileSystem &vfs = moduleTranslation.
getFileSystem();
8198 ompBuilder->registerTargetGlobalVariable(
8199 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8200 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8201 mangledName, generatedRefs,
false, targetTriple,
8203 gVal->getType(), gVal);
8205 if (ompBuilder->Config.isTargetDevice() &&
8206 (attribute.getCaptureClause().getValue() !=
8207 mlir::omp::DeclareTargetCaptureClause::to ||
8208 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
8209 ompBuilder->getAddrOfDeclareTargetVar(
8210 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8211 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8212 mangledName, generatedRefs,
false, targetTriple,
8213 gVal->getType(),
nullptr,
8226class OpenMPDialectLLVMIRTranslationInterface
8227 :
public LLVMTranslationDialectInterface {
8229 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
8234 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
8235 LLVM::ModuleTranslation &moduleTranslation)
const final;
8240 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
8241 NamedAttribute attribute,
8242 LLVM::ModuleTranslation &moduleTranslation)
const final;
8247 void registerAllocatedPtr(Value var, llvm::Value *ptr)
const {
8248 ompAllocatedPtrs[var] = ptr;
8253 llvm::Value *lookupAllocatedPtr(Value var)
const {
8254 auto it = ompAllocatedPtrs.find(var);
8255 return it != ompAllocatedPtrs.end() ? it->second :
nullptr;
8267LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
8268 Operation *op, ArrayRef<llvm::Instruction *> instructions,
8269 NamedAttribute attribute,
8270 LLVM::ModuleTranslation &moduleTranslation)
const {
8271 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
8273 .Case(
"omp.is_target_device",
8274 [&](Attribute attr) {
8275 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
8276 llvm::OpenMPIRBuilderConfig &config =
8278 config.setIsTargetDevice(deviceAttr.getValue());
8284 [&](Attribute attr) {
8285 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
8286 llvm::OpenMPIRBuilderConfig &config =
8288 config.setIsGPU(gpuAttr.getValue());
8293 .Case(
"omp.host_ir_filepath",
8294 [&](Attribute attr) {
8295 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
8296 llvm::OpenMPIRBuilder *ompBuilder =
8298 ompBuilder->loadOffloadInfoMetadata(
8299 moduleTranslation.
getFileSystem(), filepathAttr.getValue());
8305 [&](Attribute attr) {
8306 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
8310 .Case(
"omp.version",
8311 [&](Attribute attr) {
8312 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
8313 llvm::OpenMPIRBuilder *ompBuilder =
8315 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
8316 versionAttr.getVersion());
8321 .Case(
"omp.declare_target",
8322 [&](Attribute attr) {
8323 if (
auto declareTargetAttr =
8324 dyn_cast<omp::DeclareTargetAttr>(attr)) {
8325 llvm::OpenMPIRBuilder *ompBuilder =
8328 ompBuilder, moduleTranslation);
8332 .Case(
"omp.requires",
8333 [&](Attribute attr) {
8334 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
8335 using Requires = omp::ClauseRequires;
8336 Requires flags = requiresAttr.getValue();
8337 llvm::OpenMPIRBuilderConfig &config =
8339 config.setHasRequiresReverseOffload(
8340 bitEnumContainsAll(flags, Requires::reverse_offload));
8341 config.setHasRequiresUnifiedAddress(
8342 bitEnumContainsAll(flags, Requires::unified_address));
8343 config.setHasRequiresUnifiedSharedMemory(
8344 bitEnumContainsAll(flags, Requires::unified_shared_memory));
8345 config.setHasRequiresDynamicAllocators(
8346 bitEnumContainsAll(flags, Requires::dynamic_allocators));
8351 .Case(
"omp.target_triples",
8352 [&](Attribute attr) {
8353 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
8354 llvm::OpenMPIRBuilderConfig &config =
8356 config.TargetTriples.clear();
8357 config.TargetTriples.reserve(triplesAttr.size());
8358 for (Attribute tripleAttr : triplesAttr) {
8359 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
8360 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
8368 .Default([](Attribute) {
8384 if (
auto declareTargetIface =
8385 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
8386 parentFn.getOperation()))
8387 if (declareTargetIface.isDeclareTarget() &&
8388 declareTargetIface.getDeclareTargetDeviceType() !=
8389 mlir::omp::DeclareTargetDeviceType::host)
8399 llvm::Module *llvmModule) {
8400 llvm::Type *i64Ty = builder.getInt64Ty();
8401 llvm::Type *i32Ty = builder.getInt32Ty();
8402 llvm::Type *returnType = builder.getPtrTy(0);
8403 llvm::FunctionType *fnType =
8404 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
8405 llvm::Function *
func = cast<llvm::Function>(
8406 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
8410template <
typename T>
8414 llvm::DataLayout dataLayout =
8416 llvm::Type *llvmHeapTy =
8417 moduleTranslation.
convertType(op.getMemElemTypeAttr().getValue());
8419 auto alignment = op.getMemAlignment();
8420 llvm::TypeSize typeSize = llvm::alignTo(
8421 dataLayout.getTypeStoreSize(llvmHeapTy),
8422 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
8424 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8425 return builder.CreateMul(
8427 builder.CreateIntCast(moduleTranslation.
lookupValue(op.getMemArraySize()),
8428 builder.getInt64Ty(),
8435 omp::TargetAllocMemOp op) {
8436 llvm::DataLayout dataLayout =
8438 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(op.getAllocatedType());
8439 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
8440 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8441 for (
auto typeParam : op.getTypeparams()) {
8442 allocSize = builder.CreateMul(
8444 builder.CreateIntCast(moduleTranslation.
lookupValue(typeParam),
8445 builder.getInt64Ty(),
8454 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
8459 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8463 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8465 llvm::Value *allocSize =
8468 llvm::CallInst *call =
8469 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
8470 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
8473 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
8479 llvm::IRBuilderBase &builder,
8483 moduleTranslation.
mapValue(allocMemOp.getResult(),
8484 ompBuilder->createOMPAllocShared(builder, size));
8491 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8492 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8495 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8496 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8497 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8499 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8501 llvm::Value *allocator;
8502 if (
auto allocatorVar = allocateDirOp.getAllocator()) {
8503 allocator = moduleTranslation.
lookupValue(allocatorVar);
8504 if (allocator->getType()->isIntegerTy())
8505 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8506 else if (allocator->getType()->isPointerTy())
8507 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8508 allocator, builder.getPtrTy());
8510 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8513 for (
Value var : vars) {
8514 llvm::Type *llvmVarTy = moduleTranslation.
convertType(var.getType());
8518 llvm::Type *typeToInspect = llvmVarTy;
8519 if (llvmVarTy->isPointerTy()) {
8522 if (
auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8523 typeToInspect = moduleTranslation.
convertType(gop.getGlobalType());
8528 if (
auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8529 llvm::Value *elementCount = builder.getInt64(1);
8530 llvm::Type *currentType = arrTy;
8531 while (
auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8532 elementCount = builder.CreateMul(
8533 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8534 currentType = nestedArrTy->getElementType();
8536 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8538 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8540 size = builder.getInt64(
8541 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8544 uint64_t alignValue =
8545 alignAttr ? alignAttr.value()
8546 : dataLayout.getABITypeAlign(typeToInspect).value();
8547 llvm::Value *alignConst = builder.getInt64(alignValue);
8549 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1),
"",
true);
8550 size = builder.CreateUDiv(size, alignConst);
8551 size = builder.CreateMul(size, alignConst,
"",
true);
8553 std::string allocName =
8554 ompBuilder->createPlatformSpecificName({
".void.addr"});
8555 llvm::CallInst *allocCall;
8556 if (alignAttr.has_value()) {
8557 allocCall = ompBuilder->createOMPAlignedAlloc(
8558 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8562 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8565 ompIface.registerAllocatedPtr(var, allocCall);
8574 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8575 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8577 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8579 llvm::Value *allocator;
8580 if (
auto allocatorVar = freeOp.getAllocator()) {
8581 allocator = moduleTranslation.
lookupValue(allocatorVar);
8582 if (allocator->getType()->isIntegerTy())
8583 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8584 else if (allocator->getType()->isPointerTy())
8585 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8586 allocator, builder.getPtrTy());
8588 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8593 for (
Value var : llvm::reverse(vars)) {
8594 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8596 return opInst.
emitError(
"omp.allocate_free: no allocation recorded");
8597 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator,
"");
8604 llvm::Module *llvmModule) {
8605 llvm::Type *ptrTy = builder.getPtrTy(0);
8606 llvm::Type *i32Ty = builder.getInt32Ty();
8607 llvm::Type *voidTy = builder.getVoidTy();
8608 llvm::FunctionType *fnType =
8609 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
8610 llvm::Function *
func = dyn_cast<llvm::Function>(
8611 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
8618 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8623 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8627 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8630 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
8632 llvm::Value *intToPtr =
8633 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8634 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8640 llvm::IRBuilderBase &builder,
8644 ompBuilder->createOMPFreeShared(
8645 builder, moduleTranslation.
lookupValue(freeMemOp.getHeapref()), size);
8654 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8659 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8663 bool shouldAllocate =
true;
8664 switch (groupprivateOp.getDeviceType().value_or(
8665 mlir::omp::DeclareTargetDeviceType::any)) {
8666 case mlir::omp::DeclareTargetDeviceType::host:
8667 shouldAllocate = !isTargetDevice;
8669 case mlir::omp::DeclareTargetDeviceType::nohost:
8670 shouldAllocate = isTargetDevice;
8672 case mlir::omp::DeclareTargetDeviceType::any:
8673 shouldAllocate =
true;
8679 &opInst, groupprivateOp.getSymNameAttr());
8682 <<
"expected symbol '" << groupprivateOp.getSymName()
8683 <<
"' to reference an LLVM global variable";
8685 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
8686 llvm::Type *varType = moduleTranslation.
convertType(global.getType());
8687 std::string varName = globalValue->getName().str();
8689 llvm::Value *resultPtr;
8690 if (shouldAllocate && isTargetDevice) {
8691 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8692 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8693 unsigned sharedAddressSpace;
8694 if (targetTriple.isAMDGCN())
8695 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8696 else if (targetTriple.isNVPTX())
8697 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8699 return opInst.
emitError() <<
"groupprivate is not supported for target: "
8700 << targetTriple.str();
8701 llvm::GlobalVariable *sharedVar =
new llvm::GlobalVariable(
8702 *llvmModule, varType,
false,
8703 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8704 varName,
nullptr, llvm::GlobalValue::NotThreadLocal,
8707 resultPtr = sharedVar;
8709 if (shouldAllocate && !isTargetDevice)
8710 opInst.
emitWarning(
"groupprivate directive is currently ignored on the "
8711 "host, using original global");
8712 resultPtr = globalValue;
8721LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8722 Operation *op, llvm::IRBuilderBase &builder,
8723 LLVM::ModuleTranslation &moduleTranslation)
const {
8726 if (ompBuilder->Config.isTargetDevice() &&
8727 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8730 return op->
emitOpError() <<
"unsupported host op found in device";
8738 bool isOutermostLoopWrapper =
8739 isa_and_present<omp::LoopWrapperInterface>(op) &&
8740 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
8749 if (isa<omp::TaskloopContextOp>(op))
8750 isOutermostLoopWrapper =
true;
8751 else if (isa<omp::TaskloopWrapperOp>(op))
8752 isOutermostLoopWrapper =
false;
8754 if (isOutermostLoopWrapper)
8755 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
8758 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8759 .Case([&](omp::BarrierOp op) -> LogicalResult {
8763 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8764 ompBuilder->createBarrier(builder.saveIP(),
8765 llvm::omp::OMPD_barrier);
8767 if (res.succeeded()) {
8770 builder.restoreIP(*afterIP);
8774 .Case([&](omp::TaskyieldOp op) {
8778 ompBuilder->createTaskyield(builder.saveIP());
8781 .Case([&](omp::FlushOp op) {
8793 ompBuilder->createFlush(builder.saveIP());
8796 .Case([&](omp::ParallelOp op) {
8799 .Case([&](omp::MaskedOp) {
8802 .Case([&](omp::MasterOp) {
8805 .Case([&](omp::CriticalOp) {
8808 .Case([&](omp::OrderedRegionOp) {
8811 .Case([&](omp::OrderedOp) {
8814 .Case([&](omp::WsloopOp) {
8817 .Case([&](omp::SimdOp) {
8820 .Case([&](omp::AtomicReadOp) {
8823 .Case([&](omp::AtomicWriteOp) {
8826 .Case([&](omp::AtomicUpdateOp op) {
8829 .Case([&](omp::AtomicCaptureOp op) {
8832 .Case([&](omp::AtomicCompareOp op) {
8835 .Case([&](omp::CancelOp op) {
8838 .Case([&](omp::CancellationPointOp op) {
8841 .Case([&](omp::SectionsOp) {
8844 .Case([&](omp::ScopeOp op) {
8847 .Case([&](omp::SingleOp op) {
8850 .Case([&](omp::TeamsOp op) {
8853 .Case([&](omp::TaskOp op) {
8856 .Case([&](omp::TaskloopWrapperOp op) {
8859 .Case([&](omp::TaskloopContextOp op) {
8862 .Case([&](omp::TaskgroupOp op) {
8865 .Case([&](omp::TaskwaitOp op) {
8868 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8869 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8870 omp::CriticalDeclareOp>([](
auto op) {
8883 .Case([&](omp::ThreadprivateOp) {
8886 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8887 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
8890 .Case([&](omp::TargetOp) {
8893 .Case([&](omp::DistributeOp) {
8896 .Case([&](omp::LoopNestOp) {
8899 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8900 omp::AffinityEntryOp, omp::IteratorOp>([&](
auto op) {
8906 .Case([&](omp::NewCliOp op) {
8911 .Case([&](omp::CanonicalLoopOp op) {
8914 .Case([&](omp::UnrollHeuristicOp op) {
8923 .Case([&](omp::TileOp op) {
8924 return applyTile(op, builder, moduleTranslation);
8926 .Case([&](omp::FuseOp op) {
8927 return applyFuse(op, builder, moduleTranslation);
8929 .Case([&](omp::TargetAllocMemOp) {
8932 .Case([&](omp::TargetFreeMemOp) {
8935 .Case([&](omp::AllocateDirOp) {
8938 .Case([&](omp::AllocateFreeOp) {
8942 .Case([&](omp::AllocSharedMemOp op) {
8945 .Case([&](omp::FreeSharedMemOp op) {
8948 .Case([&](omp::GroupprivateOp) {
8951 .Default([&](Operation *inst) {
8953 <<
"not yet implemented: " << inst->
getName();
8956 if (isOutermostLoopWrapper)
8963 registry.
insert<omp::OpenMPDialect>();
8965 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static void mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static void processIndividualMap(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag=llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, bool isTargetParam=true, int mapDataParentIdx=-1)
This function handles the insertion of a single item of map data from MapInfoData into the OMPIRBuild...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo, bool first=true)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.compare operation to LLVM IR.
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from a floating-point comparison predicate....
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from an integer comparison predicate. Returns std::nullopt f...
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.