24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
198 if (linearVarType->isIntegerTy()) {
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 }
else if (linearVarType->isFloatingPointTy()) {
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
220 "Linear variable must be of integer or floating-point type");
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(),
"omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
265 builder.SetInsertPoint(linearExitBB->getTerminator());
267 builder.saveIP(), llvm::omp::OMPD_barrier);
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
282 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (
auto *user : users) {
288 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
307 assert(privatizer &&
"privatizer not found in the symbol table");
318 auto todo = [&op](StringRef clauseName) {
319 return op.
emitError() <<
"not yet implemented: Unhandled clause "
320 << clauseName <<
" in " << op.
getName()
324 auto checkAffinity = [&todo](
auto op, LogicalResult &
result) {
325 if (!op.getAffinityVars().empty())
326 result = todo(
"affinity");
328 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
329 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
330 result = todo(
"allocate");
332 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
334 result = todo(
"ompx_bare");
336 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
340 auto checkHint = [](
auto op, LogicalResult &) {
344 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo(
"in_reduction");
349 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
353 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
354 if (op.getOrder() || op.getOrderMod())
357 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
358 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
359 result = todo(
"privatization");
361 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
362 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
363 if (!op.getReductionVars().empty() || op.getReductionByref() ||
364 op.getReductionSyms())
365 result = todo(
"reduction");
366 if (op.getReductionMod() &&
367 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
368 result = todo(
"reduction with modifier");
370 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
371 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
372 op.getTaskReductionSyms())
373 result = todo(
"task_reduction");
375 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
376 if (op.hasNumTeamsMultiDim())
377 result = todo(
"num_teams with multi-dimensional values");
379 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
380 if (op.hasNumThreadsMultiDim())
381 result = todo(
"num_threads with multi-dimensional values");
384 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
385 if (op.hasThreadLimitMultiDim())
386 result = todo(
"thread_limit with multi-dimensional values");
391 .Case([&](omp::DistributeOp op) {
392 checkAllocate(op,
result);
395 .Case([&](omp::SectionsOp op) {
396 checkAllocate(op,
result);
398 checkReduction(op,
result);
400 .Case([&](omp::SingleOp op) {
401 checkAllocate(op,
result);
404 .Case([&](omp::TeamsOp op) {
405 checkAllocate(op,
result);
407 checkNumTeams(op,
result);
408 checkThreadLimit(op,
result);
410 .Case([&](omp::TaskOp op) {
411 checkAffinity(op,
result);
412 checkAllocate(op,
result);
413 checkInReduction(op,
result);
415 .Case([&](omp::TaskgroupOp op) {
416 checkAllocate(op,
result);
417 checkTaskReduction(op,
result);
419 .Case([&](omp::TaskwaitOp op) {
423 .Case([&](omp::TaskloopOp op) {
424 checkAllocate(op,
result);
425 checkInReduction(op,
result);
426 checkReduction(op,
result);
428 .Case([&](omp::WsloopOp op) {
429 checkAllocate(op,
result);
431 checkReduction(op,
result);
433 .Case([&](omp::ParallelOp op) {
434 checkAllocate(op,
result);
435 checkReduction(op,
result);
436 checkNumThreads(op,
result);
438 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
439 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
440 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
441 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
442 [&](
auto op) { checkDepend(op,
result); })
443 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
444 .Case([&](omp::TargetOp op) {
445 checkAllocate(op,
result);
447 checkInReduction(op,
result);
448 checkThreadLimit(op,
result);
460 llvm::handleAllErrors(
462 [&](
const PreviouslyReportedError &) {
result = failure(); },
463 [&](
const llvm::ErrorInfoBase &err) {
480static llvm::OpenMPIRBuilder::InsertPointTy
486 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
488 [&](OpenMPAllocaStackFrame &frame) {
489 allocaInsertPoint = frame.allocaInsertPoint;
497 allocaInsertPoint.getBlock()->getParent() ==
498 builder.GetInsertBlock()->getParent())
499 return allocaInsertPoint;
508 if (builder.GetInsertBlock() ==
509 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
510 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
511 "Assuming end of basic block");
512 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
513 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
514 builder.GetInsertBlock()->getNextNode());
515 builder.CreateBr(entryBB);
516 builder.SetInsertPoint(entryBB);
519 llvm::BasicBlock &funcEntryBlock =
520 builder.GetInsertBlock()->getParent()->getEntryBlock();
521 return llvm::OpenMPIRBuilder::InsertPointTy(
522 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
528static llvm::CanonicalLoopInfo *
530 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
531 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
532 [&](OpenMPLoopInfoStackFrame &frame) {
533 loopInfo = frame.loopInfo;
545 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
548 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
550 llvm::BasicBlock *continuationBlock =
551 splitBB(builder,
true,
"omp.region.cont");
552 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
554 llvm::LLVMContext &llvmContext = builder.getContext();
555 for (
Block &bb : region) {
556 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
557 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
558 builder.GetInsertBlock()->getNextNode());
559 moduleTranslation.
mapBlock(&bb, llvmBB);
562 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
569 unsigned numYields = 0;
571 if (!isLoopWrapper) {
572 bool operandsProcessed =
false;
574 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
575 if (!operandsProcessed) {
576 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
577 continuationBlockPHITypes.push_back(
578 moduleTranslation.
convertType(yield->getOperand(i).getType()));
580 operandsProcessed =
true;
582 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
583 "mismatching number of values yielded from the region");
584 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
585 llvm::Type *operandType =
586 moduleTranslation.
convertType(yield->getOperand(i).getType());
588 assert(continuationBlockPHITypes[i] == operandType &&
589 "values of mismatching types yielded from the region");
599 if (!continuationBlockPHITypes.empty())
601 continuationBlockPHIs &&
602 "expected continuation block PHIs if converted regions yield values");
603 if (continuationBlockPHIs) {
604 llvm::IRBuilderBase::InsertPointGuard guard(builder);
605 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
606 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
607 for (llvm::Type *ty : continuationBlockPHITypes)
608 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
614 for (
Block *bb : blocks) {
615 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
618 if (bb->isEntryBlock()) {
619 assert(sourceTerminator->getNumSuccessors() == 1 &&
620 "provided entry block has multiple successors");
621 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
622 "ContinuationBlock is not the successor of the entry block");
623 sourceTerminator->setSuccessor(0, llvmBB);
626 llvm::IRBuilderBase::InsertPointGuard guard(builder);
628 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
629 return llvm::make_error<PreviouslyReportedError>();
634 builder.CreateBr(continuationBlock);
645 Operation *terminator = bb->getTerminator();
646 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
647 builder.CreateBr(continuationBlock);
649 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
650 (*continuationBlockPHIs)[i]->addIncoming(
664 return continuationBlock;
670 case omp::ClauseProcBindKind::Close:
671 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
672 case omp::ClauseProcBindKind::Master:
673 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
674 case omp::ClauseProcBindKind::Primary:
675 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
676 case omp::ClauseProcBindKind::Spread:
677 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
679 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
686 auto maskedOp = cast<omp::MaskedOp>(opInst);
687 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
692 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
694 auto ®ion = maskedOp.getRegion();
695 builder.restoreIP(codeGenIP);
703 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
705 llvm::Value *filterVal =
nullptr;
706 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
707 filterVal = moduleTranslation.
lookupValue(filterVar);
709 llvm::LLVMContext &llvmContext = builder.getContext();
711 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
713 assert(filterVal !=
nullptr);
714 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
715 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
722 builder.restoreIP(*afterIP);
730 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
731 auto masterOp = cast<omp::MasterOp>(opInst);
736 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
738 auto ®ion = masterOp.getRegion();
739 builder.restoreIP(codeGenIP);
747 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
749 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
750 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
757 builder.restoreIP(*afterIP);
765 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
766 auto criticalOp = cast<omp::CriticalOp>(opInst);
771 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
773 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
774 builder.restoreIP(codeGenIP);
782 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
784 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
785 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
786 llvm::Constant *hint =
nullptr;
789 if (criticalOp.getNameAttr()) {
792 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
793 auto criticalDeclareOp =
797 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
798 static_cast<int>(criticalDeclareOp.getHint()));
800 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
802 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
807 builder.restoreIP(*afterIP);
814 template <
typename OP>
817 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
820 collectPrivatizationDecls<OP>(op);
835 void collectPrivatizationDecls(OP op) {
836 std::optional<ArrayAttr> attr = op.getPrivateSyms();
841 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
852 std::optional<ArrayAttr> attr = op.getReductionSyms();
856 reductions.reserve(reductions.size() + op.getNumReductionVars());
857 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
858 reductions.push_back(
870 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
879 llvm::Instruction *potentialTerminator =
880 builder.GetInsertBlock()->empty() ?
nullptr
881 : &builder.GetInsertBlock()->back();
883 if (potentialTerminator && potentialTerminator->isTerminator())
884 potentialTerminator->removeFromParent();
885 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
888 region.
front(),
true, builder)))
892 if (continuationBlockArgs)
894 *continuationBlockArgs,
901 if (potentialTerminator && potentialTerminator->isTerminator()) {
902 llvm::BasicBlock *block = builder.GetInsertBlock();
903 if (block->empty()) {
909 potentialTerminator->insertInto(block, block->begin());
911 potentialTerminator->insertAfter(&block->back());
925 if (continuationBlockArgs)
926 llvm::append_range(*continuationBlockArgs, phis);
927 builder.SetInsertPoint(*continuationBlock,
928 (*continuationBlock)->getFirstInsertionPt());
935using OwningReductionGen =
936 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
937 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
939using OwningAtomicReductionGen =
940 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
941 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
943using OwningDataPtrPtrReductionGen =
944 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
945 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
951static OwningReductionGen
957 OwningReductionGen gen =
958 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
959 llvm::Value *
lhs, llvm::Value *
rhs,
960 llvm::Value *&
result)
mutable
961 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
962 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
963 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
964 builder.restoreIP(insertPoint);
967 "omp.reduction.nonatomic.body", builder,
968 moduleTranslation, &phis)))
969 return llvm::createStringError(
970 "failed to inline `combiner` region of `omp.declare_reduction`");
971 result = llvm::getSingleElement(phis);
972 return builder.saveIP();
981static OwningAtomicReductionGen
983 llvm::IRBuilderBase &builder,
985 if (decl.getAtomicReductionRegion().empty())
986 return OwningAtomicReductionGen();
992 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
993 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
994 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
995 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
996 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
997 builder.restoreIP(insertPoint);
1000 "omp.reduction.atomic.body", builder,
1001 moduleTranslation, &phis)))
1002 return llvm::createStringError(
1003 "failed to inline `atomic` region of `omp.declare_reduction`");
1004 assert(phis.empty());
1005 return builder.saveIP();
1014static OwningDataPtrPtrReductionGen
1018 return OwningDataPtrPtrReductionGen();
1020 OwningDataPtrPtrReductionGen refDataPtrGen =
1021 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1022 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1023 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1024 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1025 builder.restoreIP(insertPoint);
1028 "omp.data_ptr_ptr.body", builder,
1029 moduleTranslation, &phis)))
1030 return llvm::createStringError(
1031 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1032 result = llvm::getSingleElement(phis);
1033 return builder.saveIP();
1036 return refDataPtrGen;
1043 auto orderedOp = cast<omp::OrderedOp>(opInst);
1048 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1049 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1050 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1052 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1054 size_t indexVecValues = 0;
1055 while (indexVecValues < vecValues.size()) {
1057 storeValues.reserve(numLoops);
1058 for (
unsigned i = 0; i < numLoops; i++) {
1059 storeValues.push_back(vecValues[indexVecValues]);
1062 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1064 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1065 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1066 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1076 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1077 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1082 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1084 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1085 builder.restoreIP(codeGenIP);
1093 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1095 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1096 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1098 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1103 builder.restoreIP(*afterIP);
1109struct DeferredStore {
1110 DeferredStore(llvm::Value *value, llvm::Value *address)
1111 : value(value), address(address) {}
1114 llvm::Value *address;
1121template <
typename T>
1124 llvm::IRBuilderBase &builder,
1126 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1132 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1133 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1136 deferredStores.reserve(loop.getNumReductionVars());
1138 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1139 Region &allocRegion = reductionDecls[i].getAllocRegion();
1141 if (allocRegion.
empty())
1146 builder, moduleTranslation, &phis)))
1147 return loop.emitError(
1148 "failed to inline `alloc` region of `omp.declare_reduction`");
1150 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1151 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1155 llvm::Value *var = builder.CreateAlloca(
1156 moduleTranslation.
convertType(reductionDecls[i].getType()));
1158 llvm::Type *ptrTy = builder.getPtrTy();
1159 llvm::Value *castVar =
1160 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1161 llvm::Value *castPhi =
1162 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1164 deferredStores.emplace_back(castPhi, castVar);
1166 privateReductionVariables[i] = castVar;
1167 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1168 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1170 assert(allocRegion.
empty() &&
1171 "allocaction is implicit for by-val reduction");
1172 llvm::Value *var = builder.CreateAlloca(
1173 moduleTranslation.
convertType(reductionDecls[i].getType()));
1175 llvm::Type *ptrTy = builder.getPtrTy();
1176 llvm::Value *castVar =
1177 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1179 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1180 privateReductionVariables[i] = castVar;
1181 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1189template <
typename T>
1192 llvm::IRBuilderBase &builder,
1197 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1198 Region &initializerRegion = reduction.getInitializerRegion();
1201 mlir::Value mlirSource = loop.getReductionVars()[i];
1202 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1203 llvm::Value *origVal = llvmSource;
1205 if (!isa<LLVM::LLVMPointerType>(
1206 reduction.getInitializerMoldArg().getType()) &&
1207 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1210 reduction.getInitializerMoldArg().getType()),
1211 llvmSource,
"omp_orig");
1213 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1216 llvm::Value *allocation =
1217 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1218 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1224 llvm::BasicBlock *block =
nullptr) {
1225 if (block ==
nullptr)
1226 block = builder.GetInsertBlock();
1228 if (block->empty() || block->getTerminator() ==
nullptr)
1229 builder.SetInsertPoint(block);
1231 builder.SetInsertPoint(block->getTerminator());
1239template <
typename OP>
1242 llvm::IRBuilderBase &builder,
1244 llvm::BasicBlock *latestAllocaBlock,
1250 if (op.getNumReductionVars() == 0)
1253 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1254 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1255 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1256 builder.restoreIP(allocaIP);
1259 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1261 if (!reductionDecls[i].getAllocRegion().empty())
1267 byRefVars[i] = builder.CreateAlloca(
1268 moduleTranslation.
convertType(reductionDecls[i].getType()));
1276 for (
auto [data, addr] : deferredStores)
1277 builder.CreateStore(data, addr);
1282 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1287 reductionVariableMap, i);
1295 "omp.reduction.neutral", builder,
1296 moduleTranslation, &phis)))
1299 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1300 "reduction neutral element declaration region");
1305 if (!reductionDecls[i].getAllocRegion().empty())
1314 builder.CreateStore(phis[0], byRefVars[i]);
1316 privateReductionVariables[i] = byRefVars[i];
1317 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1318 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1321 builder.CreateStore(phis[0], privateReductionVariables[i]);
1328 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1335template <
typename T>
1336static void collectReductionInfo(
1337 T loop, llvm::IRBuilderBase &builder,
1346 unsigned numReductions = loop.getNumReductionVars();
1348 for (
unsigned i = 0; i < numReductions; ++i) {
1351 owningAtomicReductionGens.push_back(
1354 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1358 reductionInfos.reserve(numReductions);
1359 for (
unsigned i = 0; i < numReductions; ++i) {
1360 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1361 if (owningAtomicReductionGens[i])
1362 atomicGen = owningAtomicReductionGens[i];
1363 llvm::Value *variable =
1364 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1367 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1368 allocatedType = alloca.getElemType();
1375 reductionInfos.push_back(
1377 privateReductionVariables[i],
1378 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1382 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1383 reductionDecls[i].getByrefElementType()
1385 *reductionDecls[i].getByrefElementType())
1395 llvm::IRBuilderBase &builder, StringRef regionName,
1396 bool shouldLoadCleanupRegionArg =
true) {
1397 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1398 if (cleanupRegion->empty())
1404 llvm::Instruction *potentialTerminator =
1405 builder.GetInsertBlock()->empty() ?
nullptr
1406 : &builder.GetInsertBlock()->back();
1407 if (potentialTerminator && potentialTerminator->isTerminator())
1408 builder.SetInsertPoint(potentialTerminator);
1409 llvm::Value *privateVarValue =
1410 shouldLoadCleanupRegionArg
1411 ? builder.CreateLoad(
1413 privateVariables[i])
1414 : privateVariables[i];
1419 moduleTranslation)))
1432 OP op, llvm::IRBuilderBase &builder,
1434 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1437 bool isNowait =
false,
bool isTeamsReduction =
false) {
1439 if (op.getNumReductionVars() == 0)
1451 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1453 owningReductionGenRefDataPtrGens,
1454 privateReductionVariables, reductionInfos, isByRef);
1459 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1460 builder.SetInsertPoint(tempTerminator);
1461 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1462 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1463 isByRef, isNowait, isTeamsReduction);
1468 if (!contInsertPoint->getBlock())
1469 return op->emitOpError() <<
"failed to convert reductions";
1471 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1472 if (!isTeamsReduction) {
1473 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1474 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1478 afterIP = *barrierIP;
1481 tempTerminator->eraseFromParent();
1482 builder.restoreIP(afterIP);
1486 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1487 [](omp::DeclareReductionOp reductionDecl) {
1488 return &reductionDecl.getCleanupRegion();
1491 moduleTranslation, builder,
1492 "omp.reduction.cleanup");
1503template <
typename OP>
1507 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1512 if (op.getNumReductionVars() == 0)
1518 allocaIP, reductionDecls,
1519 privateReductionVariables, reductionVariableMap,
1520 deferredStores, isByRef)))
1523 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1524 allocaIP.getBlock(), reductionDecls,
1525 privateReductionVariables, reductionVariableMap,
1526 isByRef, deferredStores);
1540 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1543 Value blockArg = (*mappedPrivateVars)[privateVar];
1546 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1547 "A block argument corresponding to a mapped var should have "
1550 if (privVarType == blockArgType)
1557 if (!isa<LLVM::LLVMPointerType>(privVarType))
1558 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1571 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1573 llvm::BasicBlock *privInitBlock,
1575 Region &initRegion = privDecl.getInitRegion();
1576 if (initRegion.
empty())
1577 return llvmPrivateVar;
1579 assert(nonPrivateVar);
1580 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1581 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1586 moduleTranslation, &phis)))
1587 return llvm::createStringError(
1588 "failed to inline `init` region of `omp.private`");
1590 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1607 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1610 builder, moduleTranslation, privDecl,
1613 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1622 return llvm::Error::success();
1624 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1627 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1630 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1632 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1633 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1636 return privVarOrErr.takeError();
1638 llvmPrivateVar = privVarOrErr.get();
1639 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1644 return llvm::Error::success();
1654 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1657 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1658 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1659 allocaTerminator->getIterator()),
1660 true, allocaTerminator->getStableDebugLoc(),
1661 "omp.region.after_alloca");
1663 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1665 allocaTerminator = allocaIP.getBlock()->getTerminator();
1666 builder.SetInsertPoint(allocaTerminator);
1668 assert(allocaTerminator->getNumSuccessors() == 1 &&
1669 "This is an unconditional branch created by splitBB");
1671 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1672 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1674 unsigned int allocaAS =
1675 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1678 .getProgramAddressSpace();
1680 for (
auto [privDecl, mlirPrivVar, blockArg] :
1683 llvm::Type *llvmAllocType =
1684 moduleTranslation.
convertType(privDecl.getType());
1685 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1686 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1687 llvmAllocType,
nullptr,
"omp.private.alloc");
1688 if (allocaAS != defaultAS)
1689 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1690 builder.getPtrTy(defaultAS));
1692 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1695 return afterAllocas;
1703 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1712 if (mlir::isa<omp::ParallelOp>(parent))
1726 bool needsFirstprivate =
1727 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1728 return privOp.getDataSharingType() ==
1729 omp::DataSharingClauseType::FirstPrivate;
1732 if (!needsFirstprivate)
1735 llvm::BasicBlock *copyBlock =
1736 splitBB(builder,
true,
"omp.private.copy");
1739 for (
auto [decl, moldVar, llvmVar] :
1740 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1741 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1745 Region ©Region = decl.getCopyRegion();
1747 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1750 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1754 moduleTranslation)))
1755 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1769 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1770 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1786 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1788 llvm::Value *moldVar = findAssociatedValue(
1789 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1794 llvmPrivateVars, privateDecls, insertBarrier,
1805 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1806 [](omp::PrivateClauseOp privatizer) {
1807 return &privatizer.getDeallocRegion();
1811 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1812 "omp.private.dealloc",
false)))
1813 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1814 "`omp.private` op in");
1826 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1836 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1837 using StorableBodyGenCallbackTy =
1838 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1840 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1846 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1850 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1854 sectionsOp.getNumReductionVars());
1858 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1861 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1862 reductionDecls, privateReductionVariables, reductionVariableMap,
1869 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1873 Region ®ion = sectionOp.getRegion();
1874 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1875 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1876 builder.restoreIP(codeGenIP);
1883 sectionsOp.getRegion().getNumArguments());
1884 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1885 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1886 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1888 moduleTranslation.
mapValue(sectionArg, llvmVal);
1895 sectionCBs.push_back(sectionCB);
1901 if (sectionCBs.empty())
1904 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1909 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1910 llvm::Value &vPtr, llvm::Value *&replacementValue)
1911 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1912 replacementValue = &vPtr;
1918 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1922 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1923 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1925 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1926 sectionsOp.getNowait());
1931 builder.restoreIP(*afterIP);
1935 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1936 privateReductionVariables, isByRef, sectionsOp.getNowait());
1943 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1944 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1949 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1950 builder.restoreIP(codegenIP);
1952 builder, moduleTranslation)
1955 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1959 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1962 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1963 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1965 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1966 llvmCPFuncs.push_back(
1970 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1972 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1978 builder.restoreIP(*afterIP);
1984 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1989 for (
auto ra : iface.getReductionBlockArgs())
1990 for (
auto &use : ra.getUses()) {
1991 auto *useOp = use.getOwner();
1993 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1994 debugUses.push_back(useOp);
1998 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2003 Operation *currentOp = currentDistOp.getOperation();
2004 if (distOp && (distOp != currentOp))
2013 for (
auto *use : debugUses)
2022 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2027 unsigned numReductionVars = op.getNumReductionVars();
2031 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2037 if (doTeamsReduction) {
2038 isByRef =
getIsByRef(op.getReductionByref());
2040 assert(isByRef.size() == op.getNumReductionVars());
2043 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2048 op, reductionArgs, builder, moduleTranslation, allocaIP,
2049 reductionDecls, privateReductionVariables, reductionVariableMap,
2054 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2056 moduleTranslation, allocaIP);
2057 builder.restoreIP(codegenIP);
2063 llvm::Value *numTeamsLower =
nullptr;
2064 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2065 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2067 llvm::Value *numTeamsUpper =
nullptr;
2068 if (!op.getNumTeamsUpperVars().empty())
2069 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2071 llvm::Value *threadLimit =
nullptr;
2072 if (!op.getThreadLimitVars().empty())
2073 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2075 llvm::Value *ifExpr =
nullptr;
2076 if (
Value ifVar = op.getIfExpr())
2079 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2080 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2082 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2087 builder.restoreIP(*afterIP);
2088 if (doTeamsReduction) {
2091 op, builder, moduleTranslation, allocaIP, reductionDecls,
2092 privateReductionVariables, isByRef,
2102 if (dependVars.empty())
2104 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2105 llvm::omp::RTLDependenceKindTy type;
2107 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2108 case mlir::omp::ClauseTaskDepend::taskdependin:
2109 type = llvm::omp::RTLDependenceKindTy::DepIn;
2114 case mlir::omp::ClauseTaskDepend::taskdependout:
2115 case mlir::omp::ClauseTaskDepend::taskdependinout:
2116 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2118 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2119 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2121 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2122 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2125 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2126 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2127 dds.emplace_back(dd);
2139 llvm::IRBuilderBase &llvmBuilder,
2141 llvm::omp::Directive cancelDirective) {
2142 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2143 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2147 llvmBuilder.restoreIP(ip);
2153 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2154 return llvm::Error::success();
2159 ompBuilder.pushFinalizationCB(
2169 llvm::OpenMPIRBuilder &ompBuilder,
2170 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2171 ompBuilder.popFinalizationCB();
2172 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2173 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2174 assert(cancelBranch->getNumSuccessors() == 1 &&
2175 "cancel branch should have one target");
2176 cancelBranch->setSuccessor(0, constructFini);
2183class TaskContextStructManager {
2185 TaskContextStructManager(llvm::IRBuilderBase &builder,
2186 LLVM::ModuleTranslation &moduleTranslation,
2187 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2188 : builder{builder}, moduleTranslation{moduleTranslation},
2189 privateDecls{privateDecls} {}
2195 void generateTaskContextStruct();
2201 void createGEPsToPrivateVars();
2207 SmallVector<llvm::Value *>
2208 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2211 void freeStructPtr();
2213 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2214 return llvmPrivateVarGEPs;
2217 llvm::Value *getStructPtr() {
return structPtr; }
2220 llvm::IRBuilderBase &builder;
2221 LLVM::ModuleTranslation &moduleTranslation;
2222 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2225 SmallVector<llvm::Type *> privateVarTypes;
2229 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2232 llvm::Value *structPtr =
nullptr;
2234 llvm::Type *structTy =
nullptr;
2238void TaskContextStructManager::generateTaskContextStruct() {
2239 if (privateDecls.empty())
2241 privateVarTypes.reserve(privateDecls.size());
2243 for (omp::PrivateClauseOp &privOp : privateDecls) {
2246 if (!privOp.readsFromMold())
2248 Type mlirType = privOp.getType();
2249 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2252 if (privateVarTypes.empty())
2255 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2258 llvm::DataLayout dataLayout =
2259 builder.GetInsertBlock()->getModule()->getDataLayout();
2260 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2261 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2264 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2266 "omp.task.context_ptr");
2269SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2270 llvm::Value *altStructPtr)
const {
2271 SmallVector<llvm::Value *> ret;
2274 ret.reserve(privateDecls.size());
2275 llvm::Value *zero = builder.getInt32(0);
2277 for (
auto privDecl : privateDecls) {
2278 if (!privDecl.readsFromMold()) {
2280 ret.push_back(
nullptr);
2283 llvm::Value *iVal = builder.getInt32(i);
2284 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2291void TaskContextStructManager::createGEPsToPrivateVars() {
2293 assert(privateVarTypes.empty());
2297 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2300void TaskContextStructManager::freeStructPtr() {
2304 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2306 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2307 builder.CreateFree(structPtr);
2314 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2319 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2331 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2336 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2337 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2338 builder.getContext(),
"omp.task.start",
2339 builder.GetInsertBlock()->getParent());
2340 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2341 builder.SetInsertPoint(branchToTaskStartBlock);
2344 llvm::BasicBlock *copyBlock =
2345 splitBB(builder,
true,
"omp.private.copy");
2346 llvm::BasicBlock *initBlock =
2347 splitBB(builder,
true,
"omp.private.init");
2363 moduleTranslation, allocaIP);
2366 builder.SetInsertPoint(initBlock->getTerminator());
2369 taskStructMgr.generateTaskContextStruct();
2376 taskStructMgr.createGEPsToPrivateVars();
2378 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2381 taskStructMgr.getLLVMPrivateVarGEPs())) {
2383 if (!privDecl.readsFromMold())
2385 assert(llvmPrivateVarAlloc &&
2386 "reads from mold so shouldn't have been skipped");
2389 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2390 blockArg, llvmPrivateVarAlloc, initBlock);
2391 if (!privateVarOrErr)
2392 return handleError(privateVarOrErr, *taskOp.getOperation());
2401 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2402 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2403 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2405 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2406 llvmPrivateVarAlloc);
2408 assert(llvmPrivateVarAlloc->getType() ==
2409 moduleTranslation.
convertType(blockArg.getType()));
2419 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2420 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2421 taskOp.getPrivateNeedsBarrier())))
2422 return llvm::failure();
2425 builder.SetInsertPoint(taskStartBlock);
2427 auto bodyCB = [&](InsertPointTy allocaIP,
2428 InsertPointTy codegenIP) -> llvm::Error {
2432 moduleTranslation, allocaIP);
2435 builder.restoreIP(codegenIP);
2437 llvm::BasicBlock *privInitBlock =
nullptr;
2439 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2442 auto [blockArg, privDecl, mlirPrivVar] = zip;
2444 if (privDecl.readsFromMold())
2447 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2448 llvm::Type *llvmAllocType =
2449 moduleTranslation.
convertType(privDecl.getType());
2450 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2451 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2452 llvmAllocType,
nullptr,
"omp.private.alloc");
2455 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2456 blockArg, llvmPrivateVar, privInitBlock);
2457 if (!privateVarOrError)
2458 return privateVarOrError.takeError();
2459 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2460 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2463 taskStructMgr.createGEPsToPrivateVars();
2464 for (
auto [i, llvmPrivVar] :
2465 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2467 assert(privateVarsInfo.
llvmVars[i] &&
2468 "This is added in the loop above");
2471 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2476 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2480 if (!privateDecl.readsFromMold())
2483 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2484 llvmPrivateVar = builder.CreateLoad(
2485 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2487 assert(llvmPrivateVar->getType() ==
2488 moduleTranslation.
convertType(blockArg.getType()));
2489 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2493 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2494 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2495 return llvm::make_error<PreviouslyReportedError>();
2497 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2502 return llvm::make_error<PreviouslyReportedError>();
2505 taskStructMgr.freeStructPtr();
2507 return llvm::Error::success();
2516 llvm::omp::Directive::OMPD_taskgroup);
2520 moduleTranslation, dds);
2522 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2523 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2525 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2527 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2528 taskOp.getMergeable(),
2529 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2530 moduleTranslation.
lookupValue(taskOp.getPriority()));
2538 builder.restoreIP(*afterIP);
2546 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2547 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2555 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2558 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2561 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2562 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2563 builder.getContext(),
"omp.taskloop.start",
2564 builder.GetInsertBlock()->getParent());
2565 llvm::Instruction *branchToTaskloopStartBlock =
2566 builder.CreateBr(taskloopStartBlock);
2567 builder.SetInsertPoint(branchToTaskloopStartBlock);
2569 llvm::BasicBlock *copyBlock =
2570 splitBB(builder,
true,
"omp.private.copy");
2571 llvm::BasicBlock *initBlock =
2572 splitBB(builder,
true,
"omp.private.init");
2575 moduleTranslation, allocaIP);
2578 builder.SetInsertPoint(initBlock->getTerminator());
2581 taskStructMgr.generateTaskContextStruct();
2582 taskStructMgr.createGEPsToPrivateVars();
2584 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
2586 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2588 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2589 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2591 if (!privDecl.readsFromMold())
2593 assert(llvmPrivateVarAlloc &&
2594 "reads from mold so shouldn't have been skipped");
2597 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2598 blockArg, llvmPrivateVarAlloc, initBlock);
2599 if (!privateVarOrErr)
2600 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2602 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2604 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2605 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2607 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2608 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2609 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2611 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2612 llvmPrivateVarAlloc);
2614 assert(llvmPrivateVarAlloc->getType() ==
2615 moduleTranslation.
convertType(blockArg.getType()));
2621 taskloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2622 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2623 taskloopOp.getPrivateNeedsBarrier())))
2624 return llvm::failure();
2627 builder.SetInsertPoint(taskloopStartBlock);
2629 auto bodyCB = [&](InsertPointTy allocaIP,
2630 InsertPointTy codegenIP) -> llvm::Error {
2634 moduleTranslation, allocaIP);
2637 builder.restoreIP(codegenIP);
2639 llvm::BasicBlock *privInitBlock =
nullptr;
2641 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2644 auto [blockArg, privDecl, mlirPrivVar] = zip;
2646 if (privDecl.readsFromMold())
2649 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2650 llvm::Type *llvmAllocType =
2651 moduleTranslation.
convertType(privDecl.getType());
2652 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2653 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2654 llvmAllocType,
nullptr,
"omp.private.alloc");
2657 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2658 blockArg, llvmPrivateVar, privInitBlock);
2659 if (!privateVarOrError)
2660 return privateVarOrError.takeError();
2661 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2662 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2665 taskStructMgr.createGEPsToPrivateVars();
2666 for (
auto [i, llvmPrivVar] :
2667 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2669 assert(privateVarsInfo.
llvmVars[i] &&
2670 "This is added in the loop above");
2673 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2678 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2682 if (!privateDecl.readsFromMold())
2685 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2686 llvmPrivateVar = builder.CreateLoad(
2687 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2689 assert(llvmPrivateVar->getType() ==
2690 moduleTranslation.
convertType(blockArg.getType()));
2691 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2694 auto continuationBlockOrError =
2696 builder, moduleTranslation);
2698 if (failed(
handleError(continuationBlockOrError, opInst)))
2699 return llvm::make_error<PreviouslyReportedError>();
2701 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2709 taskloopOp.getLoc(), privateVarsInfo.
llvmVars,
2711 return llvm::make_error<PreviouslyReportedError>();
2714 taskStructMgr.freeStructPtr();
2716 return llvm::Error::success();
2722 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2723 llvm::Value *destPtr, llvm::Value *srcPtr)
2725 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2726 builder.restoreIP(codegenIP);
2729 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2731 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
2733 TaskContextStructManager &srcStructMgr = taskStructMgr;
2734 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2736 destStructMgr.generateTaskContextStruct();
2737 llvm::Value *dest = destStructMgr.getStructPtr();
2738 dest->setName(
"omp.taskloop.context.dest");
2739 builder.CreateStore(dest, destPtr);
2742 srcStructMgr.createGEPsToPrivateVars(src);
2744 destStructMgr.createGEPsToPrivateVars(dest);
2747 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2748 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
2751 if (!privDecl.readsFromMold())
2753 assert(llvmPrivateVarAlloc &&
2754 "reads from mold so shouldn't have been skipped");
2757 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2758 llvmPrivateVarAlloc, builder.GetInsertBlock());
2759 if (!privateVarOrErr)
2760 return privateVarOrErr.takeError();
2769 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2770 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2771 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2773 llvmPrivateVarAlloc = builder.CreateLoad(
2774 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2776 assert(llvmPrivateVarAlloc->getType() ==
2777 moduleTranslation.
convertType(blockArg.getType()));
2785 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2786 privateVarsInfo.
privatizers, taskloopOp.getPrivateNeedsBarrier())))
2787 return llvm::make_error<PreviouslyReportedError>();
2789 return builder.saveIP();
2792 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2802 llvm::Type *boundType =
2803 moduleTranslation.
lookupValue(lowerBounds[0])->getType();
2804 llvm::Value *lbVal =
nullptr;
2805 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2806 llvm::Value *stepVal =
nullptr;
2807 if (loopOp.getCollapseNumLoops() > 1) {
2825 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2826 llvm::Value *loopLb = moduleTranslation.
lookupValue(lowerBounds[i]);
2827 llvm::Value *loopUb = moduleTranslation.
lookupValue(upperBounds[i]);
2828 llvm::Value *loopStep = moduleTranslation.
lookupValue(steps[i]);
2834 llvm::Value *loopLbMinusOne = builder.CreateSub(
2835 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2836 llvm::Value *loopUbMinusOne = builder.CreateSub(
2837 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2838 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2839 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2840 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2841 llvm::Value *loopTripCount =
2842 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2843 loopTripCount = builder.CreateBinaryIntrinsic(
2844 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2848 llvm::Value *loopTripCountDivStep =
2849 builder.CreateSDiv(loopTripCount, loopStep);
2850 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2851 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2852 llvm::Value *loopTripCountRem =
2853 builder.CreateSRem(loopTripCount, loopStep);
2854 loopTripCountRem = builder.CreateBinaryIntrinsic(
2855 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2856 llvm::Value *needsRoundUp = builder.CreateICmpNE(
2858 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2861 builder.CreateAdd(loopTripCountDivStep,
2862 builder.CreateZExtOrTrunc(
2863 needsRoundUp, loopTripCountDivStep->getType()));
2864 ubVal = builder.CreateMul(ubVal, loopTripCount);
2866 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2867 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2869 lbVal = moduleTranslation.
lookupValue(lowerBounds[0]);
2870 ubVal = moduleTranslation.
lookupValue(upperBounds[0]);
2871 stepVal = moduleTranslation.
lookupValue(steps[0]);
2873 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
2874 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
2875 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
2877 llvm::Value *ifCond =
nullptr;
2878 llvm::Value *grainsize =
nullptr;
2880 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2881 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2882 if (
Value ifVar = taskloopOp.getIfExpr())
2885 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
2887 }
else if (numTasksVal) {
2888 grainsize = moduleTranslation.
lookupValue(numTasksVal);
2892 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
2893 if (taskStructMgr.getStructPtr())
2894 taskDupOrNull = taskDupCB;
2904 llvm::omp::Directive::OMPD_taskgroup);
2906 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2907 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2909 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
2910 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2911 sched, moduleTranslation.
lookupValue(taskloopOp.getFinal()),
2912 taskloopOp.getMergeable(),
2913 moduleTranslation.
lookupValue(taskloopOp.getPriority()),
2914 loopOp.getCollapseNumLoops(), taskDupOrNull,
2915 taskStructMgr.getStructPtr());
2922 builder.restoreIP(*afterIP);
2930 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2934 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2935 builder.restoreIP(codegenIP);
2937 builder, moduleTranslation)
2942 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2943 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2950 builder.restoreIP(*afterIP);
2969 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2973 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2975 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2979 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2982 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2983 llvm::Type *ivType = step->getType();
2984 llvm::Value *chunk =
nullptr;
2985 if (wsloopOp.getScheduleChunk()) {
2986 llvm::Value *chunkVar =
2987 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2988 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2991 omp::DistributeOp distributeOp =
nullptr;
2992 llvm::Value *distScheduleChunk =
nullptr;
2993 bool hasDistSchedule =
false;
2994 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
2995 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
2996 hasDistSchedule = distributeOp.getDistScheduleStatic();
2997 if (distributeOp.getDistScheduleChunkSize()) {
2998 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
2999 distributeOp.getDistScheduleChunkSize());
3000 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3008 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3012 wsloopOp.getNumReductionVars());
3015 builder, moduleTranslation, privateVarsInfo, allocaIP);
3022 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3027 moduleTranslation, allocaIP, reductionDecls,
3028 privateReductionVariables, reductionVariableMap,
3029 deferredStores, isByRef)))
3038 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3040 wsloopOp.getPrivateNeedsBarrier())))
3043 assert(afterAllocas.get()->getSinglePredecessor());
3044 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3046 afterAllocas.get()->getSinglePredecessor(),
3047 reductionDecls, privateReductionVariables,
3048 reductionVariableMap, isByRef, deferredStores)))
3052 bool isOrdered = wsloopOp.getOrdered().has_value();
3053 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3054 bool isSimd = wsloopOp.getScheduleSimd();
3055 bool loopNeedsBarrier = !wsloopOp.getNowait();
3060 llvm::omp::WorksharingLoopType workshareLoopType =
3061 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3062 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3063 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3067 llvm::omp::Directive::OMPD_for);
3069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3072 LinearClauseProcessor linearClauseProcessor;
3074 if (!wsloopOp.getLinearVars().empty()) {
3075 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3077 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3079 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3080 linearClauseProcessor.createLinearVar(
3081 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3083 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3084 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3088 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3096 if (!wsloopOp.getLinearVars().empty()) {
3097 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3098 loopInfo->getPreheader());
3099 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3101 builder.saveIP(), llvm::omp::OMPD_barrier);
3104 builder.restoreIP(*afterBarrierIP);
3105 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3106 loopInfo->getIndVar());
3107 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3110 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3113 bool noLoopMode =
false;
3114 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3116 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3120 if (loopOp == targetCapturedOp) {
3121 omp::TargetRegionFlags kernelFlags =
3122 targetOp.getKernelExecFlags(targetCapturedOp);
3123 if (omp::bitEnumContainsAll(kernelFlags,
3124 omp::TargetRegionFlags::spmd |
3125 omp::TargetRegionFlags::no_loop) &&
3126 !omp::bitEnumContainsAny(kernelFlags,
3127 omp::TargetRegionFlags::generic))
3132 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3133 ompBuilder->applyWorkshareLoop(
3134 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3135 convertToScheduleKind(schedule), chunk, isSimd,
3136 scheduleMod == omp::ScheduleModifier::monotonic,
3137 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3138 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3144 if (!wsloopOp.getLinearVars().empty()) {
3145 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3146 assert(loopInfo->getLastIter() &&
3147 "`lastiter` in CanonicalLoopInfo is nullptr");
3148 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3149 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3150 loopInfo->getLastIter());
3153 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3154 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3156 builder.restoreIP(oldIP);
3164 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3165 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3178 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3180 assert(isByRef.size() == opInst.getNumReductionVars());
3193 opInst.getNumReductionVars());
3196 auto bodyGenCB = [&](InsertPointTy allocaIP,
3197 InsertPointTy codeGenIP) -> llvm::Error {
3199 builder, moduleTranslation, privateVarsInfo, allocaIP);
3201 return llvm::make_error<PreviouslyReportedError>();
3207 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3210 InsertPointTy(allocaIP.getBlock(),
3211 allocaIP.getBlock()->getTerminator()->getIterator());
3214 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3215 reductionDecls, privateReductionVariables, reductionVariableMap,
3216 deferredStores, isByRef)))
3217 return llvm::make_error<PreviouslyReportedError>();
3219 assert(afterAllocas.get()->getSinglePredecessor());
3220 builder.restoreIP(codeGenIP);
3226 return llvm::make_error<PreviouslyReportedError>();
3229 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3231 opInst.getPrivateNeedsBarrier())))
3232 return llvm::make_error<PreviouslyReportedError>();
3235 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3236 afterAllocas.get()->getSinglePredecessor(),
3237 reductionDecls, privateReductionVariables,
3238 reductionVariableMap, isByRef, deferredStores)))
3239 return llvm::make_error<PreviouslyReportedError>();
3244 moduleTranslation, allocaIP);
3248 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3250 return regionBlock.takeError();
3253 if (opInst.getNumReductionVars() > 0) {
3258 owningReductionGenRefDataPtrGens;
3260 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3262 owningReductionGenRefDataPtrGens,
3263 privateReductionVariables, reductionInfos, isByRef);
3266 builder.SetInsertPoint((*regionBlock)->getTerminator());
3269 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3270 builder.SetInsertPoint(tempTerminator);
3272 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3273 ompBuilder->createReductions(
3274 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3276 if (!contInsertPoint)
3277 return contInsertPoint.takeError();
3279 if (!contInsertPoint->getBlock())
3280 return llvm::make_error<PreviouslyReportedError>();
3282 tempTerminator->eraseFromParent();
3283 builder.restoreIP(*contInsertPoint);
3286 return llvm::Error::success();
3289 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3290 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3299 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3300 InsertPointTy oldIP = builder.saveIP();
3301 builder.restoreIP(codeGenIP);
3306 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3307 [](omp::DeclareReductionOp reductionDecl) {
3308 return &reductionDecl.getCleanupRegion();
3311 reductionCleanupRegions, privateReductionVariables,
3312 moduleTranslation, builder,
"omp.reduction.cleanup")))
3313 return llvm::createStringError(
3314 "failed to inline `cleanup` region of `omp.declare_reduction`");
3319 return llvm::make_error<PreviouslyReportedError>();
3323 if (isCancellable) {
3324 auto IPOrErr = ompBuilder->createBarrier(
3325 llvm::OpenMPIRBuilder::LocationDescription(builder),
3326 llvm::omp::Directive::OMPD_unknown,
3330 return IPOrErr.takeError();
3333 builder.restoreIP(oldIP);
3334 return llvm::Error::success();
3337 llvm::Value *ifCond =
nullptr;
3338 if (
auto ifVar = opInst.getIfExpr())
3340 llvm::Value *numThreads =
nullptr;
3341 if (!opInst.getNumThreadsVars().empty())
3342 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
3343 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3344 if (
auto bind = opInst.getProcBindKind())
3347 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3349 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3351 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3352 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3353 ifCond, numThreads, pbKind, isCancellable);
3358 builder.restoreIP(*afterIP);
3363static llvm::omp::OrderKind
3366 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3368 case omp::ClauseOrderKind::Concurrent:
3369 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3371 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3379 auto simdOp = cast<omp::SimdOp>(opInst);
3387 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3390 simdOp.getNumReductionVars());
3395 assert(isByRef.size() == simdOp.getNumReductionVars());
3397 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3401 builder, moduleTranslation, privateVarsInfo, allocaIP);
3406 LinearClauseProcessor linearClauseProcessor;
3408 if (!simdOp.getLinearVars().empty()) {
3409 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3411 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3412 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3413 bool isImplicit =
false;
3414 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3418 if (linearVar == mlirPrivVar) {
3420 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3421 llvmPrivateVar, idx);
3427 linearClauseProcessor.createLinearVar(
3428 builder, moduleTranslation,
3431 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3432 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3436 moduleTranslation, allocaIP, reductionDecls,
3437 privateReductionVariables, reductionVariableMap,
3438 deferredStores, isByRef)))
3449 assert(afterAllocas.get()->getSinglePredecessor());
3450 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3452 afterAllocas.get()->getSinglePredecessor(),
3453 reductionDecls, privateReductionVariables,
3454 reductionVariableMap, isByRef, deferredStores)))
3457 llvm::ConstantInt *simdlen =
nullptr;
3458 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3459 simdlen = builder.getInt64(simdlenVar.value());
3461 llvm::ConstantInt *safelen =
nullptr;
3462 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3463 safelen = builder.getInt64(safelenVar.value());
3465 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3468 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3469 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3471 for (
size_t i = 0; i < operands.size(); ++i) {
3472 llvm::Value *alignment =
nullptr;
3473 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3474 llvm::Type *ty = llvmVal->getType();
3476 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3477 alignment = builder.getInt64(intAttr.getInt());
3478 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3479 assert(alignment &&
"Invalid alignment value");
3483 if (!intAttr.getValue().isPowerOf2())
3486 auto curInsert = builder.saveIP();
3487 builder.SetInsertPoint(sourceBlock);
3488 llvmVal = builder.CreateLoad(ty, llvmVal);
3489 builder.restoreIP(curInsert);
3490 alignedVars[llvmVal] = alignment;
3494 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3501 if (simdOp.getLinearVars().size()) {
3502 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3503 loopInfo->getPreheader());
3505 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3506 loopInfo->getIndVar());
3508 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3510 ompBuilder->applySimd(loopInfo, alignedVars,
3512 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3514 order, simdlen, safelen);
3516 linearClauseProcessor.emitStoresForLinearVar(builder);
3519 bool hasOrderedRegions =
false;
3520 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
3521 hasOrderedRegions =
true;
3525 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
3526 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3528 if (hasOrderedRegions) {
3530 linearClauseProcessor.rewriteInPlace(builder,
"omp.ordered.region",
3533 linearClauseProcessor.rewriteInPlace(builder,
"omp_region.finalize",
3542 for (
auto [i, tuple] : llvm::enumerate(
3543 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3544 privateReductionVariables))) {
3545 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3547 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3548 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3549 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3553 llvm::Value *redValue = originalVariable;
3556 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3557 llvm::Value *privateRedValue = builder.CreateLoad(
3558 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3559 llvm::Value *reduced;
3561 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3564 builder.restoreIP(res.get());
3568 builder.CreateStore(reduced, originalVariable);
3573 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3574 [](omp::DeclareReductionOp reductionDecl) {
3575 return &reductionDecl.getCleanupRegion();
3578 moduleTranslation, builder,
3579 "omp.reduction.cleanup")))
3592 auto loopOp = cast<omp::LoopNestOp>(opInst);
3598 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3603 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3604 llvm::Value *iv) -> llvm::Error {
3607 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3612 bodyInsertPoints.push_back(ip);
3614 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3615 return llvm::Error::success();
3618 builder.restoreIP(ip);
3620 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3622 return regionBlock.takeError();
3624 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3625 return llvm::Error::success();
3633 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3634 llvm::Value *lowerBound =
3635 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3636 llvm::Value *upperBound =
3637 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3638 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3643 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3644 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3646 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3648 computeIP = loopInfos.front()->getPreheaderIP();
3652 ompBuilder->createCanonicalLoop(
3653 loc, bodyGen, lowerBound, upperBound, step,
3654 true, loopOp.getLoopInclusive(), computeIP);
3659 loopInfos.push_back(*loopResult);
3662 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3663 loopInfos.front()->getAfterIP();
3666 if (
const auto &tiles = loopOp.getTileSizes()) {
3667 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3670 for (
auto tile : tiles.value()) {
3671 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3672 tileSizes.push_back(tileVal);
3675 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3676 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3680 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3681 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3682 afterIP = {afterAfterBB, afterAfterBB->begin()};
3686 for (
const auto &newLoop : newLoops)
3687 loopInfos.push_back(newLoop);
3691 const auto &numCollapse = loopOp.getCollapseNumLoops();
3693 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3695 auto newTopLoopInfo =
3696 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3698 assert(newTopLoopInfo &&
"New top loop information is missing");
3699 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3700 [&](OpenMPLoopInfoStackFrame &frame) {
3701 frame.loopInfo = newTopLoopInfo;
3709 builder.restoreIP(afterIP);
3719 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3720 Value loopIV = op.getInductionVar();
3721 Value loopTC = op.getTripCount();
3723 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3726 ompBuilder->createCanonicalLoop(
3728 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3731 moduleTranslation.
mapValue(loopIV, llvmIV);
3733 builder.restoreIP(ip);
3738 return bodyGenStatus.takeError();
3740 llvmTC,
"omp.loop");
3742 return op.emitError(llvm::toString(llvmOrError.takeError()));
3744 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3745 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3746 builder.restoreIP(afterIP);
3749 if (
Value cli = op.getCli())
3762 Value applyee = op.getApplyee();
3763 assert(applyee &&
"Loop to apply unrolling on required");
3765 llvm::CanonicalLoopInfo *consBuilderCLI =
3767 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3768 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3776static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3779 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3784 for (
Value size : op.getSizes()) {
3785 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3786 assert(translatedSize &&
3787 "sizes clause arguments must already be translated");
3788 translatedSizes.push_back(translatedSize);
3791 for (
Value applyee : op.getApplyees()) {
3792 llvm::CanonicalLoopInfo *consBuilderCLI =
3794 assert(applyee &&
"Canonical loop must already been translated");
3795 translatedLoops.push_back(consBuilderCLI);
3798 auto generatedLoops =
3799 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3800 if (!op.getGeneratees().empty()) {
3801 for (
auto [mlirLoop,
genLoop] :
3802 zip_equal(op.getGeneratees(), generatedLoops))
3807 for (
Value applyee : op.getApplyees())
3815static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
3818 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3822 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
3823 Value applyee = op.getApplyees()[i];
3824 llvm::CanonicalLoopInfo *consBuilderCLI =
3826 assert(applyee &&
"Canonical loop must already been translated");
3827 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
3828 beforeFuse.push_back(consBuilderCLI);
3829 else if (op.getCount().has_value() &&
3830 i >= op.getFirst().value() + op.getCount().value() - 1)
3831 afterFuse.push_back(consBuilderCLI);
3833 toFuse.push_back(consBuilderCLI);
3836 (op.getGeneratees().empty() ||
3837 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
3838 "Wrong number of generatees");
3841 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
3842 if (!op.getGeneratees().empty()) {
3844 for (; i < beforeFuse.size(); i++)
3845 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
3846 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
3847 for (; i < afterFuse.size(); i++)
3848 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
3852 for (
Value applyee : op.getApplyees())
3859static llvm::AtomicOrdering
3862 return llvm::AtomicOrdering::Monotonic;
3865 case omp::ClauseMemoryOrderKind::Seq_cst:
3866 return llvm::AtomicOrdering::SequentiallyConsistent;
3867 case omp::ClauseMemoryOrderKind::Acq_rel:
3868 return llvm::AtomicOrdering::AcquireRelease;
3869 case omp::ClauseMemoryOrderKind::Acquire:
3870 return llvm::AtomicOrdering::Acquire;
3871 case omp::ClauseMemoryOrderKind::Release:
3872 return llvm::AtomicOrdering::Release;
3873 case omp::ClauseMemoryOrderKind::Relaxed:
3874 return llvm::AtomicOrdering::Monotonic;
3876 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3883 auto readOp = cast<omp::AtomicReadOp>(opInst);
3888 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3891 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3894 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3895 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3897 llvm::Type *elementType =
3898 moduleTranslation.
convertType(readOp.getElementType());
3900 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3901 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3902 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3910 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3915 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3918 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3920 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3921 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3922 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3923 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3926 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3934 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3935 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3936 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3937 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3938 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3939 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3940 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3941 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3942 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3943 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3947 bool &isIgnoreDenormalMode,
3948 bool &isFineGrainedMemory,
3949 bool &isRemoteMemory) {
3950 isIgnoreDenormalMode =
false;
3951 isFineGrainedMemory =
false;
3952 isRemoteMemory =
false;
3953 if (atomicUpdateOp &&
3954 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3955 mlir::omp::AtomicControlAttr atomicControlAttr =
3956 atomicUpdateOp.getAtomicControlAttr();
3957 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3958 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3959 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3966 llvm::IRBuilderBase &builder,
3973 auto &innerOpList = opInst.getRegion().front().getOperations();
3974 bool isXBinopExpr{
false};
3975 llvm::AtomicRMWInst::BinOp binop;
3977 llvm::Value *llvmExpr =
nullptr;
3978 llvm::Value *llvmX =
nullptr;
3979 llvm::Type *llvmXElementType =
nullptr;
3980 if (innerOpList.size() == 2) {
3986 opInst.getRegion().getArgument(0))) {
3987 return opInst.emitError(
"no atomic update operation with region argument"
3988 " as operand found inside atomic.update region");
3991 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3993 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3997 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3999 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4001 opInst.getRegion().getArgument(0).getType());
4002 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4006 llvm::AtomicOrdering atomicOrdering =
4011 [&opInst, &moduleTranslation](
4012 llvm::Value *atomicx,
4015 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4016 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4017 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4018 return llvm::make_error<PreviouslyReportedError>();
4020 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4021 assert(yieldop && yieldop.getResults().size() == 1 &&
4022 "terminator must be omp.yield op and it must have exactly one "
4024 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4027 bool isIgnoreDenormalMode;
4028 bool isFineGrainedMemory;
4029 bool isRemoteMemory;
4034 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4035 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4036 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4037 atomicOrdering, binop, updateFn,
4038 isXBinopExpr, isIgnoreDenormalMode,
4039 isFineGrainedMemory, isRemoteMemory);
4044 builder.restoreIP(*afterIP);
4050 llvm::IRBuilderBase &builder,
4057 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4058 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4060 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4061 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4063 assert((atomicUpdateOp || atomicWriteOp) &&
4064 "internal op must be an atomic.update or atomic.write op");
4066 if (atomicWriteOp) {
4067 isPostfixUpdate =
true;
4068 mlirExpr = atomicWriteOp.getExpr();
4070 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4071 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4072 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4075 if (innerOpList.size() == 2) {
4078 atomicUpdateOp.getRegion().getArgument(0))) {
4079 return atomicUpdateOp.emitError(
4080 "no atomic update operation with region argument"
4081 " as operand found inside atomic.update region");
4085 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4088 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4092 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4093 llvm::Value *llvmX =
4094 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4095 llvm::Value *llvmV =
4096 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4097 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4098 atomicCaptureOp.getAtomicReadOp().getElementType());
4099 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4102 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4106 llvm::AtomicOrdering atomicOrdering =
4110 [&](llvm::Value *atomicx,
4113 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4114 Block &bb = *atomicUpdateOp.getRegion().
begin();
4115 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4117 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4118 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4119 return llvm::make_error<PreviouslyReportedError>();
4121 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4122 assert(yieldop && yieldop.getResults().size() == 1 &&
4123 "terminator must be omp.yield op and it must have exactly one "
4125 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4128 bool isIgnoreDenormalMode;
4129 bool isFineGrainedMemory;
4130 bool isRemoteMemory;
4132 isFineGrainedMemory, isRemoteMemory);
4135 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4136 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4137 ompBuilder->createAtomicCapture(
4138 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4139 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4140 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4142 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4145 builder.restoreIP(*afterIP);
4150 omp::ClauseCancellationConstructType directive) {
4151 switch (directive) {
4152 case omp::ClauseCancellationConstructType::Loop:
4153 return llvm::omp::Directive::OMPD_for;
4154 case omp::ClauseCancellationConstructType::Parallel:
4155 return llvm::omp::Directive::OMPD_parallel;
4156 case omp::ClauseCancellationConstructType::Sections:
4157 return llvm::omp::Directive::OMPD_sections;
4158 case omp::ClauseCancellationConstructType::Taskgroup:
4159 return llvm::omp::Directive::OMPD_taskgroup;
4161 llvm_unreachable(
"Unhandled cancellation construct type");
4170 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4173 llvm::Value *ifCond =
nullptr;
4174 if (
Value ifVar = op.getIfExpr())
4177 llvm::omp::Directive cancelledDirective =
4180 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4181 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4183 if (failed(
handleError(afterIP, *op.getOperation())))
4186 builder.restoreIP(afterIP.get());
4193 llvm::IRBuilderBase &builder,
4198 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4201 llvm::omp::Directive cancelledDirective =
4204 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4205 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4207 if (failed(
handleError(afterIP, *op.getOperation())))
4210 builder.restoreIP(afterIP.get());
4220 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4222 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4227 Value symAddr = threadprivateOp.getSymAddr();
4230 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4233 if (!isa<LLVM::AddressOfOp>(symOp))
4234 return opInst.
emitError(
"Addressing symbol not found");
4235 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4237 LLVM::GlobalOp global =
4238 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4239 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4240 llvm::Type *type = globalValue->getValueType();
4241 llvm::TypeSize typeSize =
4242 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4244 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4245 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4246 ompLoc, globalValue, size, global.getSymName() +
".cache");
4252static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4254 switch (deviceClause) {
4255 case mlir::omp::DeclareTargetDeviceType::host:
4256 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4258 case mlir::omp::DeclareTargetDeviceType::nohost:
4259 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4261 case mlir::omp::DeclareTargetDeviceType::any:
4262 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4265 llvm_unreachable(
"unhandled device clause");
4268static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4270 mlir::omp::DeclareTargetCaptureClause captureClause) {
4271 switch (captureClause) {
4272 case mlir::omp::DeclareTargetCaptureClause::to:
4273 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4274 case mlir::omp::DeclareTargetCaptureClause::link:
4275 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4276 case mlir::omp::DeclareTargetCaptureClause::enter:
4277 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4278 case mlir::omp::DeclareTargetCaptureClause::none:
4279 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4281 llvm_unreachable(
"unhandled capture clause");
4286 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4288 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4289 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4290 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4295static llvm::SmallString<64>
4297 llvm::OpenMPIRBuilder &ompBuilder) {
4299 llvm::raw_svector_ostream os(suffix);
4302 auto fileInfoCallBack = [&loc]() {
4303 return std::pair<std::string, uint64_t>(
4304 llvm::StringRef(loc.getFilename()), loc.getLine());
4307 auto vfs = llvm::vfs::getRealFileSystem();
4310 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4312 os <<
"_decl_tgt_ref_ptr";
4318 if (
auto declareTargetGlobal =
4319 dyn_cast_if_present<omp::DeclareTargetInterface>(
4321 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4322 omp::DeclareTargetCaptureClause::link)
4328 if (
auto declareTargetGlobal =
4329 dyn_cast_if_present<omp::DeclareTargetInterface>(
4331 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4332 omp::DeclareTargetCaptureClause::to ||
4333 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4334 omp::DeclareTargetCaptureClause::enter)
4348 if (
auto declareTargetGlobal =
4349 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4352 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4353 omp::DeclareTargetCaptureClause::link) ||
4354 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4355 omp::DeclareTargetCaptureClause::to &&
4356 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4360 if (gOp.getSymName().contains(suffix))
4365 (gOp.getSymName().str() + suffix.str()).str());
4374struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4375 SmallVector<Operation *, 4> Mappers;
4378 void append(MapInfosTy &curInfo) {
4379 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4380 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4389struct MapInfoData : MapInfosTy {
4390 llvm::SmallVector<bool, 4> IsDeclareTarget;
4391 llvm::SmallVector<bool, 4> IsAMember;
4393 llvm::SmallVector<bool, 4> IsAMapping;
4394 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4395 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4398 llvm::SmallVector<llvm::Type *, 4> BaseType;
4401 void append(MapInfoData &CurInfo) {
4402 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4403 CurInfo.IsDeclareTarget.end());
4404 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4405 OriginalValue.append(CurInfo.OriginalValue.begin(),
4406 CurInfo.OriginalValue.end());
4407 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4408 MapInfosTy::append(CurInfo);
4412enum class TargetDirectiveEnumTy : uint32_t {
4416 TargetEnterData = 3,
4421static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4422 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4423 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4424 .Case([](omp::TargetEnterDataOp) {
4425 return TargetDirectiveEnumTy::TargetEnterData;
4427 .Case([&](omp::TargetExitDataOp) {
4428 return TargetDirectiveEnumTy::TargetExitData;
4430 .Case([&](omp::TargetUpdateOp) {
4431 return TargetDirectiveEnumTy::TargetUpdate;
4433 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4434 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4441 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4442 arrTy.getElementType()))
4459 llvm::Value *basePointer,
4460 llvm::Type *baseType,
4461 llvm::IRBuilderBase &builder,
4463 if (
auto memberClause =
4464 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4469 if (!memberClause.getBounds().empty()) {
4470 llvm::Value *elementCount = builder.getInt64(1);
4471 for (
auto bounds : memberClause.getBounds()) {
4472 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4473 bounds.getDefiningOp())) {
4478 elementCount = builder.CreateMul(
4482 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
4483 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
4484 builder.getInt64(1)));
4491 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4499 return builder.CreateMul(elementCount,
4500 builder.getInt64(underlyingTypeSzInBits / 8));
4511static llvm::omp::OpenMPOffloadMappingFlags
4513 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4514 return (mlirFlags & flag) == flag;
4516 const bool hasExplicitMap =
4517 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
4518 omp::ClauseMapFlags::none;
4520 llvm::omp::OpenMPOffloadMappingFlags mapType =
4521 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4524 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4527 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4530 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4533 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4536 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4539 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4542 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4545 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4548 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4551 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4554 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4557 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4560 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4561 if (!hasExplicitMap)
4562 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4572 ArrayRef<Value> useDevAddrOperands = {},
4573 ArrayRef<Value> hasDevAddrOperands = {}) {
4574 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4582 for (Value mapValue : mapVars) {
4583 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4584 for (
auto member : map.getMembers())
4585 if (member == mapOp)
4592 for (Value mapValue : mapVars) {
4593 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4595 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4596 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4597 mapData.Pointers.push_back(mapData.OriginalValue.back());
4599 if (llvm::Value *refPtr =
4601 mapData.IsDeclareTarget.push_back(
true);
4602 mapData.BasePointers.push_back(refPtr);
4604 mapData.IsDeclareTarget.push_back(
true);
4605 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4607 mapData.IsDeclareTarget.push_back(
false);
4608 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4611 mapData.BaseType.push_back(
4612 moduleTranslation.
convertType(mapOp.getVarType()));
4613 mapData.Sizes.push_back(
4614 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4615 mapData.BaseType.back(), builder, moduleTranslation));
4616 mapData.MapClause.push_back(mapOp.getOperation());
4620 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4621 if (mapOp.getMapperId())
4622 mapData.Mappers.push_back(
4624 mapOp, mapOp.getMapperIdAttr()));
4626 mapData.Mappers.push_back(
nullptr);
4627 mapData.IsAMapping.push_back(
true);
4628 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4631 auto findMapInfo = [&mapData](llvm::Value *val,
4632 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4635 for (llvm::Value *basePtr : mapData.OriginalValue) {
4636 if (basePtr == val && mapData.IsAMapping[index]) {
4638 mapData.Types[index] |=
4639 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4640 mapData.DevicePointers[index] = devInfoTy;
4648 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4649 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4650 for (Value mapValue : useDevOperands) {
4651 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4653 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4654 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4657 if (!findMapInfo(origValue, devInfoTy)) {
4658 mapData.OriginalValue.push_back(origValue);
4659 mapData.Pointers.push_back(mapData.OriginalValue.back());
4660 mapData.IsDeclareTarget.push_back(
false);
4661 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4662 mapData.BaseType.push_back(
4663 moduleTranslation.
convertType(mapOp.getVarType()));
4664 mapData.Sizes.push_back(builder.getInt64(0));
4665 mapData.MapClause.push_back(mapOp.getOperation());
4666 mapData.Types.push_back(
4667 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4670 mapData.DevicePointers.push_back(devInfoTy);
4671 mapData.Mappers.push_back(
nullptr);
4672 mapData.IsAMapping.push_back(
false);
4673 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4678 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4679 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4681 for (Value mapValue : hasDevAddrOperands) {
4682 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4684 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4685 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4687 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4689 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4690 omp::ClauseMapFlags::none;
4692 mapData.OriginalValue.push_back(origValue);
4693 mapData.BasePointers.push_back(origValue);
4694 mapData.Pointers.push_back(origValue);
4695 mapData.IsDeclareTarget.push_back(
false);
4696 mapData.BaseType.push_back(
4697 moduleTranslation.
convertType(mapOp.getVarType()));
4698 mapData.Sizes.push_back(
4699 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4700 mapData.MapClause.push_back(mapOp.getOperation());
4701 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4705 mapData.Types.push_back(mapType);
4709 if (mapOp.getMapperId()) {
4710 mapData.Mappers.push_back(
4712 mapOp, mapOp.getMapperIdAttr()));
4714 mapData.Mappers.push_back(
nullptr);
4719 mapData.Types.push_back(
4720 isDevicePtr ? mapType
4721 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4722 mapData.Mappers.push_back(
nullptr);
4726 mapData.DevicePointers.push_back(
4727 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4728 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4729 mapData.IsAMapping.push_back(
false);
4730 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4735 auto *res = llvm::find(mapData.MapClause, memberOp);
4736 assert(res != mapData.MapClause.end() &&
4737 "MapInfoOp for member not found in MapData, cannot return index");
4738 return std::distance(mapData.MapClause.begin(), res);
4742 omp::MapInfoOp mapInfo) {
4743 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4753 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4754 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4756 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4757 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4758 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4760 if (aIndex == bIndex)
4763 if (aIndex < bIndex)
4766 if (aIndex > bIndex)
4773 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4775 occludedChildren.push_back(
b);
4777 occludedChildren.push_back(a);
4778 return memberAParent;
4784 for (
auto v : occludedChildren)
4791 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4793 if (indexAttr.size() == 1)
4794 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4798 return llvm::cast<omp::MapInfoOp>(
4823static std::vector<llvm::Value *>
4825 llvm::IRBuilderBase &builder,
bool isArrayTy,
4827 std::vector<llvm::Value *> idx;
4838 idx.push_back(builder.getInt64(0));
4839 for (
int i = bounds.size() - 1; i >= 0; --i) {
4840 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4841 bounds[i].getDefiningOp())) {
4842 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4860 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4861 for (
int i = bounds.size() - 1; i >= 0; --i) {
4862 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4863 bounds[i].getDefiningOp())) {
4864 if (i == ((
int)bounds.size() - 1))
4866 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4868 idx.back() = builder.CreateAdd(
4869 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4870 boundOp.getExtent())),
4871 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4880 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
4881 return cast<IntegerAttr>(value).getInt();
4889 omp::MapInfoOp parentOp) {
4891 if (parentOp.getMembers().empty())
4895 if (parentOp.getMembers().size() == 1) {
4896 overlapMapDataIdxs.push_back(0);
4902 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4903 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4904 memberByIndex.push_back(
4905 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4910 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4911 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
4917 for (
auto v : memberByIndex) {
4921 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
4924 llvm::SmallVector<int64_t> xArr(x.second.size());
4925 getAsIntegers(x.second, xArr);
4926 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4927 xArr.size() >= vArr.size();
4933 for (
auto v : memberByIndex)
4934 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4935 overlapMapDataIdxs.push_back(v.first);
4947 if (mapOp.getVarPtrPtr())
4976 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4977 MapInfoData &mapData, uint64_t mapDataIndex,
4978 TargetDirectiveEnumTy targetDirective) {
4979 assert(!ompBuilder.Config.isTargetDevice() &&
4980 "function only supported for host device codegen");
4983 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4985 auto *parentMapper = mapData.Mappers[mapDataIndex];
4991 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4992 (targetDirective == TargetDirectiveEnumTy::Target &&
4993 !mapData.IsDeclareTarget[mapDataIndex])
4994 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4995 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4998 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5002 mapFlags parentFlags = mapData.Types[mapDataIndex];
5003 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5004 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5005 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5006 baseFlag |= (parentFlags & preserve);
5009 combinedInfo.Types.emplace_back(baseFlag);
5010 combinedInfo.DevicePointers.emplace_back(
5011 mapData.DevicePointers[mapDataIndex]);
5015 combinedInfo.Mappers.emplace_back(
5016 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
5018 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5019 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5028 llvm::Value *lowAddr, *highAddr;
5029 if (!parentClause.getPartialMap()) {
5030 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5031 builder.getPtrTy());
5032 highAddr = builder.CreatePointerCast(
5033 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5034 mapData.Pointers[mapDataIndex], 1),
5035 builder.getPtrTy());
5036 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5038 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5041 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5042 builder.getPtrTy());
5045 highAddr = builder.CreatePointerCast(
5046 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5047 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5048 builder.getPtrTy());
5049 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5052 llvm::Value *size = builder.CreateIntCast(
5053 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5054 builder.getInt64Ty(),
5056 combinedInfo.Sizes.push_back(size);
5058 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5059 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5067 if (!parentClause.getPartialMap()) {
5072 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5073 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5074 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5075 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5076 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5078 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5079 combinedInfo.Types.emplace_back(mapFlag);
5080 combinedInfo.DevicePointers.emplace_back(
5081 mapData.DevicePointers[mapDataIndex]);
5083 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5084 combinedInfo.BasePointers.emplace_back(
5085 mapData.BasePointers[mapDataIndex]);
5086 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5087 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5088 combinedInfo.Mappers.emplace_back(
nullptr);
5099 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5100 builder.getPtrTy());
5101 highAddr = builder.CreatePointerCast(
5102 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5103 mapData.Pointers[mapDataIndex], 1),
5104 builder.getPtrTy());
5111 for (
auto v : overlapIdxs) {
5114 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5115 combinedInfo.Types.emplace_back(mapFlag);
5116 combinedInfo.DevicePointers.emplace_back(
5117 mapData.DevicePointers[mapDataOverlapIdx]);
5119 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5120 combinedInfo.BasePointers.emplace_back(
5121 mapData.BasePointers[mapDataIndex]);
5122 combinedInfo.Mappers.emplace_back(
nullptr);
5123 combinedInfo.Pointers.emplace_back(lowAddr);
5124 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5125 builder.CreatePtrDiff(builder.getInt8Ty(),
5126 mapData.OriginalValue[mapDataOverlapIdx],
5128 builder.getInt64Ty(),
true));
5129 lowAddr = builder.CreateConstGEP1_32(
5131 mapData.MapClause[mapDataOverlapIdx]))
5132 ? builder.getPtrTy()
5133 : mapData.BaseType[mapDataOverlapIdx],
5134 mapData.BasePointers[mapDataOverlapIdx], 1);
5137 combinedInfo.Types.emplace_back(mapFlag);
5138 combinedInfo.DevicePointers.emplace_back(
5139 mapData.DevicePointers[mapDataIndex]);
5141 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5142 combinedInfo.BasePointers.emplace_back(
5143 mapData.BasePointers[mapDataIndex]);
5144 combinedInfo.Mappers.emplace_back(
nullptr);
5145 combinedInfo.Pointers.emplace_back(lowAddr);
5146 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5147 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5148 builder.getInt64Ty(),
true));
5151 return memberOfFlag;
5157 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5158 MapInfoData &mapData, uint64_t mapDataIndex,
5159 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5160 TargetDirectiveEnumTy targetDirective) {
5161 assert(!ompBuilder.Config.isTargetDevice() &&
5162 "function only supported for host device codegen");
5165 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5167 for (
auto mappedMembers : parentClause.getMembers()) {
5169 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5172 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5183 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5184 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5185 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5186 combinedInfo.Types.emplace_back(mapFlag);
5187 combinedInfo.DevicePointers.emplace_back(
5188 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5189 combinedInfo.Mappers.emplace_back(
nullptr);
5190 combinedInfo.Names.emplace_back(
5192 combinedInfo.BasePointers.emplace_back(
5193 mapData.BasePointers[mapDataIndex]);
5194 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5195 combinedInfo.Sizes.emplace_back(builder.getInt64(
5196 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5202 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5203 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5204 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5206 ? parentClause.getVarPtr()
5207 : parentClause.getVarPtrPtr());
5210 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5211 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5212 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5215 combinedInfo.Types.emplace_back(mapFlag);
5216 combinedInfo.DevicePointers.emplace_back(
5217 mapData.DevicePointers[memberDataIdx]);
5218 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5219 combinedInfo.Names.emplace_back(
5221 uint64_t basePointerIndex =
5223 combinedInfo.BasePointers.emplace_back(
5224 mapData.BasePointers[basePointerIndex]);
5225 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5227 llvm::Value *size = mapData.Sizes[memberDataIdx];
5229 size = builder.CreateSelect(
5230 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5231 builder.getInt64(0), size);
5234 combinedInfo.Sizes.emplace_back(size);
5239 MapInfosTy &combinedInfo,
5240 TargetDirectiveEnumTy targetDirective,
5241 int mapDataParentIdx = -1) {
5245 auto mapFlag = mapData.Types[mapDataIdx];
5246 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5250 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5252 if (targetDirective == TargetDirectiveEnumTy::Target &&
5253 !mapData.IsDeclareTarget[mapDataIdx])
5254 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5256 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5258 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5263 if (mapDataParentIdx >= 0)
5264 combinedInfo.BasePointers.emplace_back(
5265 mapData.BasePointers[mapDataParentIdx]);
5267 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5269 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5270 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5271 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5272 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5273 combinedInfo.Types.emplace_back(mapFlag);
5274 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5278 llvm::IRBuilderBase &builder,
5279 llvm::OpenMPIRBuilder &ompBuilder,
5281 MapInfoData &mapData, uint64_t mapDataIndex,
5282 TargetDirectiveEnumTy targetDirective) {
5283 assert(!ompBuilder.Config.isTargetDevice() &&
5284 "function only supported for host device codegen");
5287 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5292 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5293 auto memberClause = llvm::cast<omp::MapInfoOp>(
5294 parentClause.getMembers()[0].getDefiningOp());
5311 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5313 combinedInfo, mapData, mapDataIndex,
5316 combinedInfo, mapData, mapDataIndex,
5317 memberOfParentFlag, targetDirective);
5327 llvm::IRBuilderBase &builder) {
5329 "function only supported for host device codegen");
5330 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5332 if (!mapData.IsDeclareTarget[i]) {
5333 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5334 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5344 switch (captureKind) {
5345 case omp::VariableCaptureKind::ByRef: {
5346 llvm::Value *newV = mapData.Pointers[i];
5348 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5351 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5353 if (!offsetIdx.empty())
5354 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5356 mapData.Pointers[i] = newV;
5358 case omp::VariableCaptureKind::ByCopy: {
5359 llvm::Type *type = mapData.BaseType[i];
5361 if (mapData.Pointers[i]->getType()->isPointerTy())
5362 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5364 newV = mapData.Pointers[i];
5367 auto curInsert = builder.saveIP();
5368 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5370 auto *memTempAlloc =
5371 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5372 builder.SetCurrentDebugLocation(DbgLoc);
5373 builder.restoreIP(curInsert);
5375 builder.CreateStore(newV, memTempAlloc);
5376 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5379 mapData.Pointers[i] = newV;
5380 mapData.BasePointers[i] = newV;
5382 case omp::VariableCaptureKind::This:
5383 case omp::VariableCaptureKind::VLAType:
5384 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5395 MapInfoData &mapData,
5396 TargetDirectiveEnumTy targetDirective) {
5398 "function only supported for host device codegen");
5419 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5422 if (mapData.IsAMember[i])
5425 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5426 if (!mapInfoOp.getMembers().empty()) {
5428 combinedInfo, mapData, i, targetDirective);
5436static llvm::Expected<llvm::Function *>
5438 LLVM::ModuleTranslation &moduleTranslation,
5439 llvm::StringRef mapperFuncName,
5440 TargetDirectiveEnumTy targetDirective);
5442static llvm::Expected<llvm::Function *>
5445 TargetDirectiveEnumTy targetDirective) {
5447 "function only supported for host device codegen");
5448 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5449 std::string mapperFuncName =
5451 {
"omp_mapper", declMapperOp.getSymName()});
5453 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
5461 if (llvm::Function *existingFunc =
5462 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
5463 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
5464 return existingFunc;
5468 mapperFuncName, targetDirective);
5471static llvm::Expected<llvm::Function *>
5474 llvm::StringRef mapperFuncName,
5475 TargetDirectiveEnumTy targetDirective) {
5477 "function only supported for host device codegen");
5478 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5479 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5482 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
5485 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5488 MapInfosTy combinedInfo;
5490 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5491 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5492 builder.restoreIP(codeGenIP);
5493 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
5494 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
5495 builder.GetInsertBlock());
5496 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
5499 return llvm::make_error<PreviouslyReportedError>();
5500 MapInfoData mapData;
5503 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5509 return combinedInfo;
5513 if (!combinedInfo.Mappers[i])
5516 moduleTranslation, targetDirective);
5520 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5522 return newFn.takeError();
5523 if ([[maybe_unused]] llvm::Function *mappedFunc =
5525 assert(mappedFunc == *newFn &&
5526 "mapper function mapping disagrees with emitted function");
5528 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
5536 llvm::Value *ifCond =
nullptr;
5537 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5541 llvm::omp::RuntimeFunction RTLFn;
5543 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5546 llvm::OpenMPIRBuilder::TargetDataInfo info(
5549 assert(!ompBuilder->Config.isTargetDevice() &&
5550 "target data/enter/exit/update are host ops");
5551 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5553 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
5554 llvm::Value *v = moduleTranslation.
lookupValue(dev);
5555 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
5560 .Case([&](omp::TargetDataOp dataOp) {
5564 if (
auto ifVar = dataOp.getIfExpr())
5568 deviceID = getDeviceID(devId);
5570 mapVars = dataOp.getMapVars();
5571 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5572 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5575 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5579 if (
auto ifVar = enterDataOp.getIfExpr())
5583 deviceID = getDeviceID(devId);
5586 enterDataOp.getNowait()
5587 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5588 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5589 mapVars = enterDataOp.getMapVars();
5590 info.HasNoWait = enterDataOp.getNowait();
5593 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5597 if (
auto ifVar = exitDataOp.getIfExpr())
5601 deviceID = getDeviceID(devId);
5603 RTLFn = exitDataOp.getNowait()
5604 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5605 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5606 mapVars = exitDataOp.getMapVars();
5607 info.HasNoWait = exitDataOp.getNowait();
5610 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5614 if (
auto ifVar = updateDataOp.getIfExpr())
5618 deviceID = getDeviceID(devId);
5621 updateDataOp.getNowait()
5622 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5623 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5624 mapVars = updateDataOp.getMapVars();
5625 info.HasNoWait = updateDataOp.getNowait();
5628 .DefaultUnreachable(
"unexpected operation");
5633 if (!isOffloadEntry)
5634 ifCond = builder.getFalse();
5636 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5637 MapInfoData mapData;
5639 builder, useDevicePtrVars, useDeviceAddrVars);
5642 MapInfosTy combinedInfo;
5643 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5644 builder.restoreIP(codeGenIP);
5645 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5647 return combinedInfo;
5653 [&moduleTranslation](
5654 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5658 for (
auto [arg, useDevVar] :
5659 llvm::zip_equal(blockArgs, useDeviceVars)) {
5661 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5662 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5663 : mapInfoOp.getVarPtr();
5666 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5667 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5668 mapInfoData.MapClause, mapInfoData.DevicePointers,
5669 mapInfoData.BasePointers)) {
5670 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5671 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5672 devicePointer != type)
5675 if (llvm::Value *devPtrInfoMap =
5676 mapper ? mapper(basePointer) : basePointer) {
5677 moduleTranslation.
mapValue(arg, devPtrInfoMap);
5684 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5685 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5686 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5689 builder.restoreIP(codeGenIP);
5690 assert(isa<omp::TargetDataOp>(op) &&
5691 "BodyGen requested for non TargetDataOp");
5692 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5693 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
5694 switch (bodyGenType) {
5695 case BodyGenTy::Priv:
5697 if (!info.DevicePtrInfoMap.empty()) {
5698 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5699 blockArgIface.getUseDeviceAddrBlockArgs(),
5700 useDeviceAddrVars, mapData,
5701 [&](llvm::Value *basePointer) -> llvm::Value * {
5702 if (!info.DevicePtrInfoMap[basePointer].second)
5704 return builder.CreateLoad(
5706 info.DevicePtrInfoMap[basePointer].second);
5708 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5709 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5710 mapData, [&](llvm::Value *basePointer) {
5711 return info.DevicePtrInfoMap[basePointer].second;
5715 moduleTranslation)))
5716 return llvm::make_error<PreviouslyReportedError>();
5719 case BodyGenTy::DupNoPriv:
5720 if (info.DevicePtrInfoMap.empty()) {
5723 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5724 blockArgIface.getUseDeviceAddrBlockArgs(),
5725 useDeviceAddrVars, mapData);
5726 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5727 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5731 case BodyGenTy::NoPriv:
5733 if (info.DevicePtrInfoMap.empty()) {
5735 moduleTranslation)))
5736 return llvm::make_error<PreviouslyReportedError>();
5740 return builder.saveIP();
5743 auto customMapperCB =
5745 if (!combinedInfo.Mappers[i])
5747 info.HasMapper =
true;
5749 moduleTranslation, targetDirective);
5752 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5753 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5756 if (isa<omp::TargetDataOp>(op))
5757 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5758 deviceID, ifCond, info, genMapInfoCB,
5762 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5763 deviceID, ifCond, info, genMapInfoCB,
5764 customMapperCB, &RTLFn);
5770 builder.restoreIP(*afterIP);
5778 auto distributeOp = cast<omp::DistributeOp>(opInst);
5785 bool doDistributeReduction =
5789 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5794 if (doDistributeReduction) {
5795 isByRef =
getIsByRef(teamsOp.getReductionByref());
5796 assert(isByRef.size() == teamsOp.getNumReductionVars());
5799 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5803 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5804 .getReductionBlockArgs();
5807 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5808 reductionDecls, privateReductionVariables, reductionVariableMap,
5813 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5814 auto bodyGenCB = [&](InsertPointTy allocaIP,
5815 InsertPointTy codeGenIP) -> llvm::Error {
5819 moduleTranslation, allocaIP);
5822 builder.restoreIP(codeGenIP);
5828 return llvm::make_error<PreviouslyReportedError>();
5833 return llvm::make_error<PreviouslyReportedError>();
5836 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
5838 distributeOp.getPrivateNeedsBarrier())))
5839 return llvm::make_error<PreviouslyReportedError>();
5842 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5845 builder, moduleTranslation);
5847 return regionBlock.takeError();
5848 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5853 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5856 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5857 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5858 : omp::ClauseScheduleKind::Static;
5860 bool isOrdered = hasDistSchedule;
5861 std::optional<omp::ScheduleModifier> scheduleMod;
5862 bool isSimd =
false;
5863 llvm::omp::WorksharingLoopType workshareLoopType =
5864 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5865 bool loopNeedsBarrier =
false;
5866 llvm::Value *chunk = moduleTranslation.
lookupValue(
5867 distributeOp.getDistScheduleChunkSize());
5868 llvm::CanonicalLoopInfo *loopInfo =
5870 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5871 ompBuilder->applyWorkshareLoop(
5872 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5873 convertToScheduleKind(schedule), chunk, isSimd,
5874 scheduleMod == omp::ScheduleModifier::monotonic,
5875 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5876 workshareLoopType,
false, hasDistSchedule, chunk);
5879 return wsloopIP.takeError();
5882 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5884 return llvm::make_error<PreviouslyReportedError>();
5886 return llvm::Error::success();
5889 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5891 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5892 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5893 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5898 builder.restoreIP(*afterIP);
5900 if (doDistributeReduction) {
5903 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5904 privateReductionVariables, isByRef,
5916 if (!cast<mlir::ModuleOp>(op))
5921 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5922 attribute.getOpenmpDeviceVersion());
5924 if (attribute.getNoGpuLib())
5927 ompBuilder->createGlobalFlag(
5928 attribute.getDebugKind() ,
5929 "__omp_rtl_debug_kind");
5930 ompBuilder->createGlobalFlag(
5932 .getAssumeTeamsOversubscription()
5934 "__omp_rtl_assume_teams_oversubscription");
5935 ompBuilder->createGlobalFlag(
5937 .getAssumeThreadsOversubscription()
5939 "__omp_rtl_assume_threads_oversubscription");
5940 ompBuilder->createGlobalFlag(
5941 attribute.getAssumeNoThreadState() ,
5942 "__omp_rtl_assume_no_thread_state");
5943 ompBuilder->createGlobalFlag(
5945 .getAssumeNoNestedParallelism()
5947 "__omp_rtl_assume_no_nested_parallelism");
5952 omp::TargetOp targetOp,
5953 llvm::StringRef parentName =
"") {
5954 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5956 assert(fileLoc &&
"No file found from location");
5957 StringRef fileName = fileLoc.getFilename().getValue();
5959 llvm::sys::fs::UniqueID id;
5960 uint64_t line = fileLoc.getLine();
5961 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5962 size_t fileHash = llvm::hash_value(fileName.str());
5963 size_t deviceId = 0xdeadf17e;
5965 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5967 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5968 id.getFile(), line);
5975 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5977 "function only supported for target device codegen");
5978 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5979 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5992 if (mapData.IsDeclareTarget[i]) {
5999 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6000 convertUsersOfConstantsToInstructions(constant,
func,
false);
6007 for (llvm::User *user : mapData.OriginalValue[i]->users())
6008 userVec.push_back(user);
6010 for (llvm::User *user : userVec) {
6011 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
6012 if (insn->getFunction() ==
func) {
6013 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6014 llvm::Value *substitute = mapData.BasePointers[i];
6016 : mapOp.getVarPtr())) {
6017 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6018 substitute = builder.CreateLoad(
6019 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6020 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6022 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6069static llvm::IRBuilderBase::InsertPoint
6071 llvm::Value *input, llvm::Value *&retVal,
6072 llvm::IRBuilderBase &builder,
6073 llvm::OpenMPIRBuilder &ompBuilder,
6075 llvm::IRBuilderBase::InsertPoint allocaIP,
6076 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6077 assert(ompBuilder.Config.isTargetDevice() &&
6078 "function only supported for target device codegen");
6079 builder.restoreIP(allocaIP);
6081 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6083 ompBuilder.M.getContext());
6084 unsigned alignmentValue = 0;
6086 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
6087 if (mapData.OriginalValue[i] == input) {
6088 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6089 capture = mapOp.getMapCaptureType();
6092 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6096 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6097 unsigned int defaultAS =
6098 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6101 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6103 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6104 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6106 builder.CreateStore(&arg, v);
6108 builder.restoreIP(codeGenIP);
6111 case omp::VariableCaptureKind::ByCopy: {
6115 case omp::VariableCaptureKind::ByRef: {
6116 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6118 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6133 if (v->getType()->isPointerTy() && alignmentValue) {
6134 llvm::MDBuilder MDB(builder.getContext());
6135 loadInst->setMetadata(
6136 llvm::LLVMContext::MD_align,
6137 llvm::MDNode::get(builder.getContext(),
6138 MDB.createConstant(llvm::ConstantInt::get(
6139 llvm::Type::getInt64Ty(builder.getContext()),
6146 case omp::VariableCaptureKind::This:
6147 case omp::VariableCaptureKind::VLAType:
6150 assert(
false &&
"Currently unsupported capture kind");
6154 return builder.saveIP();
6171 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6172 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6173 blockArgIface.getHostEvalBlockArgs())) {
6174 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6178 .Case([&](omp::TeamsOp teamsOp) {
6179 if (teamsOp.getNumTeamsLower() == blockArg)
6180 numTeamsLower = hostEvalVar;
6181 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6183 numTeamsUpper = hostEvalVar;
6184 else if (!teamsOp.getThreadLimitVars().empty() &&
6185 teamsOp.getThreadLimit(0) == blockArg)
6186 threadLimit = hostEvalVar;
6188 llvm_unreachable(
"unsupported host_eval use");
6190 .Case([&](omp::ParallelOp parallelOp) {
6191 if (!parallelOp.getNumThreadsVars().empty() &&
6192 parallelOp.getNumThreads(0) == blockArg)
6193 numThreads = hostEvalVar;
6195 llvm_unreachable(
"unsupported host_eval use");
6197 .Case([&](omp::LoopNestOp loopOp) {
6198 auto processBounds =
6202 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6203 if (lb == blockArg) {
6206 (*outBounds)[i] = hostEvalVar;
6212 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6213 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6215 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6217 assert(found &&
"unsupported host_eval use");
6219 .DefaultUnreachable(
"unsupported host_eval use");
6231template <
typename OpTy>
6236 if (OpTy casted = dyn_cast<OpTy>(op))
6239 if (immediateParent)
6240 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6249 return std::nullopt;
6252 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6253 return constAttr.getInt();
6255 return std::nullopt;
6260 uint64_t sizeInBytes = sizeInBits / 8;
6264template <
typename OpTy>
6266 if (op.getNumReductionVars() > 0) {
6271 members.reserve(reductions.size());
6272 for (omp::DeclareReductionOp &red : reductions)
6273 members.push_back(red.getType());
6275 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6291 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6292 bool isTargetDevice,
bool isGPU) {
6295 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6296 if (!isTargetDevice) {
6304 numTeamsLower = teamsOp.getNumTeamsLower();
6306 if (!teamsOp.getNumTeamsUpperVars().empty())
6307 numTeamsUpper = teamsOp.getNumTeams(0);
6308 if (!teamsOp.getThreadLimitVars().empty())
6309 threadLimit = teamsOp.getThreadLimit(0);
6313 if (!parallelOp.getNumThreadsVars().empty())
6314 numThreads = parallelOp.getNumThreads(0);
6320 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6324 if (numTeamsUpper) {
6326 minTeamsVal = maxTeamsVal = *val;
6328 minTeamsVal = maxTeamsVal = 0;
6334 minTeamsVal = maxTeamsVal = 1;
6336 minTeamsVal = maxTeamsVal = -1;
6341 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6355 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6356 if (!targetOp.getThreadLimitVars().empty())
6357 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6358 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6361 int32_t maxThreadsVal = -1;
6363 setMaxValueFromClause(numThreads, maxThreadsVal);
6371 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6372 if (combinedMaxThreadsVal < 0 ||
6373 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6374 combinedMaxThreadsVal = teamsThreadLimitVal;
6376 if (combinedMaxThreadsVal < 0 ||
6377 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6378 combinedMaxThreadsVal = maxThreadsVal;
6380 int32_t reductionDataSize = 0;
6381 if (isGPU && capturedOp) {
6387 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6389 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6390 omp::TargetRegionFlags::spmd) &&
6391 "invalid kernel flags");
6393 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6394 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6395 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6396 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6397 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6398 if (omp::bitEnumContainsAll(kernelFlags,
6399 omp::TargetRegionFlags::spmd |
6400 omp::TargetRegionFlags::no_loop) &&
6401 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6402 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6404 attrs.MinTeams = minTeamsVal;
6405 attrs.MaxTeams.front() = maxTeamsVal;
6406 attrs.MinThreads = 1;
6407 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6408 attrs.ReductionDataSize = reductionDataSize;
6411 if (attrs.ReductionDataSize != 0)
6412 attrs.ReductionBufferLength = 1024;
6424 omp::TargetOp targetOp,
Operation *capturedOp,
6425 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6427 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6429 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6433 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6436 if (!targetOp.getThreadLimitVars().empty()) {
6437 Value targetThreadLimit = targetOp.getThreadLimit(0);
6438 attrs.TargetThreadLimit.front() =
6446 attrs.MinTeams = builder.CreateSExtOrTrunc(
6447 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
6450 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6451 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
6453 if (teamsThreadLimit)
6454 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6455 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
6458 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
6460 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6461 omp::TargetRegionFlags::trip_count)) {
6463 attrs.LoopTripCount =
nullptr;
6468 for (
auto [loopLower, loopUpper, loopStep] :
6469 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6470 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
6471 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
6472 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
6474 if (!lowerBound || !upperBound || !step) {
6475 attrs.LoopTripCount =
nullptr;
6479 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6480 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6481 loc, lowerBound, upperBound, step,
true,
6482 loopOp.getLoopInclusive());
6484 if (!attrs.LoopTripCount) {
6485 attrs.LoopTripCount = tripCount;
6490 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6495 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6497 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
6499 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6506 auto targetOp = cast<omp::TargetOp>(opInst);
6511 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6520 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6521 assert(parentBB &&
"No insert block is set for the builder");
6522 llvm::Function *parentLLVMFn = parentBB->getParent();
6523 assert(parentLLVMFn &&
"Parent Function must be valid");
6524 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6525 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6526 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6527 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6530 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6531 bool isGPU = ompBuilder->Config.isGPU();
6534 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6535 auto &targetRegion = targetOp.getRegion();
6552 llvm::Function *llvmOutlinedFn =
nullptr;
6553 TargetDirectiveEnumTy targetDirective =
6554 getTargetDirectiveEnumTyFromOp(&opInst);
6558 bool isOffloadEntry =
6559 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6566 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6568 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6569 std::optional<DenseI64ArrayAttr> privateMapIndices =
6570 targetOp.getPrivateMapsAttr();
6572 for (
auto [privVarIdx, privVarSymPair] :
6573 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6574 auto privVar = std::get<0>(privVarSymPair);
6575 auto privSym = std::get<1>(privVarSymPair);
6577 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6578 omp::PrivateClauseOp privatizer =
6581 if (!privatizer.needsMap())
6585 targetOp.getMappedValueForPrivateVar(privVarIdx);
6586 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6587 "variable that needs mapping");
6592 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6593 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6597 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6599 varType == privVar.getType() &&
6600 "Type of private var doesn't match the type of the mapped value");
6604 mappedPrivateVars.insert(
6606 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6607 (*privateMapIndices)[privVarIdx])});
6611 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6612 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6613 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6614 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6615 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6618 llvm::Function *llvmParentFn =
6620 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6621 assert(llvmParentFn && llvmOutlinedFn &&
6622 "Both parent and outlined functions must exist at this point");
6624 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6625 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6627 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6628 attr.isStringAttribute())
6629 llvmOutlinedFn->addFnAttr(attr);
6631 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6632 attr.isStringAttribute())
6633 llvmOutlinedFn->addFnAttr(attr);
6635 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6636 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6637 llvm::Value *mapOpValue =
6638 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6639 moduleTranslation.
mapValue(arg, mapOpValue);
6641 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6642 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6643 llvm::Value *mapOpValue =
6644 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6645 moduleTranslation.
mapValue(arg, mapOpValue);
6654 allocaIP, &mappedPrivateVars);
6657 return llvm::make_error<PreviouslyReportedError>();
6659 builder.restoreIP(codeGenIP);
6661 &mappedPrivateVars),
6664 return llvm::make_error<PreviouslyReportedError>();
6667 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
6669 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6670 return llvm::make_error<PreviouslyReportedError>();
6674 std::back_inserter(privateCleanupRegions),
6675 [](omp::PrivateClauseOp privatizer) {
6676 return &privatizer.getDeallocRegion();
6680 targetRegion,
"omp.target", builder, moduleTranslation);
6683 return exitBlock.takeError();
6685 builder.SetInsertPoint(*exitBlock);
6686 if (!privateCleanupRegions.empty()) {
6688 privateCleanupRegions, privateVarsInfo.
llvmVars,
6689 moduleTranslation, builder,
"omp.targetop.private.cleanup",
6691 return llvm::createStringError(
6692 "failed to inline `dealloc` region of `omp.private` "
6693 "op in the target region");
6695 return builder.saveIP();
6698 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6701 StringRef parentName = parentFn.getName();
6703 llvm::TargetRegionEntryInfo entryInfo;
6707 MapInfoData mapData;
6712 MapInfosTy combinedInfos;
6714 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6715 builder.restoreIP(codeGenIP);
6716 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6718 return combinedInfos;
6721 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6722 llvm::Value *&retVal, InsertPointTy allocaIP,
6723 InsertPointTy codeGenIP)
6724 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6725 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6726 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6732 if (!isTargetDevice) {
6733 retVal = cast<llvm::Value>(&arg);
6738 *ompBuilder, moduleTranslation,
6739 allocaIP, codeGenIP);
6742 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6743 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6744 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6746 isTargetDevice, isGPU);
6750 if (!isTargetDevice)
6752 targetCapturedOp, runtimeAttrs);
6760 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6761 llvm::Value *value = moduleTranslation.
lookupValue(var);
6762 moduleTranslation.
mapValue(arg, value);
6764 if (!llvm::isa<llvm::Constant>(value))
6765 kernelInput.push_back(value);
6768 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6775 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6776 kernelInput.push_back(mapData.OriginalValue[i]);
6781 moduleTranslation, dds);
6783 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6785 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6787 llvm::OpenMPIRBuilder::TargetDataInfo info(
6791 auto customMapperCB =
6793 if (!combinedInfos.Mappers[i])
6795 info.HasMapper =
true;
6797 moduleTranslation, targetDirective);
6800 llvm::Value *ifCond =
nullptr;
6801 if (
Value targetIfCond = targetOp.getIfExpr())
6802 ifCond = moduleTranslation.
lookupValue(targetIfCond);
6804 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6806 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6807 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6808 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6813 builder.restoreIP(*afterIP);
6834 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6835 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6837 if (!offloadMod.getIsTargetDevice())
6840 omp::DeclareTargetDeviceType declareType =
6841 attribute.getDeviceType().getValue();
6843 if (declareType == omp::DeclareTargetDeviceType::host) {
6844 llvm::Function *llvmFunc =
6846 llvmFunc->dropAllReferences();
6847 llvmFunc->eraseFromParent();
6853 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6854 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6855 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6857 bool isDeclaration = gOp.isDeclaration();
6858 bool isExternallyVisible =
6861 llvm::StringRef mangledName = gOp.getSymName();
6862 auto captureClause =
6868 std::vector<llvm::GlobalVariable *> generatedRefs;
6870 std::vector<llvm::Triple> targetTriple;
6871 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6873 LLVM::LLVMDialect::getTargetTripleAttrName()));
6874 if (targetTripleAttr)
6875 targetTriple.emplace_back(targetTripleAttr.data());
6877 auto fileInfoCallBack = [&loc]() {
6878 std::string filename =
"";
6879 std::uint64_t lineNo = 0;
6882 filename = loc.getFilename().str();
6883 lineNo = loc.getLine();
6886 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6890 auto vfs = llvm::vfs::getRealFileSystem();
6892 ompBuilder->registerTargetGlobalVariable(
6893 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6894 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6895 mangledName, generatedRefs,
false, targetTriple,
6897 gVal->getType(), gVal);
6899 if (ompBuilder->Config.isTargetDevice() &&
6900 (attribute.getCaptureClause().getValue() !=
6901 mlir::omp::DeclareTargetCaptureClause::to ||
6902 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6903 ompBuilder->getAddrOfDeclareTargetVar(
6904 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6905 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6906 mangledName, generatedRefs,
false, targetTriple,
6907 gVal->getType(),
nullptr,
6920class OpenMPDialectLLVMIRTranslationInterface
6921 :
public LLVMTranslationDialectInterface {
6923 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
6928 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6929 LLVM::ModuleTranslation &moduleTranslation)
const final;
6934 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6935 NamedAttribute attribute,
6936 LLVM::ModuleTranslation &moduleTranslation)
const final;
6941LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6942 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6943 NamedAttribute attribute,
6944 LLVM::ModuleTranslation &moduleTranslation)
const {
6945 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6947 .Case(
"omp.is_target_device",
6948 [&](Attribute attr) {
6949 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6950 llvm::OpenMPIRBuilderConfig &config =
6952 config.setIsTargetDevice(deviceAttr.getValue());
6958 [&](Attribute attr) {
6959 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6960 llvm::OpenMPIRBuilderConfig &config =
6962 config.setIsGPU(gpuAttr.getValue());
6967 .Case(
"omp.host_ir_filepath",
6968 [&](Attribute attr) {
6969 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6970 llvm::OpenMPIRBuilder *ompBuilder =
6972 auto VFS = llvm::vfs::getRealFileSystem();
6973 ompBuilder->loadOffloadInfoMetadata(*VFS,
6974 filepathAttr.getValue());
6980 [&](Attribute attr) {
6981 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6985 .Case(
"omp.version",
6986 [&](Attribute attr) {
6987 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6988 llvm::OpenMPIRBuilder *ompBuilder =
6990 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6991 versionAttr.getVersion());
6996 .Case(
"omp.declare_target",
6997 [&](Attribute attr) {
6998 if (
auto declareTargetAttr =
6999 dyn_cast<omp::DeclareTargetAttr>(attr))
7004 .Case(
"omp.requires",
7005 [&](Attribute attr) {
7006 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7007 using Requires = omp::ClauseRequires;
7008 Requires flags = requiresAttr.getValue();
7009 llvm::OpenMPIRBuilderConfig &config =
7011 config.setHasRequiresReverseOffload(
7012 bitEnumContainsAll(flags, Requires::reverse_offload));
7013 config.setHasRequiresUnifiedAddress(
7014 bitEnumContainsAll(flags, Requires::unified_address));
7015 config.setHasRequiresUnifiedSharedMemory(
7016 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7017 config.setHasRequiresDynamicAllocators(
7018 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7023 .Case(
"omp.target_triples",
7024 [&](Attribute attr) {
7025 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7026 llvm::OpenMPIRBuilderConfig &config =
7028 config.TargetTriples.clear();
7029 config.TargetTriples.reserve(triplesAttr.size());
7030 for (Attribute tripleAttr : triplesAttr) {
7031 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7032 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7040 .Default([](Attribute) {
7056 if (
auto declareTargetIface =
7057 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7058 parentFn.getOperation()))
7059 if (declareTargetIface.isDeclareTarget() &&
7060 declareTargetIface.getDeclareTargetDeviceType() !=
7061 mlir::omp::DeclareTargetDeviceType::host)
7071 llvm::Module *llvmModule) {
7072 llvm::Type *i64Ty = builder.getInt64Ty();
7073 llvm::Type *i32Ty = builder.getInt32Ty();
7074 llvm::Type *returnType = builder.getPtrTy(0);
7075 llvm::FunctionType *fnType =
7076 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
7077 llvm::Function *
func = cast<llvm::Function>(
7078 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
7085 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7090 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7094 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7096 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7097 mlir::Type heapTy = allocMemOp.getAllocatedType();
7098 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
7099 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7100 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7101 for (
auto typeParam : allocMemOp.getTypeparams())
7103 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
7105 llvm::CallInst *call =
7106 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7107 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7110 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
7115 llvm::Module *llvmModule) {
7116 llvm::Type *ptrTy = builder.getPtrTy(0);
7117 llvm::Type *i32Ty = builder.getInt32Ty();
7118 llvm::Type *voidTy = builder.getVoidTy();
7119 llvm::FunctionType *fnType =
7120 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
7121 llvm::Function *
func = dyn_cast<llvm::Function>(
7122 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
7129 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7134 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7138 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7141 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
7143 llvm::Value *intToPtr =
7144 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7145 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7151LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7152 Operation *op, llvm::IRBuilderBase &builder,
7153 LLVM::ModuleTranslation &moduleTranslation)
const {
7156 if (ompBuilder->Config.isTargetDevice() &&
7157 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7160 return op->
emitOpError() <<
"unsupported host op found in device";
7168 bool isOutermostLoopWrapper =
7169 isa_and_present<omp::LoopWrapperInterface>(op) &&
7170 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
7172 if (isOutermostLoopWrapper)
7173 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
7176 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7177 .Case([&](omp::BarrierOp op) -> LogicalResult {
7181 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7182 ompBuilder->createBarrier(builder.saveIP(),
7183 llvm::omp::OMPD_barrier);
7185 if (res.succeeded()) {
7188 builder.restoreIP(*afterIP);
7192 .Case([&](omp::TaskyieldOp op) {
7196 ompBuilder->createTaskyield(builder.saveIP());
7199 .Case([&](omp::FlushOp op) {
7211 ompBuilder->createFlush(builder.saveIP());
7214 .Case([&](omp::ParallelOp op) {
7217 .Case([&](omp::MaskedOp) {
7220 .Case([&](omp::MasterOp) {
7223 .Case([&](omp::CriticalOp) {
7226 .Case([&](omp::OrderedRegionOp) {
7229 .Case([&](omp::OrderedOp) {
7232 .Case([&](omp::WsloopOp) {
7235 .Case([&](omp::SimdOp) {
7238 .Case([&](omp::AtomicReadOp) {
7241 .Case([&](omp::AtomicWriteOp) {
7244 .Case([&](omp::AtomicUpdateOp op) {
7247 .Case([&](omp::AtomicCaptureOp op) {
7250 .Case([&](omp::CancelOp op) {
7253 .Case([&](omp::CancellationPointOp op) {
7256 .Case([&](omp::SectionsOp) {
7259 .Case([&](omp::SingleOp op) {
7262 .Case([&](omp::TeamsOp op) {
7265 .Case([&](omp::TaskOp op) {
7268 .Case([&](omp::TaskloopOp op) {
7271 .Case([&](omp::TaskgroupOp op) {
7274 .Case([&](omp::TaskwaitOp op) {
7277 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7278 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7279 omp::CriticalDeclareOp>([](
auto op) {
7292 .Case([&](omp::ThreadprivateOp) {
7295 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7296 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
7299 .Case([&](omp::TargetOp) {
7302 .Case([&](omp::DistributeOp) {
7305 .Case([&](omp::LoopNestOp) {
7308 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7315 .Case([&](omp::NewCliOp op) {
7320 .Case([&](omp::CanonicalLoopOp op) {
7323 .Case([&](omp::UnrollHeuristicOp op) {
7332 .Case([&](omp::TileOp op) {
7333 return applyTile(op, builder, moduleTranslation);
7335 .Case([&](omp::FuseOp op) {
7336 return applyFuse(op, builder, moduleTranslation);
7338 .Case([&](omp::TargetAllocMemOp) {
7341 .Case([&](omp::TargetFreeMemOp) {
7344 .Default([&](Operation *inst) {
7346 <<
"not yet implemented: " << inst->
getName();
7349 if (isOutermostLoopWrapper)
7356 registry.
insert<omp::OpenMPDialect>();
7358 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars