25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Frontend/OpenMP/OMPConstants.h"
29#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30#include "llvm/IR/Constants.h"
31#include "llvm/IR/DebugInfoMetadata.h"
32#include "llvm/IR/DerivedTypes.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/MDBuilder.h"
35#include "llvm/IR/ReplaceConstant.h"
36#include "llvm/Support/FileSystem.h"
37#include "llvm/Support/VirtualFileSystem.h"
38#include "llvm/TargetParser/Triple.h"
39#include "llvm/Transforms/Utils/ModuleUtils.h"
50static llvm::omp::ScheduleKind
51convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
52 if (!schedKind.has_value())
53 return llvm::omp::OMP_SCHEDULE_Default;
54 switch (schedKind.value()) {
55 case omp::ClauseScheduleKind::Static:
56 return llvm::omp::OMP_SCHEDULE_Static;
57 case omp::ClauseScheduleKind::Dynamic:
58 return llvm::omp::OMP_SCHEDULE_Dynamic;
59 case omp::ClauseScheduleKind::Guided:
60 return llvm::omp::OMP_SCHEDULE_Guided;
61 case omp::ClauseScheduleKind::Auto:
62 return llvm::omp::OMP_SCHEDULE_Auto;
63 case omp::ClauseScheduleKind::Runtime:
64 return llvm::omp::OMP_SCHEDULE_Runtime;
65 case omp::ClauseScheduleKind::Distribute:
66 return llvm::omp::OMP_SCHEDULE_Distribute;
68 llvm_unreachable(
"unhandled schedule clause argument");
73class OpenMPAllocaStackFrame
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
86class OpenMPLoopInfoStackFrame
90 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
109class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
125char PreviouslyReportedError::ID = 0;
136class LinearClauseProcessor {
139 SmallVector<llvm::Value *> linearPreconditionVars;
140 SmallVector<llvm::Value *> linearLoopBodyTemps;
141 SmallVector<llvm::Value *> linearOrigVal;
142 SmallVector<llvm::Value *> linearSteps;
143 SmallVector<llvm::Type *> linearVarTypes;
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
150 void registerType(LLVM::ModuleTranslation &moduleTranslation,
151 mlir::Attribute &ty) {
152 linearVarTypes.push_back(moduleTranslation.
convertType(
153 mlir::cast<mlir::TypeAttr>(ty).getValue()));
157 void createLinearVar(llvm::IRBuilderBase &builder,
158 LLVM::ModuleTranslation &moduleTranslation,
159 llvm::Value *linearVar,
int idx) {
160 linearPreconditionVars.push_back(
161 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
162 llvm::Value *linearLoopBodyTemp =
163 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
164 linearOrigVal.push_back(linearVar);
165 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
169 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
170 mlir::Value &linearStep) {
171 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
175 void initLinearVar(llvm::IRBuilderBase &builder,
176 LLVM::ModuleTranslation &moduleTranslation,
177 llvm::BasicBlock *loopPreHeader) {
178 builder.SetInsertPoint(loopPreHeader->getTerminator());
179 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
180 llvm::LoadInst *linearVarLoad =
181 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
182 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
187 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
188 llvm::Value *loopInductionVar) {
189 builder.SetInsertPoint(loopBody->getTerminator());
190 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
191 llvm::Type *linearVarType = linearVarTypes[index];
192 llvm::Value *iv = loopInductionVar;
193 llvm::Value *step = linearSteps[index];
195 if (!iv->getType()->isIntegerTy())
196 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
199 if (linearVarType->isIntegerTy()) {
201 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
202 step = builder.CreateSExtOrTrunc(step, linearVarType);
204 llvm::LoadInst *linearVarStart =
205 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
206 llvm::Value *mulInst = builder.CreateMul(iv, step);
207 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
208 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
209 }
else if (linearVarType->isFloatingPointTy()) {
211 step = builder.CreateSExtOrTrunc(step, iv->getType());
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
214 llvm::LoadInst *linearVarStart =
215 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
216 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
217 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
218 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
221 "Linear variable must be of integer or floating-point type");
228 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
229 llvm::BasicBlock *loopExit) {
230 linearFinalizationBB = loopExit->splitBasicBlock(
231 loopExit->getTerminator(),
"omp_loop.linear_finalization");
232 linearExitBB = linearFinalizationBB->splitBasicBlock(
233 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
234 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
235 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
239 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
240 finalizeLinearVar(llvm::IRBuilderBase &builder,
241 LLVM::ModuleTranslation &moduleTranslation,
242 llvm::Value *lastIter) {
244 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
245 llvm::Value *loopLastIterLoad = builder.CreateLoad(
246 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
247 llvm::Value *isLast =
248 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
249 llvm::ConstantInt::get(
250 llvm::Type::getInt32Ty(builder.getContext()), 0));
252 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
253 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
254 llvm::LoadInst *linearVarTemp =
255 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
256 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
262 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
263 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
264 linearFinalizationBB->getTerminator()->eraseFromParent();
266 builder.SetInsertPoint(linearExitBB->getTerminator());
268 builder.saveIP(), llvm::omp::OMPD_barrier);
273 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
274 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
275 llvm::LoadInst *linearVarTemp =
276 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
277 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
283 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
285 llvm::SmallVector<llvm::User *> users;
286 for (llvm::User *user : linearOrigVal[varIndex]->users())
287 users.push_back(user);
288 for (
auto *user : users) {
289 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
290 if (userInst->getParent()->getName().str().find(BBName) !=
292 user->replaceUsesOfWith(linearOrigVal[varIndex],
293 linearLoopBodyTemps[varIndex]);
304 SymbolRefAttr symbolName) {
305 omp::PrivateClauseOp privatizer =
308 assert(privatizer &&
"privatizer not found in the symbol table");
319 auto todo = [&op](StringRef clauseName) {
320 return op.
emitError() <<
"not yet implemented: Unhandled clause "
321 << clauseName <<
" in " << op.
getName()
325 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
326 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
327 result = todo(
"allocate");
329 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
331 result = todo(
"ompx_bare");
333 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
334 if (!op.getDependVars().empty() || op.getDependKinds())
337 auto checkHint = [](
auto op, LogicalResult &) {
341 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
342 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
343 op.getInReductionSyms())
344 result = todo(
"in_reduction");
346 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
350 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
351 if (op.getOrder() || op.getOrderMod())
354 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
355 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
356 result = todo(
"privatization");
358 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
359 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
360 if (!op.getReductionVars().empty() || op.getReductionByref() ||
361 op.getReductionSyms())
362 result = todo(
"reduction");
363 if (op.getReductionMod() &&
364 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
365 result = todo(
"reduction with modifier");
367 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
368 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
369 op.getTaskReductionSyms())
370 result = todo(
"task_reduction");
372 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
373 if (op.hasNumTeamsMultiDim())
374 result = todo(
"num_teams with multi-dimensional values");
376 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
377 if (op.hasNumThreadsMultiDim())
378 result = todo(
"num_threads with multi-dimensional values");
381 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
382 if (op.hasThreadLimitMultiDim())
383 result = todo(
"thread_limit with multi-dimensional values");
388 .Case([&](omp::DistributeOp op) {
389 checkAllocate(op,
result);
392 .Case([&](omp::SectionsOp op) {
393 checkAllocate(op,
result);
395 checkReduction(op,
result);
397 .Case([&](omp::SingleOp op) {
398 checkAllocate(op,
result);
401 .Case([&](omp::TeamsOp op) {
402 checkAllocate(op,
result);
404 checkNumTeams(op,
result);
405 checkThreadLimit(op,
result);
407 .Case([&](omp::TaskOp op) {
408 checkAllocate(op,
result);
409 checkInReduction(op,
result);
411 .Case([&](omp::TaskgroupOp op) {
412 checkAllocate(op,
result);
413 checkTaskReduction(op,
result);
415 .Case([&](omp::TaskwaitOp op) {
419 .Case([&](omp::TaskloopContextOp op) {
420 checkAllocate(op,
result);
421 checkInReduction(op,
result);
422 checkReduction(op,
result);
424 .Case([&](omp::WsloopOp op) {
425 checkAllocate(op,
result);
427 checkReduction(op,
result);
429 .Case([&](omp::ParallelOp op) {
430 checkAllocate(op,
result);
431 checkReduction(op,
result);
432 checkNumThreads(op,
result);
434 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
435 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
436 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
437 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
438 [&](
auto op) { checkDepend(op,
result); })
439 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
440 .Case([&](omp::TargetOp op) {
441 checkAllocate(op,
result);
443 checkInReduction(op,
result);
444 checkThreadLimit(op,
result);
456 llvm::handleAllErrors(
458 [&](
const PreviouslyReportedError &) {
result = failure(); },
459 [&](
const llvm::ErrorInfoBase &err) {
476static llvm::OpenMPIRBuilder::InsertPointTy
482 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
484 [&](OpenMPAllocaStackFrame &frame) {
485 allocaInsertPoint = frame.allocaInsertPoint;
493 allocaInsertPoint.getBlock()->getParent() ==
494 builder.GetInsertBlock()->getParent())
495 return allocaInsertPoint;
504 if (builder.GetInsertBlock() ==
505 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
506 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
507 "Assuming end of basic block");
508 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
509 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
510 builder.GetInsertBlock()->getNextNode());
511 builder.CreateBr(entryBB);
512 builder.SetInsertPoint(entryBB);
515 llvm::BasicBlock &funcEntryBlock =
516 builder.GetInsertBlock()->getParent()->getEntryBlock();
517 return llvm::OpenMPIRBuilder::InsertPointTy(
518 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
524static llvm::CanonicalLoopInfo *
526 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
527 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
528 [&](OpenMPLoopInfoStackFrame &frame) {
529 loopInfo = frame.loopInfo;
541 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
544 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
546 llvm::BasicBlock *continuationBlock =
547 splitBB(builder,
true,
"omp.region.cont");
548 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
550 llvm::LLVMContext &llvmContext = builder.getContext();
551 for (
Block &bb : region) {
552 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
553 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
554 builder.GetInsertBlock()->getNextNode());
555 moduleTranslation.
mapBlock(&bb, llvmBB);
558 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
565 unsigned numYields = 0;
567 if (!isLoopWrapper) {
568 bool operandsProcessed =
false;
570 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
571 if (!operandsProcessed) {
572 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
573 continuationBlockPHITypes.push_back(
574 moduleTranslation.
convertType(yield->getOperand(i).getType()));
576 operandsProcessed =
true;
578 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
579 "mismatching number of values yielded from the region");
580 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
581 llvm::Type *operandType =
582 moduleTranslation.
convertType(yield->getOperand(i).getType());
584 assert(continuationBlockPHITypes[i] == operandType &&
585 "values of mismatching types yielded from the region");
595 if (!continuationBlockPHITypes.empty())
597 continuationBlockPHIs &&
598 "expected continuation block PHIs if converted regions yield values");
599 if (continuationBlockPHIs) {
600 llvm::IRBuilderBase::InsertPointGuard guard(builder);
601 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
602 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
603 for (llvm::Type *ty : continuationBlockPHITypes)
604 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
610 for (
Block *bb : blocks) {
611 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
614 if (bb->isEntryBlock()) {
615 assert(sourceTerminator->getNumSuccessors() == 1 &&
616 "provided entry block has multiple successors");
617 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
618 "ContinuationBlock is not the successor of the entry block");
619 sourceTerminator->setSuccessor(0, llvmBB);
622 llvm::IRBuilderBase::InsertPointGuard guard(builder);
624 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
625 return llvm::make_error<PreviouslyReportedError>();
630 builder.CreateBr(continuationBlock);
641 Operation *terminator = bb->getTerminator();
642 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
643 builder.CreateBr(continuationBlock);
645 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
646 (*continuationBlockPHIs)[i]->addIncoming(
660 return continuationBlock;
666 case omp::ClauseProcBindKind::Close:
667 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
668 case omp::ClauseProcBindKind::Master:
669 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
670 case omp::ClauseProcBindKind::Primary:
671 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
672 case omp::ClauseProcBindKind::Spread:
673 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
675 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
682 auto maskedOp = cast<omp::MaskedOp>(opInst);
683 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
688 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
690 auto ®ion = maskedOp.getRegion();
691 builder.restoreIP(codeGenIP);
699 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
701 llvm::Value *filterVal =
nullptr;
702 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
703 filterVal = moduleTranslation.
lookupValue(filterVar);
705 llvm::LLVMContext &llvmContext = builder.getContext();
707 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
709 assert(filterVal !=
nullptr);
710 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
711 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
718 builder.restoreIP(*afterIP);
726 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
727 auto masterOp = cast<omp::MasterOp>(opInst);
732 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
734 auto ®ion = masterOp.getRegion();
735 builder.restoreIP(codeGenIP);
743 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
745 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
746 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
753 builder.restoreIP(*afterIP);
761 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
762 auto criticalOp = cast<omp::CriticalOp>(opInst);
767 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
769 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
770 builder.restoreIP(codeGenIP);
778 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
780 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
781 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
782 llvm::Constant *hint =
nullptr;
785 if (criticalOp.getNameAttr()) {
788 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
789 auto criticalDeclareOp =
793 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
794 static_cast<int>(criticalDeclareOp.getHint()));
796 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
798 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
803 builder.restoreIP(*afterIP);
810 template <
typename OP>
813 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
816 collectPrivatizationDecls<OP>(op);
831 void collectPrivatizationDecls(OP op) {
832 std::optional<ArrayAttr> attr = op.getPrivateSyms();
837 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
848 std::optional<ArrayAttr> attr = op.getReductionSyms();
852 reductions.reserve(reductions.size() + op.getNumReductionVars());
853 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
854 reductions.push_back(
866 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
875 llvm::Instruction *potentialTerminator =
876 builder.GetInsertBlock()->empty() ?
nullptr
877 : &builder.GetInsertBlock()->back();
879 if (potentialTerminator && potentialTerminator->isTerminator())
880 potentialTerminator->removeFromParent();
881 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
884 region.
front(),
true, builder)))
888 if (continuationBlockArgs)
890 *continuationBlockArgs,
897 if (potentialTerminator && potentialTerminator->isTerminator()) {
898 llvm::BasicBlock *block = builder.GetInsertBlock();
899 if (block->empty()) {
905 potentialTerminator->insertInto(block, block->begin());
907 potentialTerminator->insertAfter(&block->back());
921 if (continuationBlockArgs)
922 llvm::append_range(*continuationBlockArgs, phis);
923 builder.SetInsertPoint(*continuationBlock,
924 (*continuationBlock)->getFirstInsertionPt());
931using OwningReductionGen =
932 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
933 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
935using OwningAtomicReductionGen =
936 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
937 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
939using OwningDataPtrPtrReductionGen =
940 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
941 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
947static OwningReductionGen
953 OwningReductionGen gen =
954 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
955 llvm::Value *
lhs, llvm::Value *
rhs,
956 llvm::Value *&
result)
mutable
957 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
958 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
959 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
960 builder.restoreIP(insertPoint);
963 "omp.reduction.nonatomic.body", builder,
964 moduleTranslation, &phis)))
965 return llvm::createStringError(
966 "failed to inline `combiner` region of `omp.declare_reduction`");
967 result = llvm::getSingleElement(phis);
968 return builder.saveIP();
977static OwningAtomicReductionGen
979 llvm::IRBuilderBase &builder,
981 if (decl.getAtomicReductionRegion().empty())
982 return OwningAtomicReductionGen();
988 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
989 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
990 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
991 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
992 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
993 builder.restoreIP(insertPoint);
996 "omp.reduction.atomic.body", builder,
997 moduleTranslation, &phis)))
998 return llvm::createStringError(
999 "failed to inline `atomic` region of `omp.declare_reduction`");
1000 assert(phis.empty());
1001 return builder.saveIP();
1010static OwningDataPtrPtrReductionGen
1014 return OwningDataPtrPtrReductionGen();
1016 OwningDataPtrPtrReductionGen refDataPtrGen =
1017 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1018 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1019 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1020 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1021 builder.restoreIP(insertPoint);
1024 "omp.data_ptr_ptr.body", builder,
1025 moduleTranslation, &phis)))
1026 return llvm::createStringError(
1027 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1028 result = llvm::getSingleElement(phis);
1029 return builder.saveIP();
1032 return refDataPtrGen;
1039 auto orderedOp = cast<omp::OrderedOp>(opInst);
1044 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1045 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1046 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1048 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1050 size_t indexVecValues = 0;
1051 while (indexVecValues < vecValues.size()) {
1053 storeValues.reserve(numLoops);
1054 for (
unsigned i = 0; i < numLoops; i++) {
1055 storeValues.push_back(vecValues[indexVecValues]);
1058 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1060 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1061 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1062 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1072 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1073 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1078 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1080 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1081 builder.restoreIP(codeGenIP);
1089 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1091 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1092 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1094 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1099 builder.restoreIP(*afterIP);
1105struct DeferredStore {
1106 DeferredStore(llvm::Value *value, llvm::Value *address)
1107 : value(value), address(address) {}
1110 llvm::Value *address;
1117template <
typename T>
1120 llvm::IRBuilderBase &builder,
1122 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1128 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1129 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1132 deferredStores.reserve(loop.getNumReductionVars());
1134 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1135 Region &allocRegion = reductionDecls[i].getAllocRegion();
1137 if (allocRegion.
empty())
1142 builder, moduleTranslation, &phis)))
1143 return loop.emitError(
1144 "failed to inline `alloc` region of `omp.declare_reduction`");
1146 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1147 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1151 llvm::Value *var = builder.CreateAlloca(
1152 moduleTranslation.
convertType(reductionDecls[i].getType()));
1154 llvm::Type *ptrTy = builder.getPtrTy();
1155 llvm::Value *castVar =
1156 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1157 llvm::Value *castPhi =
1158 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1160 deferredStores.emplace_back(castPhi, castVar);
1162 privateReductionVariables[i] = castVar;
1163 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1164 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1166 assert(allocRegion.
empty() &&
1167 "allocaction is implicit for by-val reduction");
1168 llvm::Value *var = builder.CreateAlloca(
1169 moduleTranslation.
convertType(reductionDecls[i].getType()));
1171 llvm::Type *ptrTy = builder.getPtrTy();
1172 llvm::Value *castVar =
1173 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1175 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1176 privateReductionVariables[i] = castVar;
1177 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1185template <
typename T>
1188 llvm::IRBuilderBase &builder,
1193 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1194 Region &initializerRegion = reduction.getInitializerRegion();
1197 mlir::Value mlirSource = loop.getReductionVars()[i];
1198 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1199 llvm::Value *origVal = llvmSource;
1201 if (!isa<LLVM::LLVMPointerType>(
1202 reduction.getInitializerMoldArg().getType()) &&
1203 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1206 reduction.getInitializerMoldArg().getType()),
1207 llvmSource,
"omp_orig");
1209 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1212 llvm::Value *allocation =
1213 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1214 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1220 llvm::BasicBlock *block =
nullptr) {
1221 if (block ==
nullptr)
1222 block = builder.GetInsertBlock();
1224 if (!block->hasTerminator())
1225 builder.SetInsertPoint(block);
1227 builder.SetInsertPoint(block->getTerminator());
1235template <
typename OP>
1238 llvm::IRBuilderBase &builder,
1240 llvm::BasicBlock *latestAllocaBlock,
1246 if (op.getNumReductionVars() == 0)
1249 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1250 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1251 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1252 builder.restoreIP(allocaIP);
1255 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1257 if (!reductionDecls[i].getAllocRegion().empty())
1263 byRefVars[i] = builder.CreateAlloca(
1264 moduleTranslation.
convertType(reductionDecls[i].getType()));
1272 for (
auto [data, addr] : deferredStores)
1273 builder.CreateStore(data, addr);
1278 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1283 reductionVariableMap, i);
1291 "omp.reduction.neutral", builder,
1292 moduleTranslation, &phis)))
1295 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1296 "reduction neutral element declaration region");
1301 if (!reductionDecls[i].getAllocRegion().empty())
1310 builder.CreateStore(phis[0], byRefVars[i]);
1312 privateReductionVariables[i] = byRefVars[i];
1313 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1314 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1317 builder.CreateStore(phis[0], privateReductionVariables[i]);
1324 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1331template <
typename T>
1332static void collectReductionInfo(
1333 T loop, llvm::IRBuilderBase &builder,
1342 unsigned numReductions = loop.getNumReductionVars();
1344 for (
unsigned i = 0; i < numReductions; ++i) {
1347 owningAtomicReductionGens.push_back(
1350 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1354 reductionInfos.reserve(numReductions);
1355 for (
unsigned i = 0; i < numReductions; ++i) {
1356 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1357 if (owningAtomicReductionGens[i])
1358 atomicGen = owningAtomicReductionGens[i];
1359 llvm::Value *variable =
1360 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1363 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1364 allocatedType = alloca.getElemType();
1371 reductionInfos.push_back(
1373 privateReductionVariables[i],
1374 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1378 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1379 reductionDecls[i].getByrefElementType()
1381 *reductionDecls[i].getByrefElementType())
1391 llvm::IRBuilderBase &builder, StringRef regionName,
1392 bool shouldLoadCleanupRegionArg =
true) {
1393 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1394 if (cleanupRegion->empty())
1400 llvm::Instruction *potentialTerminator =
1401 builder.GetInsertBlock()->empty() ?
nullptr
1402 : &builder.GetInsertBlock()->back();
1403 if (potentialTerminator && potentialTerminator->isTerminator())
1404 builder.SetInsertPoint(potentialTerminator);
1405 llvm::Value *privateVarValue =
1406 shouldLoadCleanupRegionArg
1407 ? builder.CreateLoad(
1409 privateVariables[i])
1410 : privateVariables[i];
1415 moduleTranslation)))
1428 OP op, llvm::IRBuilderBase &builder,
1430 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1433 bool isNowait =
false,
bool isTeamsReduction =
false) {
1435 if (op.getNumReductionVars() == 0)
1447 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1449 owningReductionGenRefDataPtrGens,
1450 privateReductionVariables, reductionInfos, isByRef);
1455 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1456 builder.SetInsertPoint(tempTerminator);
1457 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1458 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1459 isByRef, isNowait, isTeamsReduction);
1464 if (!contInsertPoint->getBlock())
1465 return op->emitOpError() <<
"failed to convert reductions";
1467 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1468 if (!isTeamsReduction) {
1469 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1470 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1474 afterIP = *barrierIP;
1477 tempTerminator->eraseFromParent();
1478 builder.restoreIP(afterIP);
1482 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1483 [](omp::DeclareReductionOp reductionDecl) {
1484 return &reductionDecl.getCleanupRegion();
1487 moduleTranslation, builder,
1488 "omp.reduction.cleanup");
1499template <
typename OP>
1503 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1508 if (op.getNumReductionVars() == 0)
1514 allocaIP, reductionDecls,
1515 privateReductionVariables, reductionVariableMap,
1516 deferredStores, isByRef)))
1519 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1520 allocaIP.getBlock(), reductionDecls,
1521 privateReductionVariables, reductionVariableMap,
1522 isByRef, deferredStores);
1536 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1539 Value blockArg = (*mappedPrivateVars)[privateVar];
1542 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1543 "A block argument corresponding to a mapped var should have "
1546 if (privVarType == blockArgType)
1553 if (!isa<LLVM::LLVMPointerType>(privVarType))
1554 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1567 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1569 llvm::BasicBlock *privInitBlock,
1571 Region &initRegion = privDecl.getInitRegion();
1572 if (initRegion.
empty())
1573 return llvmPrivateVar;
1575 assert(nonPrivateVar);
1576 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1577 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1582 moduleTranslation, &phis)))
1583 return llvm::createStringError(
1584 "failed to inline `init` region of `omp.private`");
1586 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1603 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1606 builder, moduleTranslation, privDecl,
1609 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1618 return llvm::Error::success();
1620 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1623 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1626 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1628 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1629 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1632 return privVarOrErr.takeError();
1634 llvmPrivateVar = privVarOrErr.get();
1635 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1640 return llvm::Error::success();
1650 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1653 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1654 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1655 allocaTerminator->getIterator()),
1656 true, allocaTerminator->getStableDebugLoc(),
1657 "omp.region.after_alloca");
1659 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1661 allocaTerminator = allocaIP.getBlock()->getTerminator();
1662 builder.SetInsertPoint(allocaTerminator);
1664 assert(allocaTerminator->getNumSuccessors() == 1 &&
1665 "This is an unconditional branch created by splitBB");
1667 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1668 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1670 unsigned int allocaAS =
1671 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1674 .getProgramAddressSpace();
1676 for (
auto [privDecl, mlirPrivVar, blockArg] :
1679 llvm::Type *llvmAllocType =
1680 moduleTranslation.
convertType(privDecl.getType());
1681 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1682 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1683 llvmAllocType,
nullptr,
"omp.private.alloc");
1684 if (allocaAS != defaultAS)
1685 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1686 builder.getPtrTy(defaultAS));
1688 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1691 return afterAllocas;
1699 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1708 if (mlir::isa<omp::ParallelOp>(parent))
1722 bool needsFirstprivate =
1723 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1724 return privOp.getDataSharingType() ==
1725 omp::DataSharingClauseType::FirstPrivate;
1728 if (!needsFirstprivate)
1731 llvm::BasicBlock *copyBlock =
1732 splitBB(builder,
true,
"omp.private.copy");
1735 for (
auto [decl, moldVar, llvmVar] :
1736 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1737 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1741 Region ©Region = decl.getCopyRegion();
1743 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1746 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1750 moduleTranslation)))
1751 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1765 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1766 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1782 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1784 llvm::Value *moldVar = findAssociatedValue(
1785 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1790 llvmPrivateVars, privateDecls, insertBarrier,
1801 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1802 [](omp::PrivateClauseOp privatizer) {
1803 return &privatizer.getDeallocRegion();
1807 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1808 "omp.private.dealloc",
false)))
1809 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1810 "`omp.private` op in");
1822 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1832 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1833 using StorableBodyGenCallbackTy =
1834 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1836 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1842 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1846 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1850 sectionsOp.getNumReductionVars());
1854 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1857 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1858 reductionDecls, privateReductionVariables, reductionVariableMap,
1865 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1869 Region ®ion = sectionOp.getRegion();
1870 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1871 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1872 builder.restoreIP(codeGenIP);
1879 sectionsOp.getRegion().getNumArguments());
1880 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1881 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1882 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1884 moduleTranslation.
mapValue(sectionArg, llvmVal);
1891 sectionCBs.push_back(sectionCB);
1897 if (sectionCBs.empty())
1900 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1905 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1906 llvm::Value &vPtr, llvm::Value *&replacementValue)
1907 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1908 replacementValue = &vPtr;
1914 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1918 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1919 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1921 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1922 sectionsOp.getNowait());
1927 builder.restoreIP(*afterIP);
1931 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1932 privateReductionVariables, isByRef, sectionsOp.getNowait());
1939 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1940 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1945 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1946 builder.restoreIP(codegenIP);
1948 builder, moduleTranslation)
1951 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1955 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1958 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1959 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1961 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1962 llvmCPFuncs.push_back(
1966 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1968 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1974 builder.restoreIP(*afterIP);
1978static omp::DistributeOp
1982 omp::DistributeOp distOp;
1983 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
1989 if (walk.wasInterrupted() || !distOp)
1993 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1997 for (
auto ra : iface.getReductionBlockArgs())
1998 for (
auto &use : ra.getUses()) {
1999 auto *useOp = use.getOwner();
2001 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2002 debugUses.push_back(useOp);
2005 if (!distOp->isProperAncestor(useOp))
2012 for (
auto *use : debugUses)
2021 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2026 unsigned numReductionVars = op.getNumReductionVars();
2030 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2036 if (doTeamsReduction) {
2037 isByRef =
getIsByRef(op.getReductionByref());
2039 assert(isByRef.size() == op.getNumReductionVars());
2042 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2047 op, reductionArgs, builder, moduleTranslation, allocaIP,
2048 reductionDecls, privateReductionVariables, reductionVariableMap,
2053 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2055 moduleTranslation, allocaIP);
2056 builder.restoreIP(codegenIP);
2062 llvm::Value *numTeamsLower =
nullptr;
2063 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2064 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2066 llvm::Value *numTeamsUpper =
nullptr;
2067 if (!op.getNumTeamsUpperVars().empty())
2068 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2070 llvm::Value *threadLimit =
nullptr;
2071 if (!op.getThreadLimitVars().empty())
2072 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2074 llvm::Value *ifExpr =
nullptr;
2075 if (
Value ifVar = op.getIfExpr())
2078 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2079 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2081 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2086 builder.restoreIP(*afterIP);
2087 if (doTeamsReduction) {
2090 op, builder, moduleTranslation, allocaIP, reductionDecls,
2091 privateReductionVariables, isByRef,
2097static llvm::omp::RTLDependenceKindTy
2100 case mlir::omp::ClauseTaskDepend::taskdependin:
2101 return llvm::omp::RTLDependenceKindTy::DepIn;
2105 case mlir::omp::ClauseTaskDepend::taskdependout:
2106 case mlir::omp::ClauseTaskDepend::taskdependinout:
2107 return llvm::omp::RTLDependenceKindTy::DepInOut;
2108 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2109 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2110 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2111 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2113 llvm_unreachable(
"unhandled depend kind");
2117 std::optional<ArrayAttr> dependKinds,
OperandRange dependVars,
2120 if (dependVars.empty())
2122 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2124 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2126 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2140 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2142 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2143 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2147 llvmBuilder.restoreIP(ip);
2153 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2154 return llvm::Error::success();
2159 ompBuilder.pushFinalizationCB(
2169 llvm::OpenMPIRBuilder &ompBuilder,
2170 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2171 ompBuilder.popFinalizationCB();
2172 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2173 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2174 cancelBranch->setSuccessor(constructFini);
2180class TaskContextStructManager {
2182 TaskContextStructManager(llvm::IRBuilderBase &builder,
2183 LLVM::ModuleTranslation &moduleTranslation,
2184 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2185 : builder{builder}, moduleTranslation{moduleTranslation},
2186 privateDecls{privateDecls} {}
2192 void generateTaskContextStruct();
2198 void createGEPsToPrivateVars();
2204 SmallVector<llvm::Value *>
2205 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2208 void freeStructPtr();
2210 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2211 return llvmPrivateVarGEPs;
2214 llvm::Value *getStructPtr() {
return structPtr; }
2217 llvm::IRBuilderBase &builder;
2218 LLVM::ModuleTranslation &moduleTranslation;
2219 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2222 SmallVector<llvm::Type *> privateVarTypes;
2226 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2229 llvm::Value *structPtr =
nullptr;
2231 llvm::Type *structTy =
nullptr;
2242 llvm::SmallVector<llvm::Value *> lowerBounds;
2243 llvm::SmallVector<llvm::Value *> upperBounds;
2244 llvm::SmallVector<llvm::Value *> steps;
2245 llvm::SmallVector<llvm::Value *> trips;
2247 llvm::Value *totalTrips;
2249 llvm::Value *lookUpAsI64(mlir::Value val,
const LLVM::ModuleTranslation &mt,
2250 llvm::IRBuilderBase &builder) {
2254 if (v->getType()->isIntegerTy(64))
2256 if (v->getType()->isIntegerTy())
2257 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2262 IteratorInfo(mlir::omp::IteratorOp itersOp,
2263 mlir::LLVM::ModuleTranslation &moduleTranslation,
2264 llvm::IRBuilderBase &builder) {
2265 dims = itersOp.getLoopLowerBounds().size();
2266 lowerBounds.resize(dims);
2267 upperBounds.resize(dims);
2271 for (
unsigned d = 0; d < dims; ++d) {
2272 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2273 moduleTranslation, builder);
2274 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2275 moduleTranslation, builder);
2277 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2278 assert(lb && ub && st &&
2279 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2280 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2281 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2282 "Expect non-zero step in IteratorOp");
2284 lowerBounds[d] = lb;
2285 upperBounds[d] = ub;
2289 llvm::Value *diff = builder.CreateSub(ub, lb);
2290 llvm::Value *
div = builder.CreateSDiv(diff, st);
2291 trips[d] = builder.CreateAdd(
2292 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2295 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2296 for (
unsigned d = 0; d < dims; ++d)
2297 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2300 unsigned getDims()
const {
return dims; }
2301 llvm::ArrayRef<llvm::Value *> getLowerBounds()
const {
return lowerBounds; }
2302 llvm::ArrayRef<llvm::Value *> getUpperBounds()
const {
return upperBounds; }
2303 llvm::ArrayRef<llvm::Value *> getSteps()
const {
return steps; }
2304 llvm::ArrayRef<llvm::Value *> getTrips()
const {
return trips; }
2305 llvm::Value *getTotalTrips()
const {
return totalTrips; }
2310void TaskContextStructManager::generateTaskContextStruct() {
2311 if (privateDecls.empty())
2313 privateVarTypes.reserve(privateDecls.size());
2315 for (omp::PrivateClauseOp &privOp : privateDecls) {
2318 if (!privOp.readsFromMold())
2320 Type mlirType = privOp.getType();
2321 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2324 if (privateVarTypes.empty())
2327 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2330 llvm::DataLayout dataLayout =
2331 builder.GetInsertBlock()->getModule()->getDataLayout();
2332 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2333 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2336 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2338 "omp.task.context_ptr");
2341SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2342 llvm::Value *altStructPtr)
const {
2343 SmallVector<llvm::Value *> ret;
2346 ret.reserve(privateDecls.size());
2347 llvm::Value *zero = builder.getInt32(0);
2349 for (
auto privDecl : privateDecls) {
2350 if (!privDecl.readsFromMold()) {
2352 ret.push_back(
nullptr);
2355 llvm::Value *iVal = builder.getInt32(i);
2356 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2363void TaskContextStructManager::createGEPsToPrivateVars() {
2365 assert(privateVarTypes.empty());
2369 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2372void TaskContextStructManager::freeStructPtr() {
2376 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2378 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2379 builder.CreateFree(structPtr);
2383 llvm::OpenMPIRBuilder &ompBuilder,
2384 llvm::Value *affinityList, llvm::Value *
index,
2385 llvm::Value *addr, llvm::Value *len) {
2386 llvm::StructType *kmpTaskAffinityInfoTy =
2387 ompBuilder.getKmpTaskAffinityInfoTy();
2388 llvm::Value *entry = builder.CreateInBoundsGEP(
2389 kmpTaskAffinityInfoTy, affinityList,
index,
"omp.affinity.entry");
2391 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2392 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2394 llvm::Value *flags = builder.getInt32(0);
2396 builder.CreateStore(addr,
2397 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2398 builder.CreateStore(len,
2399 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2400 builder.CreateStore(flags,
2401 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2405 llvm::IRBuilderBase &builder,
2407 llvm::Value *affinityList) {
2408 for (
auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2409 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2410 assert(entryOp &&
"affinity item must be omp.affinity_entry");
2412 llvm::Value *addr = moduleTranslation.
lookupValue(entryOp.getAddr());
2413 llvm::Value *len = moduleTranslation.
lookupValue(entryOp.getLen());
2414 assert(addr && len &&
"expect affinity addr and len to be non-null");
2416 affinityList, builder.getInt64(i), addr, len);
2420static mlir::LogicalResult
2423 llvm::IRBuilderBase &builder,
2425 llvm::Value *tmp = linearIV;
2426 for (
int d = (
int)iterInfo.getDims() - 1; d >= 0; --d) {
2427 llvm::Value *trip = iterInfo.getTrips()[d];
2429 llvm::Value *idx = builder.CreateURem(tmp, trip);
2431 tmp = builder.CreateUDiv(tmp, trip);
2434 llvm::Value *physIV = builder.CreateAdd(
2435 iterInfo.getLowerBounds()[d],
2436 builder.CreateMul(idx, iterInfo.getSteps()[d]),
"omp.it.phys_iv");
2442 moduleTranslation.
mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2443 if (mlir::failed(moduleTranslation.
convertBlock(iteratorRegionBlock,
2446 return mlir::failure();
2448 return mlir::success();
2454static mlir::LogicalResult
2457 IteratorInfo &iterInfo, llvm::StringRef loopName,
2462 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2464 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2465 llvm::Value *linearIV) -> llvm::Error {
2466 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2467 builder.restoreIP(bodyIP);
2470 builder, moduleTranslation))) {
2471 return llvm::make_error<llvm::StringError>(
2472 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2476 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.
getTerminator());
2477 assert(yield && yield.getResults().size() == 1 &&
2478 "expect omp.yield in iterator region to have one result");
2480 genStoreEntry(linearIV, yield);
2486 return llvm::Error::success();
2489 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2491 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2495 builder.restoreIP(*afterIP);
2497 return mlir::success();
2500static mlir::LogicalResult
2503 llvm::OpenMPIRBuilder::AffinityData &ad) {
2505 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2508 return mlir::success();
2512 llvm::StructType *kmpTaskAffinityInfoTy =
2515 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2516 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2517 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2519 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2520 "omp.affinity_list");
2523 auto createAffinity =
2524 [&](llvm::Value *count,
2525 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2526 llvm::OpenMPIRBuilder::AffinityData ad{};
2527 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2529 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2533 if (!taskOp.getAffinityVars().empty()) {
2534 llvm::Value *count = llvm::ConstantInt::get(
2535 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2536 llvm::Value *list = allocateAffinityList(count);
2539 ads.emplace_back(createAffinity(count, list));
2542 if (!taskOp.getIterated().empty()) {
2543 for (
auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2544 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2545 assert(itersOp &&
"iterated value must be defined by omp.iterator");
2546 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2547 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2549 itersOp, builder, moduleTranslation, iterInfo,
"iterator",
2550 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2551 auto entryOp = yield.getResults()[0]
2552 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2553 assert(entryOp &&
"expect yield produce an affinity entry");
2560 affList, linearIV, addr, len);
2562 return llvm::failure();
2563 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2567 llvm::Value *totalAffinityCount = builder.getInt32(0);
2568 for (
const auto &affinity : ads)
2569 totalAffinityCount = builder.CreateAdd(
2571 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2574 llvm::Value *affinityInfo = ads.front().Info;
2575 if (ads.size() > 1) {
2576 llvm::StructType *kmpTaskAffinityInfoTy =
2578 llvm::Value *affinityInfoElemSize = builder.getInt64(
2579 moduleTranslation.
getLLVMModule()->getDataLayout().getTypeAllocSize(
2580 kmpTaskAffinityInfoTy));
2582 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2583 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2584 for (
const auto &affinity : ads) {
2585 llvm::Value *affinityCount = builder.CreateIntCast(
2586 affinity.Count, builder.getInt32Ty(),
false);
2587 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2588 affinityCount, builder.getInt64Ty(),
false);
2589 llvm::Value *affinityInfoSize =
2590 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2592 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2593 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2595 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2596 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2598 builder.CreateMemCpy(
2599 packedAffinityInfoIndex, llvm::Align(1),
2600 builder.CreatePointerBitCastOrAddrSpaceCast(
2601 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2602 ->getPointerAddressSpace())),
2603 llvm::Align(1), affinityInfoSize);
2605 packedAffinityInfoOffset =
2606 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2609 affinityInfo = packedAffinityInfo;
2612 ad.Count = totalAffinityCount;
2613 ad.Info = affinityInfo;
2615 return mlir::success();
2621static mlir::LogicalResult
2624 std::optional<ArrayAttr> dependIteratedKinds,
2625 llvm::IRBuilderBase &builder,
2627 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2628 if (dependIterated.empty()) {
2631 return mlir::success();
2635 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2636 unsigned numLocator = dependVars.size();
2639 llvm::Value *totalCount =
2640 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2643 for (
auto iter : dependIterated) {
2644 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2645 assert(itersOp &&
"depend_iterated value must be defined by omp.iterator");
2646 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2648 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2653 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2654 llvm::Value *depArray =
2655 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2656 totalCount,
nullptr,
".dep.arr.addr");
2659 if (numLocator > 0) {
2662 for (
auto [i, dd] : llvm::enumerate(dds)) {
2663 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2664 llvm::Value *entry =
2665 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2666 ompBuilder.emitTaskDependency(builder, entry, dd);
2671 llvm::Value *offset =
2672 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2673 for (
auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2674 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2675 dependIteratedKinds->getValue()[i]);
2676 llvm::omp::RTLDependenceKindTy rtlKind =
2679 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2681 itersOp, builder, moduleTranslation, iterInfo,
"dep_iterator",
2682 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2684 moduleTranslation.
lookupValue(yield.getResults()[0]);
2685 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2686 llvm::Value *entry =
2687 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2688 ompBuilder.emitTaskDependency(
2690 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2693 return mlir::failure();
2696 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2699 taskDeps.DepArray = depArray;
2700 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2701 return mlir::success();
2708 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2713 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2725 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2730 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2731 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2732 builder.getContext(),
"omp.task.start",
2733 builder.GetInsertBlock()->getParent());
2734 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2735 builder.SetInsertPoint(branchToTaskStartBlock);
2738 llvm::BasicBlock *copyBlock =
2739 splitBB(builder,
true,
"omp.private.copy");
2740 llvm::BasicBlock *initBlock =
2741 splitBB(builder,
true,
"omp.private.init");
2757 moduleTranslation, allocaIP);
2760 builder.SetInsertPoint(initBlock->getTerminator());
2763 taskStructMgr.generateTaskContextStruct();
2770 taskStructMgr.createGEPsToPrivateVars();
2772 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2775 taskStructMgr.getLLVMPrivateVarGEPs())) {
2777 if (!privDecl.readsFromMold())
2779 assert(llvmPrivateVarAlloc &&
2780 "reads from mold so shouldn't have been skipped");
2783 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2784 blockArg, llvmPrivateVarAlloc, initBlock);
2785 if (!privateVarOrErr)
2786 return handleError(privateVarOrErr, *taskOp.getOperation());
2795 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2796 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2797 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2799 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2800 llvmPrivateVarAlloc);
2802 assert(llvmPrivateVarAlloc->getType() ==
2803 moduleTranslation.
convertType(blockArg.getType()));
2813 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2814 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2815 taskOp.getPrivateNeedsBarrier())))
2816 return llvm::failure();
2818 llvm::OpenMPIRBuilder::AffinityData ad;
2820 return llvm::failure();
2823 builder.SetInsertPoint(taskStartBlock);
2825 auto bodyCB = [&](InsertPointTy allocaIP,
2826 InsertPointTy codegenIP) -> llvm::Error {
2830 moduleTranslation, allocaIP);
2833 builder.restoreIP(codegenIP);
2835 llvm::BasicBlock *privInitBlock =
nullptr;
2837 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2840 auto [blockArg, privDecl, mlirPrivVar] = zip;
2842 if (privDecl.readsFromMold())
2845 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2846 llvm::Type *llvmAllocType =
2847 moduleTranslation.
convertType(privDecl.getType());
2848 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2849 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2850 llvmAllocType,
nullptr,
"omp.private.alloc");
2853 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2854 blockArg, llvmPrivateVar, privInitBlock);
2855 if (!privateVarOrError)
2856 return privateVarOrError.takeError();
2857 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2858 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2861 taskStructMgr.createGEPsToPrivateVars();
2862 for (
auto [i, llvmPrivVar] :
2863 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2865 assert(privateVarsInfo.
llvmVars[i] &&
2866 "This is added in the loop above");
2869 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2874 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2878 if (!privateDecl.readsFromMold())
2881 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2882 llvmPrivateVar = builder.CreateLoad(
2883 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2885 assert(llvmPrivateVar->getType() ==
2886 moduleTranslation.
convertType(blockArg.getType()));
2887 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2891 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2892 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2893 return llvm::make_error<PreviouslyReportedError>();
2895 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2900 return llvm::make_error<PreviouslyReportedError>();
2903 taskStructMgr.freeStructPtr();
2905 return llvm::Error::success();
2914 llvm::omp::Directive::OMPD_taskgroup);
2916 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
2917 if (failed(
buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
2918 taskOp.getDependIterated(),
2919 taskOp.getDependIteratedKinds(), builder,
2920 moduleTranslation, dependencies)))
2923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2926 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2928 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dependencies, ad,
2929 taskOp.getMergeable(),
2930 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2931 moduleTranslation.
lookupValue(taskOp.getPriority()));
2939 builder.restoreIP(*afterIP);
2941 if (dependencies.DepArray)
2942 builder.CreateFree(dependencies.DepArray);
2951 llvm::IRBuilderBase &builder,
2959 loopWrapperOp.getRegion(),
"omp.taskloop.wrapper.region", builder,
2962 if (failed(
handleError(continuationBlockOrError, opInst)))
2965 builder.SetInsertPoint(continuationBlockOrError.get());
2973static llvm::Expected<llvm::Value *>
2976 llvm::IRBuilderBase &builder) {
2977 if (llvm::Value *mapped = moduleTranslation.
lookupValue(value))
2982 return llvm::make_error<llvm::StringError>(
2983 "value is a block argument and is not mapped",
2984 llvm::inconvertibleErrorCode());
2986 return llvm::make_error<llvm::StringError>(
2987 "unsupported op defining taskloop loop bound",
2988 llvm::inconvertibleErrorCode());
2998 if (!operandOrError)
2999 return operandOrError.takeError();
3000 moduleTranslation.
mapValue(operand, *operandOrError);
3001 mappingsToRemove.push_back(operand);
3005 return llvm::make_error<llvm::StringError>(
3006 "failed to convert op defining taskloop loop bound",
3007 llvm::inconvertibleErrorCode());
3010 assert(
result &&
"expected conversion of loop bound op to produce a value");
3014 mappingsToRemove.push_back(resultValue);
3016 for (
Value mappedValue : mappingsToRemove)
3025 llvm::Value *&lbVal, llvm::Value *&ubVal,
3026 llvm::Value *&stepVal) {
3034 return firstLbOrErr.takeError();
3036 llvm::Type *boundType = (*firstLbOrErr)->getType();
3037 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3038 if (loopOp.getCollapseNumLoops() > 1) {
3056 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3058 i == 0 ? std::move(firstLbOrErr)
3062 return lbOrErr.takeError();
3064 upperBounds[i], moduleTranslation, builder);
3066 return ubOrErr.takeError();
3070 return stepOrErr.takeError();
3072 llvm::Value *loopLb = *lbOrErr;
3073 llvm::Value *loopUb = *ubOrErr;
3074 llvm::Value *loopStep = *stepOrErr;
3080 llvm::Value *loopLbMinusOne = builder.CreateSub(
3081 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3082 llvm::Value *loopUbMinusOne = builder.CreateSub(
3083 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3084 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3085 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3086 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3087 llvm::Value *loopTripCount =
3088 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3089 loopTripCount = builder.CreateBinaryIntrinsic(
3090 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3094 llvm::Value *loopTripCountDivStep =
3095 builder.CreateSDiv(loopTripCount, loopStep);
3096 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3097 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3098 llvm::Value *loopTripCountRem =
3099 builder.CreateSRem(loopTripCount, loopStep);
3100 loopTripCountRem = builder.CreateBinaryIntrinsic(
3101 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3102 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3104 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3107 builder.CreateAdd(loopTripCountDivStep,
3108 builder.CreateZExtOrTrunc(
3109 needsRoundUp, loopTripCountDivStep->getType()));
3110 ubVal = builder.CreateMul(ubVal, loopTripCount);
3112 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3113 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3118 return ubOrErr.takeError();
3122 return stepOrErr.takeError();
3123 lbVal = *firstLbOrErr;
3125 stepVal = *stepOrErr;
3128 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
3129 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
3130 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
3131 return llvm::Error::success();
3137 llvm::IRBuilderBase &builder,
3139 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3141 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3149 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3152 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3155 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3156 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3157 builder.getContext(),
"omp.taskloop.wrapper.start",
3158 builder.GetInsertBlock()->getParent());
3159 llvm::Instruction *branchToTaskloopStartBlock =
3160 builder.CreateBr(taskloopStartBlock);
3161 builder.SetInsertPoint(branchToTaskloopStartBlock);
3163 llvm::BasicBlock *copyBlock =
3164 splitBB(builder,
true,
"omp.private.copy");
3165 llvm::BasicBlock *initBlock =
3166 splitBB(builder,
true,
"omp.private.init");
3169 moduleTranslation, allocaIP);
3172 builder.SetInsertPoint(initBlock->getTerminator());
3175 taskStructMgr.generateTaskContextStruct();
3176 taskStructMgr.createGEPsToPrivateVars();
3178 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
3180 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3182 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3183 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3185 if (!privDecl.readsFromMold())
3187 assert(llvmPrivateVarAlloc &&
3188 "reads from mold so shouldn't have been skipped");
3191 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3192 blockArg, llvmPrivateVarAlloc, initBlock);
3193 if (!privateVarOrErr)
3194 return handleError(privateVarOrErr, *contextOp.getOperation());
3196 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3198 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3199 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3201 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3202 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3203 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3205 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3206 llvmPrivateVarAlloc);
3208 assert(llvmPrivateVarAlloc->getType() ==
3209 moduleTranslation.
convertType(blockArg.getType()));
3215 contextOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3216 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3217 contextOp.getPrivateNeedsBarrier())))
3218 return llvm::failure();
3221 builder.SetInsertPoint(taskloopStartBlock);
3223 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3224 llvm::Value *lbVal =
nullptr;
3225 llvm::Value *ubVal =
nullptr;
3226 llvm::Value *stepVal =
nullptr;
3228 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3231 auto bodyCB = [&](InsertPointTy allocaIP,
3232 InsertPointTy codegenIP) -> llvm::Error {
3236 moduleTranslation, allocaIP);
3239 builder.restoreIP(codegenIP);
3241 llvm::BasicBlock *privInitBlock =
nullptr;
3243 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3246 auto [blockArg, privDecl, mlirPrivVar] = zip;
3248 if (privDecl.readsFromMold())
3251 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3252 llvm::Type *llvmAllocType =
3253 moduleTranslation.
convertType(privDecl.getType());
3254 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3255 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3256 llvmAllocType,
nullptr,
"omp.private.alloc");
3259 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3260 blockArg, llvmPrivateVar, privInitBlock);
3261 if (!privateVarOrError)
3262 return privateVarOrError.takeError();
3263 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3264 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3267 taskStructMgr.createGEPsToPrivateVars();
3268 for (
auto [i, llvmPrivVar] :
3269 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3271 assert(privateVarsInfo.
llvmVars[i] &&
3272 "This is added in the loop above");
3275 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3280 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3284 if (!privateDecl.readsFromMold())
3287 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3288 llvmPrivateVar = builder.CreateLoad(
3289 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3291 assert(llvmPrivateVar->getType() ==
3292 moduleTranslation.
convertType(blockArg.getType()));
3293 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3299 contextOp.getRegion(),
"omp.taskloop.context.region", builder,
3302 if (failed(
handleError(continuationBlockOrError, opInst)))
3303 return llvm::make_error<PreviouslyReportedError>();
3305 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3313 contextOp.getLoc(), privateVarsInfo.
llvmVars,
3315 return llvm::make_error<PreviouslyReportedError>();
3318 taskStructMgr.freeStructPtr();
3320 return llvm::Error::success();
3326 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3327 llvm::Value *destPtr, llvm::Value *srcPtr)
3329 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3330 builder.restoreIP(codegenIP);
3333 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3335 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
3337 TaskContextStructManager &srcStructMgr = taskStructMgr;
3338 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3340 destStructMgr.generateTaskContextStruct();
3341 llvm::Value *dest = destStructMgr.getStructPtr();
3342 dest->setName(
"omp.taskloop.context.dest");
3343 builder.CreateStore(dest, destPtr);
3346 srcStructMgr.createGEPsToPrivateVars(src);
3348 destStructMgr.createGEPsToPrivateVars(dest);
3351 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3352 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
3355 if (!privDecl.readsFromMold())
3357 assert(llvmPrivateVarAlloc &&
3358 "reads from mold so shouldn't have been skipped");
3361 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3362 llvmPrivateVarAlloc, builder.GetInsertBlock());
3363 if (!privateVarOrErr)
3364 return privateVarOrErr.takeError();
3373 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3374 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3375 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3377 llvmPrivateVarAlloc = builder.CreateLoad(
3378 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3380 assert(llvmPrivateVarAlloc->getType() ==
3381 moduleTranslation.
convertType(blockArg.getType()));
3389 moduleTranslation, srcGEPs, destGEPs,
3391 contextOp.getPrivateNeedsBarrier())))
3392 return llvm::make_error<PreviouslyReportedError>();
3394 return builder.saveIP();
3402 llvm::Value *ifCond =
nullptr;
3403 llvm::Value *grainsize =
nullptr;
3405 mlir::Value grainsizeVal = contextOp.getGrainsize();
3406 mlir::Value numTasksVal = contextOp.getNumTasks();
3407 if (
Value ifVar = contextOp.getIfExpr())
3410 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
3412 }
else if (numTasksVal) {
3413 grainsize = moduleTranslation.
lookupValue(numTasksVal);
3417 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
3418 if (taskStructMgr.getStructPtr())
3419 taskDupOrNull = taskDupCB;
3429 llvm::omp::Directive::OMPD_taskgroup);
3431 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3432 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3434 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
3435 contextOp.getUntied(), ifCond, grainsize, contextOp.getNogroup(),
3436 sched, moduleTranslation.
lookupValue(contextOp.getFinal()),
3437 contextOp.getMergeable(),
3438 moduleTranslation.
lookupValue(contextOp.getPriority()),
3439 loopOp.getCollapseNumLoops(), taskDupOrNull,
3440 taskStructMgr.getStructPtr());
3447 builder.restoreIP(*afterIP);
3455 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3459 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
3460 builder.restoreIP(codegenIP);
3462 builder, moduleTranslation)
3467 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3468 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3475 builder.restoreIP(*afterIP);
3494 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3498 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3500 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3504 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3507 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
3508 llvm::Type *ivType = step->getType();
3509 llvm::Value *chunk =
nullptr;
3510 if (wsloopOp.getScheduleChunk()) {
3511 llvm::Value *chunkVar =
3512 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
3513 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3516 omp::DistributeOp distributeOp =
nullptr;
3517 llvm::Value *distScheduleChunk =
nullptr;
3518 bool hasDistSchedule =
false;
3519 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
3520 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
3521 hasDistSchedule = distributeOp.getDistScheduleStatic();
3522 if (distributeOp.getDistScheduleChunkSize()) {
3523 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3524 distributeOp.getDistScheduleChunkSize());
3525 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3533 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3537 wsloopOp.getNumReductionVars());
3540 builder, moduleTranslation, privateVarsInfo, allocaIP);
3547 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3552 moduleTranslation, allocaIP, reductionDecls,
3553 privateReductionVariables, reductionVariableMap,
3554 deferredStores, isByRef)))
3563 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3565 wsloopOp.getPrivateNeedsBarrier())))
3568 assert(afterAllocas.get()->getSinglePredecessor());
3569 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3571 afterAllocas.get()->getSinglePredecessor(),
3572 reductionDecls, privateReductionVariables,
3573 reductionVariableMap, isByRef, deferredStores)))
3577 bool isOrdered = wsloopOp.getOrdered().has_value();
3578 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3579 bool isSimd = wsloopOp.getScheduleSimd();
3580 bool loopNeedsBarrier = !wsloopOp.getNowait();
3585 llvm::omp::WorksharingLoopType workshareLoopType =
3586 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3587 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3588 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3592 llvm::omp::Directive::OMPD_for);
3594 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3597 LinearClauseProcessor linearClauseProcessor;
3599 if (!wsloopOp.getLinearVars().empty()) {
3600 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3602 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3604 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3605 linearClauseProcessor.createLinearVar(
3606 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3608 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3609 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3613 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3621 if (!wsloopOp.getLinearVars().empty()) {
3622 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3623 loopInfo->getPreheader());
3624 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3626 builder.saveIP(), llvm::omp::OMPD_barrier);
3629 builder.restoreIP(*afterBarrierIP);
3630 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3631 loopInfo->getIndVar());
3632 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3635 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3638 bool noLoopMode =
false;
3639 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3641 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3645 if (loopOp == targetCapturedOp) {
3646 omp::TargetRegionFlags kernelFlags =
3647 targetOp.getKernelExecFlags(targetCapturedOp);
3648 if (omp::bitEnumContainsAll(kernelFlags,
3649 omp::TargetRegionFlags::spmd |
3650 omp::TargetRegionFlags::no_loop) &&
3651 !omp::bitEnumContainsAny(kernelFlags,
3652 omp::TargetRegionFlags::generic))
3657 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3658 ompBuilder->applyWorkshareLoop(
3659 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3660 convertToScheduleKind(schedule), chunk, isSimd,
3661 scheduleMod == omp::ScheduleModifier::monotonic,
3662 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3663 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3669 if (!wsloopOp.getLinearVars().empty()) {
3670 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3671 assert(loopInfo->getLastIter() &&
3672 "`lastiter` in CanonicalLoopInfo is nullptr");
3673 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3674 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3675 loopInfo->getLastIter());
3678 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3679 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3681 builder.restoreIP(oldIP);
3689 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3690 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3703 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3705 assert(isByRef.size() == opInst.getNumReductionVars());
3718 opInst.getNumReductionVars());
3721 auto bodyGenCB = [&](InsertPointTy allocaIP,
3722 InsertPointTy codeGenIP) -> llvm::Error {
3724 builder, moduleTranslation, privateVarsInfo, allocaIP);
3726 return llvm::make_error<PreviouslyReportedError>();
3732 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3735 InsertPointTy(allocaIP.getBlock(),
3736 allocaIP.getBlock()->getTerminator()->getIterator());
3739 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3740 reductionDecls, privateReductionVariables, reductionVariableMap,
3741 deferredStores, isByRef)))
3742 return llvm::make_error<PreviouslyReportedError>();
3744 assert(afterAllocas.get()->getSinglePredecessor());
3745 builder.restoreIP(codeGenIP);
3751 return llvm::make_error<PreviouslyReportedError>();
3754 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3756 opInst.getPrivateNeedsBarrier())))
3757 return llvm::make_error<PreviouslyReportedError>();
3760 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3761 afterAllocas.get()->getSinglePredecessor(),
3762 reductionDecls, privateReductionVariables,
3763 reductionVariableMap, isByRef, deferredStores)))
3764 return llvm::make_error<PreviouslyReportedError>();
3769 moduleTranslation, allocaIP);
3773 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3775 return regionBlock.takeError();
3778 if (opInst.getNumReductionVars() > 0) {
3783 owningReductionGenRefDataPtrGens;
3785 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3787 owningReductionGenRefDataPtrGens,
3788 privateReductionVariables, reductionInfos, isByRef);
3791 builder.SetInsertPoint((*regionBlock)->getTerminator());
3794 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3795 builder.SetInsertPoint(tempTerminator);
3797 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3798 ompBuilder->createReductions(
3799 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3801 if (!contInsertPoint)
3802 return contInsertPoint.takeError();
3804 if (!contInsertPoint->getBlock())
3805 return llvm::make_error<PreviouslyReportedError>();
3807 tempTerminator->eraseFromParent();
3808 builder.restoreIP(*contInsertPoint);
3811 return llvm::Error::success();
3814 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3815 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3824 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3825 InsertPointTy oldIP = builder.saveIP();
3826 builder.restoreIP(codeGenIP);
3831 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3832 [](omp::DeclareReductionOp reductionDecl) {
3833 return &reductionDecl.getCleanupRegion();
3836 reductionCleanupRegions, privateReductionVariables,
3837 moduleTranslation, builder,
"omp.reduction.cleanup")))
3838 return llvm::createStringError(
3839 "failed to inline `cleanup` region of `omp.declare_reduction`");
3844 return llvm::make_error<PreviouslyReportedError>();
3848 if (isCancellable) {
3849 auto IPOrErr = ompBuilder->createBarrier(
3850 llvm::OpenMPIRBuilder::LocationDescription(builder),
3851 llvm::omp::Directive::OMPD_unknown,
3855 return IPOrErr.takeError();
3858 builder.restoreIP(oldIP);
3859 return llvm::Error::success();
3862 llvm::Value *ifCond =
nullptr;
3863 if (
auto ifVar = opInst.getIfExpr())
3865 llvm::Value *numThreads =
nullptr;
3866 if (!opInst.getNumThreadsVars().empty())
3867 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
3868 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3869 if (
auto bind = opInst.getProcBindKind())
3872 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3874 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3876 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3877 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3878 ifCond, numThreads, pbKind, isCancellable);
3883 builder.restoreIP(*afterIP);
3888static llvm::omp::OrderKind
3891 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3893 case omp::ClauseOrderKind::Concurrent:
3894 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3896 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3904 auto simdOp = cast<omp::SimdOp>(opInst);
3912 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3915 simdOp.getNumReductionVars());
3920 assert(isByRef.size() == simdOp.getNumReductionVars());
3922 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3926 builder, moduleTranslation, privateVarsInfo, allocaIP);
3931 LinearClauseProcessor linearClauseProcessor;
3933 if (!simdOp.getLinearVars().empty()) {
3934 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3936 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3937 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3938 bool isImplicit =
false;
3939 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3943 if (linearVar == mlirPrivVar) {
3945 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3946 llvmPrivateVar, idx);
3952 linearClauseProcessor.createLinearVar(
3953 builder, moduleTranslation,
3956 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3957 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3961 moduleTranslation, allocaIP, reductionDecls,
3962 privateReductionVariables, reductionVariableMap,
3963 deferredStores, isByRef)))
3974 assert(afterAllocas.get()->getSinglePredecessor());
3975 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3977 afterAllocas.get()->getSinglePredecessor(),
3978 reductionDecls, privateReductionVariables,
3979 reductionVariableMap, isByRef, deferredStores)))
3982 llvm::ConstantInt *simdlen =
nullptr;
3983 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3984 simdlen = builder.getInt64(simdlenVar.value());
3986 llvm::ConstantInt *safelen =
nullptr;
3987 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3988 safelen = builder.getInt64(safelenVar.value());
3990 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3993 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3994 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3996 for (
size_t i = 0; i < operands.size(); ++i) {
3997 llvm::Value *alignment =
nullptr;
3998 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3999 llvm::Type *ty = llvmVal->getType();
4001 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4002 alignment = builder.getInt64(intAttr.getInt());
4003 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
4004 assert(alignment &&
"Invalid alignment value");
4008 if (!intAttr.getValue().isPowerOf2())
4011 auto curInsert = builder.saveIP();
4012 builder.SetInsertPoint(sourceBlock);
4013 llvmVal = builder.CreateLoad(ty, llvmVal);
4014 builder.restoreIP(curInsert);
4015 alignedVars[llvmVal] = alignment;
4019 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
4026 if (simdOp.getLinearVars().size()) {
4027 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4028 loopInfo->getPreheader());
4030 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4031 loopInfo->getIndVar());
4033 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4035 ompBuilder->applySimd(loopInfo, alignedVars,
4037 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
4039 order, simdlen, safelen);
4041 linearClauseProcessor.emitStoresForLinearVar(builder);
4044 bool hasOrderedRegions =
false;
4045 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4046 hasOrderedRegions =
true;
4050 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
4051 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
4053 if (hasOrderedRegions) {
4055 linearClauseProcessor.rewriteInPlace(builder,
"omp.ordered.region",
4058 linearClauseProcessor.rewriteInPlace(builder,
"omp_region.finalize",
4067 for (
auto [i, tuple] : llvm::enumerate(
4068 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4069 privateReductionVariables))) {
4070 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4072 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
4073 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
4074 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
4078 llvm::Value *redValue = originalVariable;
4081 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
4082 llvm::Value *privateRedValue = builder.CreateLoad(
4083 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
4084 llvm::Value *reduced;
4086 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4089 builder.restoreIP(res.get());
4093 builder.CreateStore(reduced, originalVariable);
4098 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4099 [](omp::DeclareReductionOp reductionDecl) {
4100 return &reductionDecl.getCleanupRegion();
4103 moduleTranslation, builder,
4104 "omp.reduction.cleanup")))
4117 auto loopOp = cast<omp::LoopNestOp>(opInst);
4123 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4128 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4129 llvm::Value *iv) -> llvm::Error {
4132 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4137 bodyInsertPoints.push_back(ip);
4139 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4140 return llvm::Error::success();
4143 builder.restoreIP(ip);
4145 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
4147 return regionBlock.takeError();
4149 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4150 return llvm::Error::success();
4158 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4159 llvm::Value *lowerBound =
4160 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
4161 llvm::Value *upperBound =
4162 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
4163 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
4168 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4169 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4171 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4173 computeIP = loopInfos.front()->getPreheaderIP();
4177 ompBuilder->createCanonicalLoop(
4178 loc, bodyGen, lowerBound, upperBound, step,
4179 true, loopOp.getLoopInclusive(), computeIP);
4184 loopInfos.push_back(*loopResult);
4187 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4188 loopInfos.front()->getAfterIP();
4191 if (
const auto &tiles = loopOp.getTileSizes()) {
4192 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4195 for (
auto tile : tiles.value()) {
4196 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
4197 tileSizes.push_back(tileVal);
4200 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4201 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4205 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4206 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4207 afterIP = {afterAfterBB, afterAfterBB->begin()};
4211 for (
const auto &newLoop : newLoops)
4212 loopInfos.push_back(newLoop);
4216 const auto &numCollapse = loopOp.getCollapseNumLoops();
4218 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4220 auto newTopLoopInfo =
4221 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4223 assert(newTopLoopInfo &&
"New top loop information is missing");
4224 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
4225 [&](OpenMPLoopInfoStackFrame &frame) {
4226 frame.loopInfo = newTopLoopInfo;
4234 builder.restoreIP(afterIP);
4244 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4245 Value loopIV = op.getInductionVar();
4246 Value loopTC = op.getTripCount();
4248 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
4251 ompBuilder->createCanonicalLoop(
4253 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4256 moduleTranslation.
mapValue(loopIV, llvmIV);
4258 builder.restoreIP(ip);
4263 return bodyGenStatus.takeError();
4265 llvmTC,
"omp.loop");
4267 return op.emitError(llvm::toString(llvmOrError.takeError()));
4269 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4270 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4271 builder.restoreIP(afterIP);
4274 if (
Value cli = op.getCli())
4287 Value applyee = op.getApplyee();
4288 assert(applyee &&
"Loop to apply unrolling on required");
4290 llvm::CanonicalLoopInfo *consBuilderCLI =
4292 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4293 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4301static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4304 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4309 for (
Value size : op.getSizes()) {
4310 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
4311 assert(translatedSize &&
4312 "sizes clause arguments must already be translated");
4313 translatedSizes.push_back(translatedSize);
4316 for (
Value applyee : op.getApplyees()) {
4317 llvm::CanonicalLoopInfo *consBuilderCLI =
4319 assert(applyee &&
"Canonical loop must already been translated");
4320 translatedLoops.push_back(consBuilderCLI);
4323 auto generatedLoops =
4324 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4325 if (!op.getGeneratees().empty()) {
4326 for (
auto [mlirLoop,
genLoop] :
4327 zip_equal(op.getGeneratees(), generatedLoops))
4332 for (
Value applyee : op.getApplyees())
4340static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4343 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4347 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
4348 Value applyee = op.getApplyees()[i];
4349 llvm::CanonicalLoopInfo *consBuilderCLI =
4351 assert(applyee &&
"Canonical loop must already been translated");
4352 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4353 beforeFuse.push_back(consBuilderCLI);
4354 else if (op.getCount().has_value() &&
4355 i >= op.getFirst().value() + op.getCount().value() - 1)
4356 afterFuse.push_back(consBuilderCLI);
4358 toFuse.push_back(consBuilderCLI);
4361 (op.getGeneratees().empty() ||
4362 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4363 "Wrong number of generatees");
4366 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4367 if (!op.getGeneratees().empty()) {
4369 for (; i < beforeFuse.size(); i++)
4370 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4371 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4372 for (; i < afterFuse.size(); i++)
4373 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4377 for (
Value applyee : op.getApplyees())
4384static llvm::AtomicOrdering
4387 return llvm::AtomicOrdering::Monotonic;
4390 case omp::ClauseMemoryOrderKind::Seq_cst:
4391 return llvm::AtomicOrdering::SequentiallyConsistent;
4392 case omp::ClauseMemoryOrderKind::Acq_rel:
4393 return llvm::AtomicOrdering::AcquireRelease;
4394 case omp::ClauseMemoryOrderKind::Acquire:
4395 return llvm::AtomicOrdering::Acquire;
4396 case omp::ClauseMemoryOrderKind::Release:
4397 return llvm::AtomicOrdering::Release;
4398 case omp::ClauseMemoryOrderKind::Relaxed:
4399 return llvm::AtomicOrdering::Monotonic;
4401 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
4408 auto readOp = cast<omp::AtomicReadOp>(opInst);
4413 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4416 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4419 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
4420 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
4422 llvm::Type *elementType =
4423 moduleTranslation.
convertType(readOp.getElementType());
4425 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
4426 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
4427 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4435 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4440 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4443 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4445 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
4446 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
4447 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
4448 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
4451 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4459 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
4460 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
4461 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
4462 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
4463 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
4464 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
4465 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
4466 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
4467 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
4468 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4472 bool &isIgnoreDenormalMode,
4473 bool &isFineGrainedMemory,
4474 bool &isRemoteMemory) {
4475 isIgnoreDenormalMode =
false;
4476 isFineGrainedMemory =
false;
4477 isRemoteMemory =
false;
4478 if (atomicUpdateOp &&
4479 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4480 mlir::omp::AtomicControlAttr atomicControlAttr =
4481 atomicUpdateOp.getAtomicControlAttr();
4482 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4483 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4484 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4491 llvm::IRBuilderBase &builder,
4498 auto &innerOpList = opInst.getRegion().front().getOperations();
4499 bool isXBinopExpr{
false};
4500 llvm::AtomicRMWInst::BinOp binop;
4502 llvm::Value *llvmExpr =
nullptr;
4503 llvm::Value *llvmX =
nullptr;
4504 llvm::Type *llvmXElementType =
nullptr;
4505 if (innerOpList.size() == 2) {
4511 opInst.getRegion().getArgument(0))) {
4512 return opInst.emitError(
"no atomic update operation with region argument"
4513 " as operand found inside atomic.update region");
4516 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
4518 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4522 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4524 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4526 opInst.getRegion().getArgument(0).getType());
4527 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4531 llvm::AtomicOrdering atomicOrdering =
4536 [&opInst, &moduleTranslation](
4537 llvm::Value *atomicx,
4540 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4541 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4542 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4543 return llvm::make_error<PreviouslyReportedError>();
4545 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4546 assert(yieldop && yieldop.getResults().size() == 1 &&
4547 "terminator must be omp.yield op and it must have exactly one "
4549 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4552 bool isIgnoreDenormalMode;
4553 bool isFineGrainedMemory;
4554 bool isRemoteMemory;
4559 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4560 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4561 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4562 atomicOrdering, binop, updateFn,
4563 isXBinopExpr, isIgnoreDenormalMode,
4564 isFineGrainedMemory, isRemoteMemory);
4569 builder.restoreIP(*afterIP);
4575 llvm::IRBuilderBase &builder,
4582 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4583 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4585 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4586 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4588 assert((atomicUpdateOp || atomicWriteOp) &&
4589 "internal op must be an atomic.update or atomic.write op");
4591 if (atomicWriteOp) {
4592 isPostfixUpdate =
true;
4593 mlirExpr = atomicWriteOp.getExpr();
4595 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4596 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4597 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4600 if (innerOpList.size() == 2) {
4603 atomicUpdateOp.getRegion().getArgument(0))) {
4604 return atomicUpdateOp.emitError(
4605 "no atomic update operation with region argument"
4606 " as operand found inside atomic.update region");
4610 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4613 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4617 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4618 llvm::Value *llvmX =
4619 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4620 llvm::Value *llvmV =
4621 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4622 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4623 atomicCaptureOp.getAtomicReadOp().getElementType());
4624 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4627 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4631 llvm::AtomicOrdering atomicOrdering =
4635 [&](llvm::Value *atomicx,
4638 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4639 Block &bb = *atomicUpdateOp.getRegion().
begin();
4640 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4642 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4643 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4644 return llvm::make_error<PreviouslyReportedError>();
4646 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4647 assert(yieldop && yieldop.getResults().size() == 1 &&
4648 "terminator must be omp.yield op and it must have exactly one "
4650 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4653 bool isIgnoreDenormalMode;
4654 bool isFineGrainedMemory;
4655 bool isRemoteMemory;
4657 isFineGrainedMemory, isRemoteMemory);
4660 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4661 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4662 ompBuilder->createAtomicCapture(
4663 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4664 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4665 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4667 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4670 builder.restoreIP(*afterIP);
4675 omp::ClauseCancellationConstructType directive) {
4676 switch (directive) {
4677 case omp::ClauseCancellationConstructType::Loop:
4678 return llvm::omp::Directive::OMPD_for;
4679 case omp::ClauseCancellationConstructType::Parallel:
4680 return llvm::omp::Directive::OMPD_parallel;
4681 case omp::ClauseCancellationConstructType::Sections:
4682 return llvm::omp::Directive::OMPD_sections;
4683 case omp::ClauseCancellationConstructType::Taskgroup:
4684 return llvm::omp::Directive::OMPD_taskgroup;
4686 llvm_unreachable(
"Unhandled cancellation construct type");
4695 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4698 llvm::Value *ifCond =
nullptr;
4699 if (
Value ifVar = op.getIfExpr())
4702 llvm::omp::Directive cancelledDirective =
4705 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4706 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4708 if (failed(
handleError(afterIP, *op.getOperation())))
4711 builder.restoreIP(afterIP.get());
4718 llvm::IRBuilderBase &builder,
4723 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4726 llvm::omp::Directive cancelledDirective =
4729 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4730 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4732 if (failed(
handleError(afterIP, *op.getOperation())))
4735 builder.restoreIP(afterIP.get());
4745 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4747 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4752 Value symAddr = threadprivateOp.getSymAddr();
4755 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4758 if (!isa<LLVM::AddressOfOp>(symOp))
4759 return opInst.
emitError(
"Addressing symbol not found");
4760 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4762 LLVM::GlobalOp global =
4763 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4764 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4765 llvm::Type *type = globalValue->getValueType();
4766 llvm::TypeSize typeSize =
4767 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4769 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4770 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4771 ompLoc, globalValue, size, global.getSymName() +
".cache");
4777static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4779 switch (deviceClause) {
4780 case mlir::omp::DeclareTargetDeviceType::host:
4781 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4783 case mlir::omp::DeclareTargetDeviceType::nohost:
4784 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4786 case mlir::omp::DeclareTargetDeviceType::any:
4787 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4790 llvm_unreachable(
"unhandled device clause");
4793static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4795 mlir::omp::DeclareTargetCaptureClause captureClause) {
4796 switch (captureClause) {
4797 case mlir::omp::DeclareTargetCaptureClause::to:
4798 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4799 case mlir::omp::DeclareTargetCaptureClause::link:
4800 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4801 case mlir::omp::DeclareTargetCaptureClause::enter:
4802 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4803 case mlir::omp::DeclareTargetCaptureClause::none:
4804 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4806 llvm_unreachable(
"unhandled capture clause");
4811 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4813 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4814 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4815 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4822 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4823 value = addrCast.getOperand();
4840static llvm::SmallString<64>
4842 llvm::OpenMPIRBuilder &ompBuilder) {
4844 llvm::raw_svector_ostream os(suffix);
4847 auto fileInfoCallBack = [&loc]() {
4848 return std::pair<std::string, uint64_t>(
4849 llvm::StringRef(loc.getFilename()), loc.getLine());
4852 auto vfs = llvm::vfs::getRealFileSystem();
4855 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4857 os <<
"_decl_tgt_ref_ptr";
4863 if (
auto declareTargetGlobal =
4864 dyn_cast_if_present<omp::DeclareTargetInterface>(
4866 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4867 omp::DeclareTargetCaptureClause::link)
4873 if (
auto declareTargetGlobal =
4874 dyn_cast_if_present<omp::DeclareTargetInterface>(
4876 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4877 omp::DeclareTargetCaptureClause::to ||
4878 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4879 omp::DeclareTargetCaptureClause::enter)
4893 if (
auto declareTargetGlobal =
4894 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4897 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4898 omp::DeclareTargetCaptureClause::link) ||
4899 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4900 omp::DeclareTargetCaptureClause::to &&
4901 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4905 if (gOp.getSymName().contains(suffix))
4910 (gOp.getSymName().str() + suffix.str()).str());
4919struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4920 SmallVector<Operation *, 4> Mappers;
4923 void append(MapInfosTy &curInfo) {
4924 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4925 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4934struct MapInfoData : MapInfosTy {
4935 llvm::SmallVector<bool, 4> IsDeclareTarget;
4936 llvm::SmallVector<bool, 4> IsAMember;
4938 llvm::SmallVector<bool, 4> IsAMapping;
4939 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4940 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4943 llvm::SmallVector<llvm::Type *, 4> BaseType;
4946 void append(MapInfoData &CurInfo) {
4947 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4948 CurInfo.IsDeclareTarget.end());
4949 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4950 OriginalValue.append(CurInfo.OriginalValue.begin(),
4951 CurInfo.OriginalValue.end());
4952 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4953 MapInfosTy::append(CurInfo);
4957enum class TargetDirectiveEnumTy : uint32_t {
4961 TargetEnterData = 3,
4966static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4967 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4968 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4969 .Case([](omp::TargetEnterDataOp) {
4970 return TargetDirectiveEnumTy::TargetEnterData;
4972 .Case([&](omp::TargetExitDataOp) {
4973 return TargetDirectiveEnumTy::TargetExitData;
4975 .Case([&](omp::TargetUpdateOp) {
4976 return TargetDirectiveEnumTy::TargetUpdate;
4978 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4979 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4986 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4987 arrTy.getElementType()))
5004 llvm::Value *basePointer,
5005 llvm::Type *baseType,
5006 llvm::IRBuilderBase &builder,
5008 if (
auto memberClause =
5009 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5014 if (!memberClause.getBounds().empty()) {
5015 llvm::Value *elementCount = builder.getInt64(1);
5016 for (
auto bounds : memberClause.getBounds()) {
5017 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5018 bounds.getDefiningOp())) {
5023 elementCount = builder.CreateMul(
5027 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
5028 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
5029 builder.getInt64(1)));
5036 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5044 return builder.CreateMul(elementCount,
5045 builder.getInt64(underlyingTypeSzInBits / 8));
5056static llvm::omp::OpenMPOffloadMappingFlags
5058 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
5059 return (mlirFlags & flag) == flag;
5061 const bool hasExplicitMap =
5062 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
5063 omp::ClauseMapFlags::none;
5065 llvm::omp::OpenMPOffloadMappingFlags mapType =
5066 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5069 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5072 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5075 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5078 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5081 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5084 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5087 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5090 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5093 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5096 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5099 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5102 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5105 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5106 if (!hasExplicitMap)
5107 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5117 ArrayRef<Value> useDevAddrOperands = {},
5118 ArrayRef<Value> hasDevAddrOperands = {}) {
5119 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
5127 for (Value mapValue : mapVars) {
5128 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5129 for (
auto member : map.getMembers())
5130 if (member == mapOp)
5137 for (Value mapValue : mapVars) {
5138 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5140 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5141 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
5142 mapData.Pointers.push_back(mapData.OriginalValue.back());
5144 if (llvm::Value *refPtr =
5146 mapData.IsDeclareTarget.push_back(
true);
5147 mapData.BasePointers.push_back(refPtr);
5149 mapData.IsDeclareTarget.push_back(
true);
5150 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5152 mapData.IsDeclareTarget.push_back(
false);
5153 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5156 mapData.BaseType.push_back(
5157 moduleTranslation.
convertType(mapOp.getVarType()));
5158 mapData.Sizes.push_back(
5159 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
5160 mapData.BaseType.back(), builder, moduleTranslation));
5161 mapData.MapClause.push_back(mapOp.getOperation());
5163 mapData.Names.push_back(LLVM::createMappingInformation(
5165 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5166 if (mapOp.getMapperId())
5167 mapData.Mappers.push_back(
5169 mapOp, mapOp.getMapperIdAttr()));
5171 mapData.Mappers.push_back(
nullptr);
5172 mapData.IsAMapping.push_back(
true);
5173 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5176 auto findMapInfo = [&mapData](llvm::Value *val,
5177 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5180 for (llvm::Value *basePtr : mapData.OriginalValue) {
5181 if (basePtr == val && mapData.IsAMapping[index]) {
5183 mapData.Types[index] |=
5184 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5185 mapData.DevicePointers[index] = devInfoTy;
5193 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
5194 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5195 for (Value mapValue : useDevOperands) {
5196 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5198 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5199 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5202 if (!findMapInfo(origValue, devInfoTy)) {
5203 mapData.OriginalValue.push_back(origValue);
5204 mapData.Pointers.push_back(mapData.OriginalValue.back());
5205 mapData.IsDeclareTarget.push_back(
false);
5206 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5207 mapData.BaseType.push_back(
5208 moduleTranslation.
convertType(mapOp.getVarType()));
5209 mapData.Sizes.push_back(builder.getInt64(0));
5210 mapData.MapClause.push_back(mapOp.getOperation());
5211 mapData.Types.push_back(
5212 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5213 mapData.Names.push_back(LLVM::createMappingInformation(
5215 mapData.DevicePointers.push_back(devInfoTy);
5216 mapData.Mappers.push_back(
nullptr);
5217 mapData.IsAMapping.push_back(
false);
5218 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5223 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5224 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5226 for (Value mapValue : hasDevAddrOperands) {
5227 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5229 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5230 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5232 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5234 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5235 omp::ClauseMapFlags::none;
5237 mapData.OriginalValue.push_back(origValue);
5238 mapData.BasePointers.push_back(origValue);
5239 mapData.Pointers.push_back(origValue);
5240 mapData.IsDeclareTarget.push_back(
false);
5241 mapData.BaseType.push_back(
5242 moduleTranslation.
convertType(mapOp.getVarType()));
5243 mapData.Sizes.push_back(
5244 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
5245 mapData.MapClause.push_back(mapOp.getOperation());
5246 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5250 mapData.Types.push_back(mapType);
5254 if (mapOp.getMapperId()) {
5255 mapData.Mappers.push_back(
5257 mapOp, mapOp.getMapperIdAttr()));
5259 mapData.Mappers.push_back(
nullptr);
5264 mapData.Types.push_back(
5265 isDevicePtr ? mapType
5266 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5267 mapData.Mappers.push_back(
nullptr);
5269 mapData.Names.push_back(LLVM::createMappingInformation(
5271 mapData.DevicePointers.push_back(
5272 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5273 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5274 mapData.IsAMapping.push_back(
false);
5275 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5280 auto *res = llvm::find(mapData.MapClause, memberOp);
5281 assert(res != mapData.MapClause.end() &&
5282 "MapInfoOp for member not found in MapData, cannot return index");
5283 return std::distance(mapData.MapClause.begin(), res);
5287 omp::MapInfoOp mapInfo) {
5288 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5298 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5299 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5301 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5302 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5303 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5305 if (aIndex == bIndex)
5308 if (aIndex < bIndex)
5311 if (aIndex > bIndex)
5318 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5320 occludedChildren.push_back(
b);
5322 occludedChildren.push_back(a);
5323 return memberAParent;
5329 for (
auto v : occludedChildren)
5336 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5338 if (indexAttr.size() == 1)
5339 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5343 return llvm::cast<omp::MapInfoOp>(
5368static std::vector<llvm::Value *>
5370 llvm::IRBuilderBase &builder,
bool isArrayTy,
5372 std::vector<llvm::Value *> idx;
5383 idx.push_back(builder.getInt64(0));
5384 for (
int i = bounds.size() - 1; i >= 0; --i) {
5385 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5386 bounds[i].getDefiningOp())) {
5387 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5405 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5406 for (
int i = bounds.size() - 1; i >= 0; --i) {
5407 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5408 bounds[i].getDefiningOp())) {
5409 if (i == ((
int)bounds.size() - 1))
5411 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5413 idx.back() = builder.CreateAdd(
5414 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
5415 boundOp.getExtent())),
5416 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5425 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
5426 return cast<IntegerAttr>(value).getInt();
5434 omp::MapInfoOp parentOp) {
5436 if (parentOp.getMembers().empty())
5440 if (parentOp.getMembers().size() == 1) {
5441 overlapMapDataIdxs.push_back(0);
5447 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5448 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
5449 memberByIndex.push_back(
5450 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
5455 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
5456 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
5462 for (
auto v : memberByIndex) {
5466 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
5469 llvm::SmallVector<int64_t> xArr(x.second.size());
5470 getAsIntegers(x.second, xArr);
5471 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
5472 xArr.size() >= vArr.size();
5478 for (
auto v : memberByIndex)
5479 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
5480 overlapMapDataIdxs.push_back(v.first);
5492 if (mapOp.getVarPtrPtr())
5521 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5522 MapInfoData &mapData, uint64_t mapDataIndex,
5523 TargetDirectiveEnumTy targetDirective) {
5524 assert(!ompBuilder.Config.isTargetDevice() &&
5525 "function only supported for host device codegen");
5528 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5530 auto *parentMapper = mapData.Mappers[mapDataIndex];
5536 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
5537 (targetDirective == TargetDirectiveEnumTy::Target &&
5538 !mapData.IsDeclareTarget[mapDataIndex])
5539 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
5540 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5543 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5547 mapFlags parentFlags = mapData.Types[mapDataIndex];
5548 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5549 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5550 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5551 baseFlag |= (parentFlags & preserve);
5554 combinedInfo.Types.emplace_back(baseFlag);
5555 combinedInfo.DevicePointers.emplace_back(
5556 mapData.DevicePointers[mapDataIndex]);
5560 combinedInfo.Mappers.emplace_back(
5561 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
5563 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5564 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5573 llvm::Value *lowAddr, *highAddr;
5574 if (!parentClause.getPartialMap()) {
5575 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5576 builder.getPtrTy());
5577 highAddr = builder.CreatePointerCast(
5578 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5579 mapData.Pointers[mapDataIndex], 1),
5580 builder.getPtrTy());
5581 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5583 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5586 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5587 builder.getPtrTy());
5590 highAddr = builder.CreatePointerCast(
5591 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5592 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5593 builder.getPtrTy());
5594 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5597 llvm::Value *size = builder.CreateIntCast(
5598 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5599 builder.getInt64Ty(),
5601 combinedInfo.Sizes.push_back(size);
5603 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5604 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5612 if (!parentClause.getPartialMap()) {
5617 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5618 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5619 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5620 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5621 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5623 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5624 combinedInfo.Types.emplace_back(mapFlag);
5625 combinedInfo.DevicePointers.emplace_back(
5626 mapData.DevicePointers[mapDataIndex]);
5628 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5629 combinedInfo.BasePointers.emplace_back(
5630 mapData.BasePointers[mapDataIndex]);
5631 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5632 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5633 combinedInfo.Mappers.emplace_back(
nullptr);
5644 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5645 builder.getPtrTy());
5646 highAddr = builder.CreatePointerCast(
5647 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5648 mapData.Pointers[mapDataIndex], 1),
5649 builder.getPtrTy());
5656 for (
auto v : overlapIdxs) {
5659 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5660 combinedInfo.Types.emplace_back(mapFlag);
5661 combinedInfo.DevicePointers.emplace_back(
5662 mapData.DevicePointers[mapDataOverlapIdx]);
5664 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5665 combinedInfo.BasePointers.emplace_back(
5666 mapData.BasePointers[mapDataIndex]);
5667 combinedInfo.Mappers.emplace_back(
nullptr);
5668 combinedInfo.Pointers.emplace_back(lowAddr);
5669 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5670 builder.CreatePtrDiff(builder.getInt8Ty(),
5671 mapData.OriginalValue[mapDataOverlapIdx],
5673 builder.getInt64Ty(),
true));
5674 lowAddr = builder.CreateConstGEP1_32(
5676 mapData.MapClause[mapDataOverlapIdx]))
5677 ? builder.getPtrTy()
5678 : mapData.BaseType[mapDataOverlapIdx],
5679 mapData.BasePointers[mapDataOverlapIdx], 1);
5682 combinedInfo.Types.emplace_back(mapFlag);
5683 combinedInfo.DevicePointers.emplace_back(
5684 mapData.DevicePointers[mapDataIndex]);
5686 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5687 combinedInfo.BasePointers.emplace_back(
5688 mapData.BasePointers[mapDataIndex]);
5689 combinedInfo.Mappers.emplace_back(
nullptr);
5690 combinedInfo.Pointers.emplace_back(lowAddr);
5691 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5692 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5693 builder.getInt64Ty(),
true));
5696 return memberOfFlag;
5702 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5703 MapInfoData &mapData, uint64_t mapDataIndex,
5704 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5705 TargetDirectiveEnumTy targetDirective) {
5706 assert(!ompBuilder.Config.isTargetDevice() &&
5707 "function only supported for host device codegen");
5710 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5712 for (
auto mappedMembers : parentClause.getMembers()) {
5714 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5717 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5728 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5729 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5730 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5731 combinedInfo.Types.emplace_back(mapFlag);
5732 combinedInfo.DevicePointers.emplace_back(
5733 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5734 combinedInfo.Mappers.emplace_back(
nullptr);
5735 combinedInfo.Names.emplace_back(
5737 combinedInfo.BasePointers.emplace_back(
5738 mapData.BasePointers[mapDataIndex]);
5739 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5740 combinedInfo.Sizes.emplace_back(builder.getInt64(
5741 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5747 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5748 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5749 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5751 ? parentClause.getVarPtr()
5752 : parentClause.getVarPtrPtr());
5755 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5756 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5757 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5760 combinedInfo.Types.emplace_back(mapFlag);
5761 combinedInfo.DevicePointers.emplace_back(
5762 mapData.DevicePointers[memberDataIdx]);
5763 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5764 combinedInfo.Names.emplace_back(
5766 uint64_t basePointerIndex =
5768 combinedInfo.BasePointers.emplace_back(
5769 mapData.BasePointers[basePointerIndex]);
5770 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5772 llvm::Value *size = mapData.Sizes[memberDataIdx];
5774 size = builder.CreateSelect(
5775 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5776 builder.getInt64(0), size);
5779 combinedInfo.Sizes.emplace_back(size);
5784 MapInfosTy &combinedInfo,
5785 TargetDirectiveEnumTy targetDirective,
5786 int mapDataParentIdx = -1) {
5790 auto mapFlag = mapData.Types[mapDataIdx];
5791 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5795 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5797 if (targetDirective == TargetDirectiveEnumTy::Target &&
5798 !mapData.IsDeclareTarget[mapDataIdx])
5799 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5801 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5803 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5808 if (mapDataParentIdx >= 0)
5809 combinedInfo.BasePointers.emplace_back(
5810 mapData.BasePointers[mapDataParentIdx]);
5812 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5814 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5815 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5816 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5817 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5818 combinedInfo.Types.emplace_back(mapFlag);
5819 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5823 llvm::IRBuilderBase &builder,
5824 llvm::OpenMPIRBuilder &ompBuilder,
5826 MapInfoData &mapData, uint64_t mapDataIndex,
5827 TargetDirectiveEnumTy targetDirective) {
5828 assert(!ompBuilder.Config.isTargetDevice() &&
5829 "function only supported for host device codegen");
5832 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5837 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5838 auto memberClause = llvm::cast<omp::MapInfoOp>(
5839 parentClause.getMembers()[0].getDefiningOp());
5856 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5858 combinedInfo, mapData, mapDataIndex,
5861 combinedInfo, mapData, mapDataIndex,
5862 memberOfParentFlag, targetDirective);
5872 llvm::IRBuilderBase &builder) {
5874 "function only supported for host device codegen");
5875 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5877 if (!mapData.IsDeclareTarget[i]) {
5878 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5879 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5889 switch (captureKind) {
5890 case omp::VariableCaptureKind::ByRef: {
5891 llvm::Value *newV = mapData.Pointers[i];
5893 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5896 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5898 if (!offsetIdx.empty())
5899 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5901 mapData.Pointers[i] = newV;
5903 case omp::VariableCaptureKind::ByCopy: {
5904 llvm::Type *type = mapData.BaseType[i];
5906 if (mapData.Pointers[i]->getType()->isPointerTy())
5907 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5909 newV = mapData.Pointers[i];
5912 auto curInsert = builder.saveIP();
5913 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5915 auto *memTempAlloc =
5916 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5917 builder.SetCurrentDebugLocation(DbgLoc);
5918 builder.restoreIP(curInsert);
5920 builder.CreateStore(newV, memTempAlloc);
5921 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5924 mapData.Pointers[i] = newV;
5925 mapData.BasePointers[i] = newV;
5927 case omp::VariableCaptureKind::This:
5928 case omp::VariableCaptureKind::VLAType:
5929 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5940 MapInfoData &mapData,
5941 TargetDirectiveEnumTy targetDirective) {
5943 "function only supported for host device codegen");
5964 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5967 if (mapData.IsAMember[i])
5970 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5971 if (!mapInfoOp.getMembers().empty()) {
5973 combinedInfo, mapData, i, targetDirective);
5981static llvm::Expected<llvm::Function *>
5983 LLVM::ModuleTranslation &moduleTranslation,
5984 llvm::StringRef mapperFuncName,
5985 TargetDirectiveEnumTy targetDirective);
5987static llvm::Expected<llvm::Function *>
5990 TargetDirectiveEnumTy targetDirective) {
5992 "function only supported for host device codegen");
5993 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5994 std::string mapperFuncName =
5996 {
"omp_mapper", declMapperOp.getSymName()});
5998 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
6006 if (llvm::Function *existingFunc =
6007 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
6008 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
6009 return existingFunc;
6013 mapperFuncName, targetDirective);
6016static llvm::Expected<llvm::Function *>
6019 llvm::StringRef mapperFuncName,
6020 TargetDirectiveEnumTy targetDirective) {
6022 "function only supported for host device codegen");
6023 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6024 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6027 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
6030 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6033 MapInfosTy combinedInfo;
6035 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6036 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6037 builder.restoreIP(codeGenIP);
6038 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
6039 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
6040 builder.GetInsertBlock());
6041 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
6044 return llvm::make_error<PreviouslyReportedError>();
6045 MapInfoData mapData;
6048 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6054 return combinedInfo;
6058 if (!combinedInfo.Mappers[i])
6061 moduleTranslation, targetDirective);
6065 genMapInfoCB, varType, mapperFuncName, customMapperCB);
6067 return newFn.takeError();
6068 if ([[maybe_unused]] llvm::Function *mappedFunc =
6070 assert(mappedFunc == *newFn &&
6071 "mapper function mapping disagrees with emitted function");
6073 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
6081 llvm::Value *ifCond =
nullptr;
6082 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6086 llvm::omp::RuntimeFunction RTLFn;
6088 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6091 llvm::OpenMPIRBuilder::TargetDataInfo info(
6094 assert(!ompBuilder->Config.isTargetDevice() &&
6095 "target data/enter/exit/update are host ops");
6096 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6098 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
6099 llvm::Value *v = moduleTranslation.
lookupValue(dev);
6100 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
6105 .Case([&](omp::TargetDataOp dataOp) {
6109 if (
auto ifVar = dataOp.getIfExpr())
6113 deviceID = getDeviceID(devId);
6115 mapVars = dataOp.getMapVars();
6116 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6117 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6120 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6124 if (
auto ifVar = enterDataOp.getIfExpr())
6128 deviceID = getDeviceID(devId);
6131 enterDataOp.getNowait()
6132 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6133 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6134 mapVars = enterDataOp.getMapVars();
6135 info.HasNoWait = enterDataOp.getNowait();
6138 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6142 if (
auto ifVar = exitDataOp.getIfExpr())
6146 deviceID = getDeviceID(devId);
6148 RTLFn = exitDataOp.getNowait()
6149 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6150 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6151 mapVars = exitDataOp.getMapVars();
6152 info.HasNoWait = exitDataOp.getNowait();
6155 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6159 if (
auto ifVar = updateDataOp.getIfExpr())
6163 deviceID = getDeviceID(devId);
6166 updateDataOp.getNowait()
6167 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6168 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6169 mapVars = updateDataOp.getMapVars();
6170 info.HasNoWait = updateDataOp.getNowait();
6173 .DefaultUnreachable(
"unexpected operation");
6178 if (!isOffloadEntry)
6179 ifCond = builder.getFalse();
6181 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6182 MapInfoData mapData;
6184 builder, useDevicePtrVars, useDeviceAddrVars);
6187 MapInfosTy combinedInfo;
6188 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6189 builder.restoreIP(codeGenIP);
6190 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6192 return combinedInfo;
6198 [&moduleTranslation](
6199 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6203 for (
auto [arg, useDevVar] :
6204 llvm::zip_equal(blockArgs, useDeviceVars)) {
6206 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6207 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6208 : mapInfoOp.getVarPtr();
6211 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6212 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6213 mapInfoData.MapClause, mapInfoData.DevicePointers,
6214 mapInfoData.BasePointers)) {
6215 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6216 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6217 devicePointer != type)
6220 if (llvm::Value *devPtrInfoMap =
6221 mapper ? mapper(basePointer) : basePointer) {
6222 moduleTranslation.
mapValue(arg, devPtrInfoMap);
6229 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6230 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6231 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6234 builder.restoreIP(codeGenIP);
6235 assert(isa<omp::TargetDataOp>(op) &&
6236 "BodyGen requested for non TargetDataOp");
6237 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6238 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
6239 switch (bodyGenType) {
6240 case BodyGenTy::Priv:
6242 if (!info.DevicePtrInfoMap.empty()) {
6243 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6244 blockArgIface.getUseDeviceAddrBlockArgs(),
6245 useDeviceAddrVars, mapData,
6246 [&](llvm::Value *basePointer) -> llvm::Value * {
6247 if (!info.DevicePtrInfoMap[basePointer].second)
6249 return builder.CreateLoad(
6251 info.DevicePtrInfoMap[basePointer].second);
6253 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6254 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6255 mapData, [&](llvm::Value *basePointer) {
6256 return info.DevicePtrInfoMap[basePointer].second;
6260 moduleTranslation)))
6261 return llvm::make_error<PreviouslyReportedError>();
6264 case BodyGenTy::DupNoPriv:
6265 if (info.DevicePtrInfoMap.empty()) {
6268 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6269 blockArgIface.getUseDeviceAddrBlockArgs(),
6270 useDeviceAddrVars, mapData);
6271 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6272 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6276 case BodyGenTy::NoPriv:
6278 if (info.DevicePtrInfoMap.empty()) {
6280 moduleTranslation)))
6281 return llvm::make_error<PreviouslyReportedError>();
6285 return builder.saveIP();
6288 auto customMapperCB =
6290 if (!combinedInfo.Mappers[i])
6292 info.HasMapper =
true;
6294 moduleTranslation, targetDirective);
6297 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6298 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6300 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6301 if (isa<omp::TargetDataOp>(op))
6302 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6303 deviceID, ifCond, info, genMapInfoCB,
6307 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6308 deviceID, ifCond, info, genMapInfoCB,
6309 customMapperCB, &RTLFn);
6315 builder.restoreIP(*afterIP);
6323 auto distributeOp = cast<omp::DistributeOp>(opInst);
6330 bool doDistributeReduction =
6334 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6339 if (doDistributeReduction) {
6340 isByRef =
getIsByRef(teamsOp.getReductionByref());
6341 assert(isByRef.size() == teamsOp.getNumReductionVars());
6344 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6348 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6349 .getReductionBlockArgs();
6352 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6353 reductionDecls, privateReductionVariables, reductionVariableMap,
6358 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6359 auto bodyGenCB = [&](InsertPointTy allocaIP,
6360 InsertPointTy codeGenIP) -> llvm::Error {
6364 moduleTranslation, allocaIP);
6367 builder.restoreIP(codeGenIP);
6373 return llvm::make_error<PreviouslyReportedError>();
6378 return llvm::make_error<PreviouslyReportedError>();
6381 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
6383 distributeOp.getPrivateNeedsBarrier())))
6384 return llvm::make_error<PreviouslyReportedError>();
6387 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6390 builder, moduleTranslation);
6392 return regionBlock.takeError();
6393 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6398 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6401 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6402 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6403 : omp::ClauseScheduleKind::Static;
6405 bool isOrdered = hasDistSchedule;
6406 std::optional<omp::ScheduleModifier> scheduleMod;
6407 bool isSimd =
false;
6408 llvm::omp::WorksharingLoopType workshareLoopType =
6409 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6410 bool loopNeedsBarrier =
false;
6411 llvm::Value *chunk = moduleTranslation.
lookupValue(
6412 distributeOp.getDistScheduleChunkSize());
6413 llvm::CanonicalLoopInfo *loopInfo =
6415 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6416 ompBuilder->applyWorkshareLoop(
6417 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6418 convertToScheduleKind(schedule), chunk, isSimd,
6419 scheduleMod == omp::ScheduleModifier::monotonic,
6420 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6421 workshareLoopType,
false, hasDistSchedule, chunk);
6424 return wsloopIP.takeError();
6427 distributeOp.getLoc(), privVarsInfo.
llvmVars,
6429 return llvm::make_error<PreviouslyReportedError>();
6431 return llvm::Error::success();
6434 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6436 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6437 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6438 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
6443 builder.restoreIP(*afterIP);
6445 if (doDistributeReduction) {
6448 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6449 privateReductionVariables, isByRef,
6461 if (!cast<mlir::ModuleOp>(op))
6466 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
6467 attribute.getOpenmpDeviceVersion());
6469 if (attribute.getNoGpuLib())
6472 ompBuilder->createGlobalFlag(
6473 attribute.getDebugKind() ,
6474 "__omp_rtl_debug_kind");
6475 ompBuilder->createGlobalFlag(
6477 .getAssumeTeamsOversubscription()
6479 "__omp_rtl_assume_teams_oversubscription");
6480 ompBuilder->createGlobalFlag(
6482 .getAssumeThreadsOversubscription()
6484 "__omp_rtl_assume_threads_oversubscription");
6485 ompBuilder->createGlobalFlag(
6486 attribute.getAssumeNoThreadState() ,
6487 "__omp_rtl_assume_no_thread_state");
6488 ompBuilder->createGlobalFlag(
6490 .getAssumeNoNestedParallelism()
6492 "__omp_rtl_assume_no_nested_parallelism");
6497 omp::TargetOp targetOp,
6498 llvm::StringRef parentName =
"") {
6499 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
6501 assert(fileLoc &&
"No file found from location");
6502 StringRef fileName = fileLoc.getFilename().getValue();
6504 llvm::sys::fs::UniqueID id;
6505 uint64_t line = fileLoc.getLine();
6506 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
6507 size_t fileHash = llvm::hash_value(fileName.str());
6508 size_t deviceId = 0xdeadf17e;
6510 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
6512 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
6513 id.getFile(), line);
6520 llvm::IRBuilderBase &builder, llvm::Function *
func) {
6522 "function only supported for target device codegen");
6523 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6524 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6537 if (mapData.IsDeclareTarget[i]) {
6544 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6545 convertUsersOfConstantsToInstructions(constant,
func,
false);
6552 for (llvm::User *user : mapData.OriginalValue[i]->users())
6553 userVec.push_back(user);
6555 for (llvm::User *user : userVec) {
6556 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
6557 if (insn->getFunction() ==
func) {
6558 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6559 llvm::Value *substitute = mapData.BasePointers[i];
6561 : mapOp.getVarPtr())) {
6562 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6563 substitute = builder.CreateLoad(
6564 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6565 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6567 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6614static llvm::IRBuilderBase::InsertPoint
6616 llvm::Value *input, llvm::Value *&retVal,
6617 llvm::IRBuilderBase &builder,
6618 llvm::OpenMPIRBuilder &ompBuilder,
6620 llvm::IRBuilderBase::InsertPoint allocaIP,
6621 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6622 assert(ompBuilder.Config.isTargetDevice() &&
6623 "function only supported for target device codegen");
6624 builder.restoreIP(allocaIP);
6626 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6628 ompBuilder.M.getContext());
6629 unsigned alignmentValue = 0;
6631 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
6632 if (mapData.OriginalValue[i] == input) {
6633 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6634 capture = mapOp.getMapCaptureType();
6637 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6641 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6642 unsigned int defaultAS =
6643 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6646 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6648 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6649 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6651 builder.CreateStore(&arg, v);
6653 builder.restoreIP(codeGenIP);
6656 case omp::VariableCaptureKind::ByCopy: {
6660 case omp::VariableCaptureKind::ByRef: {
6661 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6663 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6678 if (v->getType()->isPointerTy() && alignmentValue) {
6679 llvm::MDBuilder MDB(builder.getContext());
6680 loadInst->setMetadata(
6681 llvm::LLVMContext::MD_align,
6682 llvm::MDNode::get(builder.getContext(),
6683 MDB.createConstant(llvm::ConstantInt::get(
6684 llvm::Type::getInt64Ty(builder.getContext()),
6691 case omp::VariableCaptureKind::This:
6692 case omp::VariableCaptureKind::VLAType:
6695 assert(
false &&
"Currently unsupported capture kind");
6699 return builder.saveIP();
6716 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6717 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6718 blockArgIface.getHostEvalBlockArgs())) {
6719 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6723 .Case([&](omp::TeamsOp teamsOp) {
6724 if (teamsOp.getNumTeamsLower() == blockArg)
6725 numTeamsLower = hostEvalVar;
6726 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6728 numTeamsUpper = hostEvalVar;
6729 else if (!teamsOp.getThreadLimitVars().empty() &&
6730 teamsOp.getThreadLimit(0) == blockArg)
6731 threadLimit = hostEvalVar;
6733 llvm_unreachable(
"unsupported host_eval use");
6735 .Case([&](omp::ParallelOp parallelOp) {
6736 if (!parallelOp.getNumThreadsVars().empty() &&
6737 parallelOp.getNumThreads(0) == blockArg)
6738 numThreads = hostEvalVar;
6740 llvm_unreachable(
"unsupported host_eval use");
6742 .Case([&](omp::LoopNestOp loopOp) {
6743 auto processBounds =
6747 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6748 if (lb == blockArg) {
6751 (*outBounds)[i] = hostEvalVar;
6757 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6758 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6760 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6762 assert(found &&
"unsupported host_eval use");
6764 .DefaultUnreachable(
"unsupported host_eval use");
6776template <
typename OpTy>
6781 if (OpTy casted = dyn_cast<OpTy>(op))
6784 if (immediateParent)
6785 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6794 return std::nullopt;
6797 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6798 return constAttr.getInt();
6800 return std::nullopt;
6805 uint64_t sizeInBytes = sizeInBits / 8;
6809template <
typename OpTy>
6811 if (op.getNumReductionVars() > 0) {
6816 members.reserve(reductions.size());
6817 for (omp::DeclareReductionOp &red : reductions) {
6821 if (red.getByrefElementType())
6822 members.push_back(*red.getByrefElementType());
6824 members.push_back(red.getType());
6827 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6843 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6844 bool isTargetDevice,
bool isGPU) {
6847 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6848 if (!isTargetDevice) {
6856 numTeamsLower = teamsOp.getNumTeamsLower();
6858 if (!teamsOp.getNumTeamsUpperVars().empty())
6859 numTeamsUpper = teamsOp.getNumTeams(0);
6860 if (!teamsOp.getThreadLimitVars().empty())
6861 threadLimit = teamsOp.getThreadLimit(0);
6865 if (!parallelOp.getNumThreadsVars().empty())
6866 numThreads = parallelOp.getNumThreads(0);
6872 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6876 if (numTeamsUpper) {
6878 minTeamsVal = maxTeamsVal = *val;
6880 minTeamsVal = maxTeamsVal = 0;
6886 minTeamsVal = maxTeamsVal = 1;
6888 minTeamsVal = maxTeamsVal = -1;
6893 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6907 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6908 if (!targetOp.getThreadLimitVars().empty())
6909 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6910 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6913 int32_t maxThreadsVal = -1;
6915 setMaxValueFromClause(numThreads, maxThreadsVal);
6923 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6924 if (combinedMaxThreadsVal < 0 ||
6925 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6926 combinedMaxThreadsVal = teamsThreadLimitVal;
6928 if (combinedMaxThreadsVal < 0 ||
6929 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6930 combinedMaxThreadsVal = maxThreadsVal;
6932 int32_t reductionDataSize = 0;
6933 if (isGPU && capturedOp) {
6939 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6941 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6942 omp::TargetRegionFlags::spmd) &&
6943 "invalid kernel flags");
6945 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6946 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6947 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6948 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6949 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6950 if (omp::bitEnumContainsAll(kernelFlags,
6951 omp::TargetRegionFlags::spmd |
6952 omp::TargetRegionFlags::no_loop) &&
6953 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6954 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6956 attrs.MinTeams = minTeamsVal;
6957 attrs.MaxTeams.front() = maxTeamsVal;
6958 attrs.MinThreads = 1;
6959 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6960 attrs.ReductionDataSize = reductionDataSize;
6963 if (attrs.ReductionDataSize != 0)
6964 attrs.ReductionBufferLength = 1024;
6976 omp::TargetOp targetOp,
Operation *capturedOp,
6977 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6979 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6981 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6985 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6988 if (!targetOp.getThreadLimitVars().empty()) {
6989 Value targetThreadLimit = targetOp.getThreadLimit(0);
6990 attrs.TargetThreadLimit.front() =
6998 attrs.MinTeams = builder.CreateSExtOrTrunc(
6999 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
7002 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7003 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
7005 if (teamsThreadLimit)
7006 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7007 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
7010 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
7012 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
7013 omp::TargetRegionFlags::trip_count)) {
7015 attrs.LoopTripCount =
nullptr;
7020 for (
auto [loopLower, loopUpper, loopStep] :
7021 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7022 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
7023 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
7024 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
7026 if (!lowerBound || !upperBound || !step) {
7027 attrs.LoopTripCount =
nullptr;
7031 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7032 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7033 loc, lowerBound, upperBound, step,
true,
7034 loopOp.getLoopInclusive());
7036 if (!attrs.LoopTripCount) {
7037 attrs.LoopTripCount = tripCount;
7042 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7047 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7049 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
7051 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7058 auto targetOp = cast<omp::TargetOp>(opInst);
7063 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7072 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7073 assert(parentBB &&
"No insert block is set for the builder");
7074 llvm::Function *parentLLVMFn = parentBB->getParent();
7075 assert(parentLLVMFn &&
"Parent Function must be valid");
7076 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7077 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7078 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7079 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7082 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7083 bool isGPU = ompBuilder->Config.isGPU();
7086 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7087 auto &targetRegion = targetOp.getRegion();
7104 llvm::Function *llvmOutlinedFn =
nullptr;
7105 TargetDirectiveEnumTy targetDirective =
7106 getTargetDirectiveEnumTyFromOp(&opInst);
7110 bool isOffloadEntry =
7111 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7118 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7120 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7121 std::optional<DenseI64ArrayAttr> privateMapIndices =
7122 targetOp.getPrivateMapsAttr();
7124 for (
auto [privVarIdx, privVarSymPair] :
7125 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7126 auto privVar = std::get<0>(privVarSymPair);
7127 auto privSym = std::get<1>(privVarSymPair);
7129 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7130 omp::PrivateClauseOp privatizer =
7133 if (!privatizer.needsMap())
7137 targetOp.getMappedValueForPrivateVar(privVarIdx);
7138 assert(mappedValue &&
"Expected to find mapped value for a privatized "
7139 "variable that needs mapping");
7144 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
7145 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
7149 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7151 varType == privVar.getType() &&
7152 "Type of private var doesn't match the type of the mapped value");
7156 mappedPrivateVars.insert(
7158 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7159 (*privateMapIndices)[privVarIdx])});
7163 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7164 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
7165 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7166 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7167 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7170 llvm::Function *llvmParentFn =
7172 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7173 assert(llvmParentFn && llvmOutlinedFn &&
7174 "Both parent and outlined functions must exist at this point");
7176 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7177 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7179 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
7180 attr.isStringAttribute())
7181 llvmOutlinedFn->addFnAttr(attr);
7183 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
7184 attr.isStringAttribute())
7185 llvmOutlinedFn->addFnAttr(attr);
7187 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7188 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7189 llvm::Value *mapOpValue =
7190 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7191 moduleTranslation.
mapValue(arg, mapOpValue);
7193 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7194 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7195 llvm::Value *mapOpValue =
7196 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7197 moduleTranslation.
mapValue(arg, mapOpValue);
7206 allocaIP, &mappedPrivateVars);
7209 return llvm::make_error<PreviouslyReportedError>();
7211 builder.restoreIP(codeGenIP);
7213 &mappedPrivateVars),
7216 return llvm::make_error<PreviouslyReportedError>();
7219 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
7221 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7222 return llvm::make_error<PreviouslyReportedError>();
7226 std::back_inserter(privateCleanupRegions),
7227 [](omp::PrivateClauseOp privatizer) {
7228 return &privatizer.getDeallocRegion();
7232 targetRegion,
"omp.target", builder, moduleTranslation);
7235 return exitBlock.takeError();
7237 builder.SetInsertPoint(*exitBlock);
7238 if (!privateCleanupRegions.empty()) {
7240 privateCleanupRegions, privateVarsInfo.
llvmVars,
7241 moduleTranslation, builder,
"omp.targetop.private.cleanup",
7243 return llvm::createStringError(
7244 "failed to inline `dealloc` region of `omp.private` "
7245 "op in the target region");
7247 return builder.saveIP();
7250 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
7253 StringRef parentName = parentFn.getName();
7255 llvm::TargetRegionEntryInfo entryInfo;
7259 MapInfoData mapData;
7264 MapInfosTy combinedInfos;
7266 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7267 builder.restoreIP(codeGenIP);
7268 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7273 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7274 combinedInfos.BasePointers.push_back(nullPtr);
7275 combinedInfos.Pointers.push_back(nullPtr);
7276 combinedInfos.DevicePointers.push_back(
7277 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7278 combinedInfos.Sizes.push_back(builder.getInt64(0));
7279 combinedInfos.Types.push_back(
7280 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7281 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7282 if (!combinedInfos.Names.empty())
7283 combinedInfos.Names.push_back(nullPtr);
7284 combinedInfos.Mappers.push_back(
nullptr);
7286 return combinedInfos;
7289 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7290 llvm::Value *&retVal, InsertPointTy allocaIP,
7291 InsertPointTy codeGenIP)
7292 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7293 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7294 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7300 if (!isTargetDevice) {
7301 retVal = cast<llvm::Value>(&arg);
7306 *ompBuilder, moduleTranslation,
7307 allocaIP, codeGenIP);
7310 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7311 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7312 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7314 isTargetDevice, isGPU);
7318 if (!isTargetDevice)
7320 targetCapturedOp, runtimeAttrs);
7328 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7329 llvm::Value *value = moduleTranslation.
lookupValue(var);
7330 moduleTranslation.
mapValue(arg, value);
7332 if (!llvm::isa<llvm::Constant>(value))
7333 kernelInput.push_back(value);
7336 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7343 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
7344 kernelInput.push_back(mapData.OriginalValue[i]);
7347 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7350 llvm::OpenMPIRBuilder::DependenciesInfo dds;
7352 targetOp.getDependVars(), targetOp.getDependKinds(),
7353 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
7354 builder, moduleTranslation, dds)))
7357 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7359 llvm::OpenMPIRBuilder::TargetDataInfo info(
7363 auto customMapperCB =
7365 if (!combinedInfos.Mappers[i])
7367 info.HasMapper =
true;
7369 moduleTranslation, targetDirective);
7372 llvm::Value *ifCond =
nullptr;
7373 if (
Value targetIfCond = targetOp.getIfExpr())
7374 ifCond = moduleTranslation.
lookupValue(targetIfCond);
7376 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7378 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
7379 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
7380 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
7385 builder.restoreIP(*afterIP);
7388 builder.CreateFree(dds.DepArray);
7409 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7410 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7412 if (!offloadMod.getIsTargetDevice())
7415 omp::DeclareTargetDeviceType declareType =
7416 attribute.getDeviceType().getValue();
7418 if (declareType == omp::DeclareTargetDeviceType::host) {
7419 llvm::Function *llvmFunc =
7421 llvmFunc->dropAllReferences();
7422 llvmFunc->eraseFromParent();
7428 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7429 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7430 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7432 bool isDeclaration = gOp.isDeclaration();
7433 bool isExternallyVisible =
7436 llvm::StringRef mangledName = gOp.getSymName();
7437 auto captureClause =
7443 std::vector<llvm::GlobalVariable *> generatedRefs;
7445 std::vector<llvm::Triple> targetTriple;
7446 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7448 LLVM::LLVMDialect::getTargetTripleAttrName()));
7449 if (targetTripleAttr)
7450 targetTriple.emplace_back(targetTripleAttr.data());
7452 auto fileInfoCallBack = [&loc]() {
7453 std::string filename =
"";
7454 std::uint64_t lineNo = 0;
7457 filename = loc.getFilename().str();
7458 lineNo = loc.getLine();
7461 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7465 auto vfs = llvm::vfs::getRealFileSystem();
7467 ompBuilder->registerTargetGlobalVariable(
7468 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7469 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7470 mangledName, generatedRefs,
false, targetTriple,
7472 gVal->getType(), gVal);
7474 if (ompBuilder->Config.isTargetDevice() &&
7475 (attribute.getCaptureClause().getValue() !=
7476 mlir::omp::DeclareTargetCaptureClause::to ||
7477 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7478 ompBuilder->getAddrOfDeclareTargetVar(
7479 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7480 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7481 mangledName, generatedRefs,
false, targetTriple,
7482 gVal->getType(),
nullptr,
7495class OpenMPDialectLLVMIRTranslationInterface
7496 :
public LLVMTranslationDialectInterface {
7498 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7503 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7504 LLVM::ModuleTranslation &moduleTranslation)
const final;
7509 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7510 NamedAttribute attribute,
7511 LLVM::ModuleTranslation &moduleTranslation)
const final;
7516 void registerAllocatedPtr(Value var, llvm::Value *ptr)
const {
7517 ompAllocatedPtrs[var] = ptr;
7522 llvm::Value *lookupAllocatedPtr(Value var)
const {
7523 auto it = ompAllocatedPtrs.find(var);
7524 return it != ompAllocatedPtrs.end() ? it->second :
nullptr;
7536LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7537 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7538 NamedAttribute attribute,
7539 LLVM::ModuleTranslation &moduleTranslation)
const {
7540 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7542 .Case(
"omp.is_target_device",
7543 [&](Attribute attr) {
7544 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7545 llvm::OpenMPIRBuilderConfig &config =
7547 config.setIsTargetDevice(deviceAttr.getValue());
7553 [&](Attribute attr) {
7554 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7555 llvm::OpenMPIRBuilderConfig &config =
7557 config.setIsGPU(gpuAttr.getValue());
7562 .Case(
"omp.host_ir_filepath",
7563 [&](Attribute attr) {
7564 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7565 llvm::OpenMPIRBuilder *ompBuilder =
7567 auto VFS = llvm::vfs::getRealFileSystem();
7568 ompBuilder->loadOffloadInfoMetadata(*VFS,
7569 filepathAttr.getValue());
7575 [&](Attribute attr) {
7576 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7580 .Case(
"omp.version",
7581 [&](Attribute attr) {
7582 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7583 llvm::OpenMPIRBuilder *ompBuilder =
7585 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
7586 versionAttr.getVersion());
7591 .Case(
"omp.declare_target",
7592 [&](Attribute attr) {
7593 if (
auto declareTargetAttr =
7594 dyn_cast<omp::DeclareTargetAttr>(attr))
7599 .Case(
"omp.requires",
7600 [&](Attribute attr) {
7601 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7602 using Requires = omp::ClauseRequires;
7603 Requires flags = requiresAttr.getValue();
7604 llvm::OpenMPIRBuilderConfig &config =
7606 config.setHasRequiresReverseOffload(
7607 bitEnumContainsAll(flags, Requires::reverse_offload));
7608 config.setHasRequiresUnifiedAddress(
7609 bitEnumContainsAll(flags, Requires::unified_address));
7610 config.setHasRequiresUnifiedSharedMemory(
7611 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7612 config.setHasRequiresDynamicAllocators(
7613 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7618 .Case(
"omp.target_triples",
7619 [&](Attribute attr) {
7620 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7621 llvm::OpenMPIRBuilderConfig &config =
7623 config.TargetTriples.clear();
7624 config.TargetTriples.reserve(triplesAttr.size());
7625 for (Attribute tripleAttr : triplesAttr) {
7626 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7627 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7635 .Default([](Attribute) {
7651 if (
auto declareTargetIface =
7652 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7653 parentFn.getOperation()))
7654 if (declareTargetIface.isDeclareTarget() &&
7655 declareTargetIface.getDeclareTargetDeviceType() !=
7656 mlir::omp::DeclareTargetDeviceType::host)
7666 llvm::Module *llvmModule) {
7667 llvm::Type *i64Ty = builder.getInt64Ty();
7668 llvm::Type *i32Ty = builder.getInt32Ty();
7669 llvm::Type *returnType = builder.getPtrTy(0);
7670 llvm::FunctionType *fnType =
7671 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
7672 llvm::Function *
func = cast<llvm::Function>(
7673 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
7680 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7685 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7689 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7691 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7692 mlir::Type heapTy = allocMemOp.getAllocatedType();
7693 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
7694 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7695 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7696 for (
auto typeParam : allocMemOp.getTypeparams())
7698 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
7700 llvm::CallInst *call =
7701 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7702 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7705 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
7712 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
7713 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
7716 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7717 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7718 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7720 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
7722 llvm::Value *allocator;
7723 if (
auto allocatorVar = allocateDirOp.getAllocator()) {
7724 allocator = moduleTranslation.
lookupValue(allocatorVar);
7725 if (allocator->getType()->isIntegerTy())
7726 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
7727 else if (allocator->getType()->isPointerTy())
7728 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
7729 allocator, builder.getPtrTy());
7731 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
7734 for (
Value var : vars) {
7735 llvm::Type *llvmVarTy = moduleTranslation.
convertType(var.getType());
7739 llvm::Type *typeToInspect = llvmVarTy;
7740 if (llvmVarTy->isPointerTy()) {
7743 if (
auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
7744 typeToInspect = moduleTranslation.
convertType(gop.getGlobalType());
7749 if (
auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
7750 llvm::Value *elementCount = builder.getInt64(1);
7751 llvm::Type *currentType = arrTy;
7752 while (
auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
7753 elementCount = builder.CreateMul(
7754 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
7755 currentType = nestedArrTy->getElementType();
7757 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
7759 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
7761 size = builder.getInt64(
7762 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
7765 uint64_t alignValue =
7766 alignAttr ? alignAttr.value()
7767 : dataLayout.getABITypeAlign(typeToInspect).value();
7768 llvm::Value *alignConst = builder.getInt64(alignValue);
7770 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1),
"",
true);
7771 size = builder.CreateUDiv(size, alignConst);
7772 size = builder.CreateMul(size, alignConst,
"",
true);
7774 std::string allocName =
7775 ompBuilder->createPlatformSpecificName({
".void.addr"});
7776 llvm::CallInst *allocCall;
7777 if (alignAttr.has_value()) {
7778 allocCall = ompBuilder->createOMPAlignedAlloc(
7779 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
7783 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
7786 ompIface.registerAllocatedPtr(var, allocCall);
7795 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
7796 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
7798 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7800 llvm::Value *allocator;
7801 if (
auto allocatorVar = freeOp.getAllocator()) {
7802 allocator = moduleTranslation.
lookupValue(allocatorVar);
7803 if (allocator->getType()->isIntegerTy())
7804 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
7805 else if (allocator->getType()->isPointerTy())
7806 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
7807 allocator, builder.getPtrTy());
7809 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
7814 for (
Value var : llvm::reverse(vars)) {
7815 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
7817 return opInst.
emitError(
"omp.allocate_free: no allocation recorded");
7818 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator,
"");
7825 llvm::Module *llvmModule) {
7826 llvm::Type *ptrTy = builder.getPtrTy(0);
7827 llvm::Type *i32Ty = builder.getInt32Ty();
7828 llvm::Type *voidTy = builder.getVoidTy();
7829 llvm::FunctionType *fnType =
7830 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
7831 llvm::Function *
func = dyn_cast<llvm::Function>(
7832 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
7839 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7844 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7848 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7851 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
7853 llvm::Value *intToPtr =
7854 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7855 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7861LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7862 Operation *op, llvm::IRBuilderBase &builder,
7863 LLVM::ModuleTranslation &moduleTranslation)
const {
7866 if (ompBuilder->Config.isTargetDevice() &&
7867 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7870 return op->
emitOpError() <<
"unsupported host op found in device";
7878 bool isOutermostLoopWrapper =
7879 isa_and_present<omp::LoopWrapperInterface>(op) &&
7880 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
7889 if (isa<omp::TaskloopContextOp>(op))
7890 isOutermostLoopWrapper =
true;
7891 else if (isa<omp::TaskloopWrapperOp>(op))
7892 isOutermostLoopWrapper =
false;
7894 if (isOutermostLoopWrapper)
7895 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
7898 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7899 .Case([&](omp::BarrierOp op) -> LogicalResult {
7903 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7904 ompBuilder->createBarrier(builder.saveIP(),
7905 llvm::omp::OMPD_barrier);
7907 if (res.succeeded()) {
7910 builder.restoreIP(*afterIP);
7914 .Case([&](omp::TaskyieldOp op) {
7918 ompBuilder->createTaskyield(builder.saveIP());
7921 .Case([&](omp::FlushOp op) {
7933 ompBuilder->createFlush(builder.saveIP());
7936 .Case([&](omp::ParallelOp op) {
7939 .Case([&](omp::MaskedOp) {
7942 .Case([&](omp::MasterOp) {
7945 .Case([&](omp::CriticalOp) {
7948 .Case([&](omp::OrderedRegionOp) {
7951 .Case([&](omp::OrderedOp) {
7954 .Case([&](omp::WsloopOp) {
7957 .Case([&](omp::SimdOp) {
7960 .Case([&](omp::AtomicReadOp) {
7963 .Case([&](omp::AtomicWriteOp) {
7966 .Case([&](omp::AtomicUpdateOp op) {
7969 .Case([&](omp::AtomicCaptureOp op) {
7972 .Case([&](omp::CancelOp op) {
7975 .Case([&](omp::CancellationPointOp op) {
7978 .Case([&](omp::SectionsOp) {
7981 .Case([&](omp::SingleOp op) {
7984 .Case([&](omp::TeamsOp op) {
7987 .Case([&](omp::TaskOp op) {
7990 .Case([&](omp::TaskloopWrapperOp op) {
7993 .Case([&](omp::TaskloopContextOp op) {
7996 .Case([&](omp::TaskgroupOp op) {
7999 .Case([&](omp::TaskwaitOp op) {
8002 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8003 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8004 omp::CriticalDeclareOp>([](
auto op) {
8017 .Case([&](omp::ThreadprivateOp) {
8020 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8021 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
8024 .Case([&](omp::TargetOp) {
8027 .Case([&](omp::DistributeOp) {
8030 .Case([&](omp::LoopNestOp) {
8033 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8034 omp::AffinityEntryOp, omp::IteratorOp>([&](
auto op) {
8040 .Case([&](omp::NewCliOp op) {
8045 .Case([&](omp::CanonicalLoopOp op) {
8048 .Case([&](omp::UnrollHeuristicOp op) {
8057 .Case([&](omp::TileOp op) {
8058 return applyTile(op, builder, moduleTranslation);
8060 .Case([&](omp::FuseOp op) {
8061 return applyFuse(op, builder, moduleTranslation);
8063 .Case([&](omp::TargetAllocMemOp) {
8066 .Case([&](omp::TargetFreeMemOp) {
8069 .Case([&](omp::AllocateDirOp) {
8072 .Case([&](omp::AllocateFreeOp) {
8076 .Default([&](Operation *inst) {
8078 <<
"not yet implemented: " << inst->
getName();
8081 if (isOutermostLoopWrapper)
8088 registry.
insert<omp::OpenMPDialect>();
8090 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars