24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
198 if (linearVarType->isIntegerTy()) {
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 }
else if (linearVarType->isFloatingPointTy()) {
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
220 "Linear variable must be of integer or floating-point type");
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(),
"omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
265 builder.SetInsertPoint(linearExitBB->getTerminator());
267 builder.saveIP(), llvm::omp::OMPD_barrier);
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
282 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (
auto *user : users) {
288 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
307 assert(privatizer &&
"privatizer not found in the symbol table");
318 auto todo = [&op](StringRef clauseName) {
319 return op.
emitError() <<
"not yet implemented: Unhandled clause "
320 << clauseName <<
" in " << op.
getName()
324 auto checkAffinity = [&todo](
auto op, LogicalResult &
result) {
325 if (!op.getAffinityVars().empty())
326 result = todo(
"affinity");
328 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
329 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
330 result = todo(
"allocate");
332 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
334 result = todo(
"ompx_bare");
336 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
340 auto checkHint = [](
auto op, LogicalResult &) {
344 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo(
"in_reduction");
349 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
353 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
354 if (op.getOrder() || op.getOrderMod())
357 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &
result) {
358 if (op.getParLevelSimd())
359 result = todo(
"parallelization-level");
361 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo(
"privatization");
365 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo(
"reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo(
"reduction with modifier");
374 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo(
"task_reduction");
379 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
380 if (op.hasNumTeamsMultiDim())
381 result = todo(
"num_teams with multi-dimensional values");
383 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
384 if (op.hasNumThreadsMultiDim())
385 result = todo(
"num_threads with multi-dimensional values");
388 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
389 if (op.hasThreadLimitMultiDim())
390 result = todo(
"thread_limit with multi-dimensional values");
395 .Case([&](omp::DistributeOp op) {
396 checkAllocate(op,
result);
399 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op,
result); })
400 .Case([&](omp::SectionsOp op) {
401 checkAllocate(op,
result);
403 checkReduction(op,
result);
405 .Case([&](omp::SingleOp op) {
406 checkAllocate(op,
result);
409 .Case([&](omp::TeamsOp op) {
410 checkAllocate(op,
result);
412 checkNumTeams(op,
result);
413 checkThreadLimit(op,
result);
415 .Case([&](omp::TaskOp op) {
416 checkAffinity(op,
result);
417 checkAllocate(op,
result);
418 checkInReduction(op,
result);
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op,
result);
422 checkTaskReduction(op,
result);
424 .Case([&](omp::TaskwaitOp op) {
428 .Case([&](omp::TaskloopOp op) {
429 checkAllocate(op,
result);
430 checkInReduction(op,
result);
431 checkReduction(op,
result);
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op,
result);
436 checkReduction(op,
result);
438 .Case([&](omp::ParallelOp op) {
439 checkAllocate(op,
result);
440 checkReduction(op,
result);
441 checkNumThreads(op,
result);
443 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
444 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
445 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
446 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
447 [&](
auto op) { checkDepend(op,
result); })
448 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
449 .Case([&](omp::TargetOp op) {
450 checkAllocate(op,
result);
452 checkInReduction(op,
result);
453 checkThreadLimit(op,
result);
465 llvm::handleAllErrors(
467 [&](
const PreviouslyReportedError &) {
result = failure(); },
468 [&](
const llvm::ErrorInfoBase &err) {
485static llvm::OpenMPIRBuilder::InsertPointTy
491 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
493 [&](OpenMPAllocaStackFrame &frame) {
494 allocaInsertPoint = frame.allocaInsertPoint;
502 allocaInsertPoint.getBlock()->getParent() ==
503 builder.GetInsertBlock()->getParent())
504 return allocaInsertPoint;
513 if (builder.GetInsertBlock() ==
514 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
515 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
516 "Assuming end of basic block");
517 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
518 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
519 builder.GetInsertBlock()->getNextNode());
520 builder.CreateBr(entryBB);
521 builder.SetInsertPoint(entryBB);
524 llvm::BasicBlock &funcEntryBlock =
525 builder.GetInsertBlock()->getParent()->getEntryBlock();
526 return llvm::OpenMPIRBuilder::InsertPointTy(
527 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
533static llvm::CanonicalLoopInfo *
535 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
536 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
537 [&](OpenMPLoopInfoStackFrame &frame) {
538 loopInfo = frame.loopInfo;
550 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
553 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
555 llvm::BasicBlock *continuationBlock =
556 splitBB(builder,
true,
"omp.region.cont");
557 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
559 llvm::LLVMContext &llvmContext = builder.getContext();
560 for (
Block &bb : region) {
561 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
562 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
563 builder.GetInsertBlock()->getNextNode());
564 moduleTranslation.
mapBlock(&bb, llvmBB);
567 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
574 unsigned numYields = 0;
576 if (!isLoopWrapper) {
577 bool operandsProcessed =
false;
579 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
580 if (!operandsProcessed) {
581 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
582 continuationBlockPHITypes.push_back(
583 moduleTranslation.
convertType(yield->getOperand(i).getType()));
585 operandsProcessed =
true;
587 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
588 "mismatching number of values yielded from the region");
589 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
590 llvm::Type *operandType =
591 moduleTranslation.
convertType(yield->getOperand(i).getType());
593 assert(continuationBlockPHITypes[i] == operandType &&
594 "values of mismatching types yielded from the region");
604 if (!continuationBlockPHITypes.empty())
606 continuationBlockPHIs &&
607 "expected continuation block PHIs if converted regions yield values");
608 if (continuationBlockPHIs) {
609 llvm::IRBuilderBase::InsertPointGuard guard(builder);
610 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
611 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
612 for (llvm::Type *ty : continuationBlockPHITypes)
613 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
619 for (
Block *bb : blocks) {
620 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
623 if (bb->isEntryBlock()) {
624 assert(sourceTerminator->getNumSuccessors() == 1 &&
625 "provided entry block has multiple successors");
626 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
627 "ContinuationBlock is not the successor of the entry block");
628 sourceTerminator->setSuccessor(0, llvmBB);
631 llvm::IRBuilderBase::InsertPointGuard guard(builder);
633 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
634 return llvm::make_error<PreviouslyReportedError>();
639 builder.CreateBr(continuationBlock);
650 Operation *terminator = bb->getTerminator();
651 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
652 builder.CreateBr(continuationBlock);
654 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
655 (*continuationBlockPHIs)[i]->addIncoming(
669 return continuationBlock;
675 case omp::ClauseProcBindKind::Close:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
677 case omp::ClauseProcBindKind::Master:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
679 case omp::ClauseProcBindKind::Primary:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
681 case omp::ClauseProcBindKind::Spread:
682 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
684 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
691 auto maskedOp = cast<omp::MaskedOp>(opInst);
692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
697 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
699 auto ®ion = maskedOp.getRegion();
700 builder.restoreIP(codeGenIP);
708 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
710 llvm::Value *filterVal =
nullptr;
711 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
712 filterVal = moduleTranslation.
lookupValue(filterVar);
714 llvm::LLVMContext &llvmContext = builder.getContext();
716 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
718 assert(filterVal !=
nullptr);
719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
720 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
727 builder.restoreIP(*afterIP);
735 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
736 auto masterOp = cast<omp::MasterOp>(opInst);
741 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
743 auto ®ion = masterOp.getRegion();
744 builder.restoreIP(codeGenIP);
752 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
762 builder.restoreIP(*afterIP);
770 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
771 auto criticalOp = cast<omp::CriticalOp>(opInst);
776 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
778 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
779 builder.restoreIP(codeGenIP);
787 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
789 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
790 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
791 llvm::Constant *hint =
nullptr;
794 if (criticalOp.getNameAttr()) {
797 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
798 auto criticalDeclareOp =
802 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
803 static_cast<int>(criticalDeclareOp.getHint()));
805 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
807 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
812 builder.restoreIP(*afterIP);
819 template <
typename OP>
822 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
825 collectPrivatizationDecls<OP>(op);
840 void collectPrivatizationDecls(OP op) {
841 std::optional<ArrayAttr> attr = op.getPrivateSyms();
846 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
857 std::optional<ArrayAttr> attr = op.getReductionSyms();
861 reductions.reserve(reductions.size() + op.getNumReductionVars());
862 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
863 reductions.push_back(
875 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
884 llvm::Instruction *potentialTerminator =
885 builder.GetInsertBlock()->empty() ?
nullptr
886 : &builder.GetInsertBlock()->back();
888 if (potentialTerminator && potentialTerminator->isTerminator())
889 potentialTerminator->removeFromParent();
890 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
893 region.
front(),
true, builder)))
897 if (continuationBlockArgs)
899 *continuationBlockArgs,
906 if (potentialTerminator && potentialTerminator->isTerminator()) {
907 llvm::BasicBlock *block = builder.GetInsertBlock();
908 if (block->empty()) {
914 potentialTerminator->insertInto(block, block->begin());
916 potentialTerminator->insertAfter(&block->back());
930 if (continuationBlockArgs)
931 llvm::append_range(*continuationBlockArgs, phis);
932 builder.SetInsertPoint(*continuationBlock,
933 (*continuationBlock)->getFirstInsertionPt());
940using OwningReductionGen =
941 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
942 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
944using OwningAtomicReductionGen =
945 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
946 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
948using OwningDataPtrPtrReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
956static OwningReductionGen
962 OwningReductionGen gen =
963 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
964 llvm::Value *
lhs, llvm::Value *
rhs,
965 llvm::Value *&
result)
mutable
966 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
967 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
968 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
969 builder.restoreIP(insertPoint);
972 "omp.reduction.nonatomic.body", builder,
973 moduleTranslation, &phis)))
974 return llvm::createStringError(
975 "failed to inline `combiner` region of `omp.declare_reduction`");
976 result = llvm::getSingleElement(phis);
977 return builder.saveIP();
986static OwningAtomicReductionGen
988 llvm::IRBuilderBase &builder,
990 if (decl.getAtomicReductionRegion().empty())
991 return OwningAtomicReductionGen();
997 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
998 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1001 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1002 builder.restoreIP(insertPoint);
1005 "omp.reduction.atomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `atomic` region of `omp.declare_reduction`");
1009 assert(phis.empty());
1010 return builder.saveIP();
1019static OwningDataPtrPtrReductionGen
1023 return OwningDataPtrPtrReductionGen();
1025 OwningDataPtrPtrReductionGen refDataPtrGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1027 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1030 builder.restoreIP(insertPoint);
1033 "omp.data_ptr_ptr.body", builder,
1034 moduleTranslation, &phis)))
1035 return llvm::createStringError(
1036 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1037 result = llvm::getSingleElement(phis);
1038 return builder.saveIP();
1041 return refDataPtrGen;
1048 auto orderedOp = cast<omp::OrderedOp>(opInst);
1053 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1054 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1055 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1057 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1059 size_t indexVecValues = 0;
1060 while (indexVecValues < vecValues.size()) {
1062 storeValues.reserve(numLoops);
1063 for (
unsigned i = 0; i < numLoops; i++) {
1064 storeValues.push_back(vecValues[indexVecValues]);
1067 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1070 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1071 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1082 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1087 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1089 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1090 builder.restoreIP(codeGenIP);
1098 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1100 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1101 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1103 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1108 builder.restoreIP(*afterIP);
1114struct DeferredStore {
1115 DeferredStore(llvm::Value *value, llvm::Value *address)
1116 : value(value), address(address) {}
1119 llvm::Value *address;
1126template <
typename T>
1129 llvm::IRBuilderBase &builder,
1131 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1137 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1141 deferredStores.reserve(loop.getNumReductionVars());
1143 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1144 Region &allocRegion = reductionDecls[i].getAllocRegion();
1146 if (allocRegion.
empty())
1151 builder, moduleTranslation, &phis)))
1152 return loop.emitError(
1153 "failed to inline `alloc` region of `omp.declare_reduction`");
1155 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1156 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.
convertType(reductionDecls[i].getType()));
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 llvm::Value *castPhi =
1167 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1169 deferredStores.emplace_back(castPhi, castVar);
1171 privateReductionVariables[i] = castVar;
1172 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1173 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1175 assert(allocRegion.
empty() &&
1176 "allocaction is implicit for by-val reduction");
1177 llvm::Value *var = builder.CreateAlloca(
1178 moduleTranslation.
convertType(reductionDecls[i].getType()));
1180 llvm::Type *ptrTy = builder.getPtrTy();
1181 llvm::Value *castVar =
1182 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1184 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1185 privateReductionVariables[i] = castVar;
1186 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1194template <
typename T>
1197 llvm::IRBuilderBase &builder,
1202 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1203 Region &initializerRegion = reduction.getInitializerRegion();
1206 mlir::Value mlirSource = loop.getReductionVars()[i];
1207 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1208 llvm::Value *origVal = llvmSource;
1210 if (!isa<LLVM::LLVMPointerType>(
1211 reduction.getInitializerMoldArg().getType()) &&
1212 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1215 reduction.getInitializerMoldArg().getType()),
1216 llvmSource,
"omp_orig");
1218 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1221 llvm::Value *allocation =
1222 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1223 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1229 llvm::BasicBlock *block =
nullptr) {
1230 if (block ==
nullptr)
1231 block = builder.GetInsertBlock();
1233 if (block->empty() || block->getTerminator() ==
nullptr)
1234 builder.SetInsertPoint(block);
1236 builder.SetInsertPoint(block->getTerminator());
1244template <
typename OP>
1247 llvm::IRBuilderBase &builder,
1249 llvm::BasicBlock *latestAllocaBlock,
1255 if (op.getNumReductionVars() == 0)
1258 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1259 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1260 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1261 builder.restoreIP(allocaIP);
1264 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1266 if (!reductionDecls[i].getAllocRegion().empty())
1272 byRefVars[i] = builder.CreateAlloca(
1273 moduleTranslation.
convertType(reductionDecls[i].getType()));
1281 for (
auto [data, addr] : deferredStores)
1282 builder.CreateStore(data, addr);
1287 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1292 reductionVariableMap, i);
1300 "omp.reduction.neutral", builder,
1301 moduleTranslation, &phis)))
1304 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1305 "reduction neutral element declaration region");
1310 if (!reductionDecls[i].getAllocRegion().empty())
1319 builder.CreateStore(phis[0], byRefVars[i]);
1321 privateReductionVariables[i] = byRefVars[i];
1322 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1323 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1326 builder.CreateStore(phis[0], privateReductionVariables[i]);
1333 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1340template <
typename T>
1341static void collectReductionInfo(
1342 T loop, llvm::IRBuilderBase &builder,
1351 unsigned numReductions = loop.getNumReductionVars();
1353 for (
unsigned i = 0; i < numReductions; ++i) {
1356 owningAtomicReductionGens.push_back(
1359 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1363 reductionInfos.reserve(numReductions);
1364 for (
unsigned i = 0; i < numReductions; ++i) {
1365 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1366 if (owningAtomicReductionGens[i])
1367 atomicGen = owningAtomicReductionGens[i];
1368 llvm::Value *variable =
1369 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1372 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1373 allocatedType = alloca.getElemType();
1380 reductionInfos.push_back(
1382 privateReductionVariables[i],
1383 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1387 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1388 reductionDecls[i].getByrefElementType()
1390 *reductionDecls[i].getByrefElementType())
1400 llvm::IRBuilderBase &builder, StringRef regionName,
1401 bool shouldLoadCleanupRegionArg =
true) {
1402 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1403 if (cleanupRegion->empty())
1409 llvm::Instruction *potentialTerminator =
1410 builder.GetInsertBlock()->empty() ?
nullptr
1411 : &builder.GetInsertBlock()->back();
1412 if (potentialTerminator && potentialTerminator->isTerminator())
1413 builder.SetInsertPoint(potentialTerminator);
1414 llvm::Value *privateVarValue =
1415 shouldLoadCleanupRegionArg
1416 ? builder.CreateLoad(
1418 privateVariables[i])
1419 : privateVariables[i];
1424 moduleTranslation)))
1437 OP op, llvm::IRBuilderBase &builder,
1439 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1442 bool isNowait =
false,
bool isTeamsReduction =
false) {
1444 if (op.getNumReductionVars() == 0)
1456 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1458 owningReductionGenRefDataPtrGens,
1459 privateReductionVariables, reductionInfos, isByRef);
1464 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1465 builder.SetInsertPoint(tempTerminator);
1466 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1467 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1468 isByRef, isNowait, isTeamsReduction);
1473 if (!contInsertPoint->getBlock())
1474 return op->emitOpError() <<
"failed to convert reductions";
1476 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1477 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1482 tempTerminator->eraseFromParent();
1483 builder.restoreIP(*afterIP);
1487 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1488 [](omp::DeclareReductionOp reductionDecl) {
1489 return &reductionDecl.getCleanupRegion();
1492 moduleTranslation, builder,
1493 "omp.reduction.cleanup");
1504template <
typename OP>
1508 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1513 if (op.getNumReductionVars() == 0)
1519 allocaIP, reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 deferredStores, isByRef)))
1524 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1525 allocaIP.getBlock(), reductionDecls,
1526 privateReductionVariables, reductionVariableMap,
1527 isByRef, deferredStores);
1541 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1544 Value blockArg = (*mappedPrivateVars)[privateVar];
1547 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1548 "A block argument corresponding to a mapped var should have "
1551 if (privVarType == blockArgType)
1558 if (!isa<LLVM::LLVMPointerType>(privVarType))
1559 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1572 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1574 llvm::BasicBlock *privInitBlock,
1576 Region &initRegion = privDecl.getInitRegion();
1577 if (initRegion.
empty())
1578 return llvmPrivateVar;
1580 assert(nonPrivateVar);
1581 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1582 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1587 moduleTranslation, &phis)))
1588 return llvm::createStringError(
1589 "failed to inline `init` region of `omp.private`");
1591 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1608 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1611 builder, moduleTranslation, privDecl,
1614 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1623 return llvm::Error::success();
1625 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1628 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1631 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1633 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1634 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1637 return privVarOrErr.takeError();
1639 llvmPrivateVar = privVarOrErr.get();
1640 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1645 return llvm::Error::success();
1655 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1658 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1659 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1660 allocaTerminator->getIterator()),
1661 true, allocaTerminator->getStableDebugLoc(),
1662 "omp.region.after_alloca");
1664 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1666 allocaTerminator = allocaIP.getBlock()->getTerminator();
1667 builder.SetInsertPoint(allocaTerminator);
1669 assert(allocaTerminator->getNumSuccessors() == 1 &&
1670 "This is an unconditional branch created by splitBB");
1672 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1673 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1675 unsigned int allocaAS =
1676 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1679 .getProgramAddressSpace();
1681 for (
auto [privDecl, mlirPrivVar, blockArg] :
1684 llvm::Type *llvmAllocType =
1685 moduleTranslation.
convertType(privDecl.getType());
1686 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1687 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1688 llvmAllocType,
nullptr,
"omp.private.alloc");
1689 if (allocaAS != defaultAS)
1690 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1691 builder.getPtrTy(defaultAS));
1693 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1696 return afterAllocas;
1704 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1713 if (mlir::isa<omp::ParallelOp>(parent))
1727 bool needsFirstprivate =
1728 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1729 return privOp.getDataSharingType() ==
1730 omp::DataSharingClauseType::FirstPrivate;
1733 if (!needsFirstprivate)
1736 llvm::BasicBlock *copyBlock =
1737 splitBB(builder,
true,
"omp.private.copy");
1740 for (
auto [decl, moldVar, llvmVar] :
1741 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1742 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1746 Region ©Region = decl.getCopyRegion();
1748 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1751 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1755 moduleTranslation)))
1756 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1771 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1787 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1789 llvm::Value *moldVar = findAssociatedValue(
1790 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1795 llvmPrivateVars, privateDecls, insertBarrier,
1806 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1807 [](omp::PrivateClauseOp privatizer) {
1808 return &privatizer.getDeallocRegion();
1812 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1813 "omp.private.dealloc",
false)))
1814 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1815 "`omp.private` op in");
1827 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1838 using StorableBodyGenCallbackTy =
1839 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1841 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1847 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1855 sectionsOp.getNumReductionVars());
1859 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1862 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1863 reductionDecls, privateReductionVariables, reductionVariableMap,
1870 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1874 Region ®ion = sectionOp.getRegion();
1875 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1876 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1877 builder.restoreIP(codeGenIP);
1884 sectionsOp.getRegion().getNumArguments());
1885 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1886 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1887 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1889 moduleTranslation.
mapValue(sectionArg, llvmVal);
1896 sectionCBs.push_back(sectionCB);
1902 if (sectionCBs.empty())
1905 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1910 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1911 llvm::Value &vPtr, llvm::Value *&replacementValue)
1912 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1913 replacementValue = &vPtr;
1919 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1926 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1927 sectionsOp.getNowait());
1932 builder.restoreIP(*afterIP);
1936 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1937 privateReductionVariables, isByRef, sectionsOp.getNowait());
1944 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1950 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1951 builder.restoreIP(codegenIP);
1953 builder, moduleTranslation)
1956 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1960 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1963 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1964 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1966 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1967 llvmCPFuncs.push_back(
1971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1973 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1979 builder.restoreIP(*afterIP);
1985 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1990 for (
auto ra : iface.getReductionBlockArgs())
1991 for (
auto &use : ra.getUses()) {
1992 auto *useOp = use.getOwner();
1994 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1995 debugUses.push_back(useOp);
1999 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2004 Operation *currentOp = currentDistOp.getOperation();
2005 if (distOp && (distOp != currentOp))
2014 for (
auto *use : debugUses)
2023 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2028 unsigned numReductionVars = op.getNumReductionVars();
2032 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2038 if (doTeamsReduction) {
2039 isByRef =
getIsByRef(op.getReductionByref());
2041 assert(isByRef.size() == op.getNumReductionVars());
2044 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2049 op, reductionArgs, builder, moduleTranslation, allocaIP,
2050 reductionDecls, privateReductionVariables, reductionVariableMap,
2055 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2057 moduleTranslation, allocaIP);
2058 builder.restoreIP(codegenIP);
2064 llvm::Value *numTeamsLower =
nullptr;
2065 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2066 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2068 llvm::Value *numTeamsUpper =
nullptr;
2069 if (!op.getNumTeamsUpperVars().empty())
2070 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2072 llvm::Value *threadLimit =
nullptr;
2073 if (!op.getThreadLimitVars().empty())
2074 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2076 llvm::Value *ifExpr =
nullptr;
2077 if (
Value ifVar = op.getIfExpr())
2080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2081 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2083 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2088 builder.restoreIP(*afterIP);
2089 if (doTeamsReduction) {
2092 op, builder, moduleTranslation, allocaIP, reductionDecls,
2093 privateReductionVariables, isByRef,
2103 if (dependVars.empty())
2105 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2106 llvm::omp::RTLDependenceKindTy type;
2108 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2109 case mlir::omp::ClauseTaskDepend::taskdependin:
2110 type = llvm::omp::RTLDependenceKindTy::DepIn;
2115 case mlir::omp::ClauseTaskDepend::taskdependout:
2116 case mlir::omp::ClauseTaskDepend::taskdependinout:
2117 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2119 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2120 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2122 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2126 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2140 llvm::IRBuilderBase &llvmBuilder,
2142 llvm::omp::Directive cancelDirective) {
2143 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2144 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2148 llvmBuilder.restoreIP(ip);
2154 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2155 return llvm::Error::success();
2160 ompBuilder.pushFinalizationCB(
2170 llvm::OpenMPIRBuilder &ompBuilder,
2171 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2172 ompBuilder.popFinalizationCB();
2173 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2174 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2175 assert(cancelBranch->getNumSuccessors() == 1 &&
2176 "cancel branch should have one target");
2177 cancelBranch->setSuccessor(0, constructFini);
2184class TaskContextStructManager {
2186 TaskContextStructManager(llvm::IRBuilderBase &builder,
2187 LLVM::ModuleTranslation &moduleTranslation,
2188 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2189 : builder{builder}, moduleTranslation{moduleTranslation},
2190 privateDecls{privateDecls} {}
2196 void generateTaskContextStruct();
2202 void createGEPsToPrivateVars();
2208 SmallVector<llvm::Value *>
2209 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2212 void freeStructPtr();
2214 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2215 return llvmPrivateVarGEPs;
2218 llvm::Value *getStructPtr() {
return structPtr; }
2221 llvm::IRBuilderBase &builder;
2222 LLVM::ModuleTranslation &moduleTranslation;
2223 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2226 SmallVector<llvm::Type *> privateVarTypes;
2230 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2233 llvm::Value *structPtr =
nullptr;
2235 llvm::Type *structTy =
nullptr;
2239void TaskContextStructManager::generateTaskContextStruct() {
2240 if (privateDecls.empty())
2242 privateVarTypes.reserve(privateDecls.size());
2244 for (omp::PrivateClauseOp &privOp : privateDecls) {
2247 if (!privOp.readsFromMold())
2249 Type mlirType = privOp.getType();
2250 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2253 if (privateVarTypes.empty())
2256 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2259 llvm::DataLayout dataLayout =
2260 builder.GetInsertBlock()->getModule()->getDataLayout();
2261 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2262 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2265 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2267 "omp.task.context_ptr");
2270SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2271 llvm::Value *altStructPtr)
const {
2272 SmallVector<llvm::Value *> ret;
2275 ret.reserve(privateDecls.size());
2276 llvm::Value *zero = builder.getInt32(0);
2278 for (
auto privDecl : privateDecls) {
2279 if (!privDecl.readsFromMold()) {
2281 ret.push_back(
nullptr);
2284 llvm::Value *iVal = builder.getInt32(i);
2285 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2292void TaskContextStructManager::createGEPsToPrivateVars() {
2294 assert(privateVarTypes.empty());
2298 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2301void TaskContextStructManager::freeStructPtr() {
2305 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2307 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2308 builder.CreateFree(structPtr);
2315 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2320 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2337 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2338 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2339 builder.getContext(),
"omp.task.start",
2340 builder.GetInsertBlock()->getParent());
2341 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2342 builder.SetInsertPoint(branchToTaskStartBlock);
2345 llvm::BasicBlock *copyBlock =
2346 splitBB(builder,
true,
"omp.private.copy");
2347 llvm::BasicBlock *initBlock =
2348 splitBB(builder,
true,
"omp.private.init");
2364 moduleTranslation, allocaIP);
2367 builder.SetInsertPoint(initBlock->getTerminator());
2370 taskStructMgr.generateTaskContextStruct();
2377 taskStructMgr.createGEPsToPrivateVars();
2379 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2382 taskStructMgr.getLLVMPrivateVarGEPs())) {
2384 if (!privDecl.readsFromMold())
2386 assert(llvmPrivateVarAlloc &&
2387 "reads from mold so shouldn't have been skipped");
2390 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2391 blockArg, llvmPrivateVarAlloc, initBlock);
2392 if (!privateVarOrErr)
2393 return handleError(privateVarOrErr, *taskOp.getOperation());
2402 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2403 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2404 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2406 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2407 llvmPrivateVarAlloc);
2409 assert(llvmPrivateVarAlloc->getType() ==
2410 moduleTranslation.
convertType(blockArg.getType()));
2420 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2421 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2422 taskOp.getPrivateNeedsBarrier())))
2423 return llvm::failure();
2426 builder.SetInsertPoint(taskStartBlock);
2428 auto bodyCB = [&](InsertPointTy allocaIP,
2429 InsertPointTy codegenIP) -> llvm::Error {
2433 moduleTranslation, allocaIP);
2436 builder.restoreIP(codegenIP);
2438 llvm::BasicBlock *privInitBlock =
nullptr;
2440 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2443 auto [blockArg, privDecl, mlirPrivVar] = zip;
2445 if (privDecl.readsFromMold())
2448 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2449 llvm::Type *llvmAllocType =
2450 moduleTranslation.
convertType(privDecl.getType());
2451 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2452 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2453 llvmAllocType,
nullptr,
"omp.private.alloc");
2456 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2457 blockArg, llvmPrivateVar, privInitBlock);
2458 if (!privateVarOrError)
2459 return privateVarOrError.takeError();
2460 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2461 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2464 taskStructMgr.createGEPsToPrivateVars();
2465 for (
auto [i, llvmPrivVar] :
2466 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2468 assert(privateVarsInfo.
llvmVars[i] &&
2469 "This is added in the loop above");
2472 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2477 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2481 if (!privateDecl.readsFromMold())
2484 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2485 llvmPrivateVar = builder.CreateLoad(
2486 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2488 assert(llvmPrivateVar->getType() ==
2489 moduleTranslation.
convertType(blockArg.getType()));
2490 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2494 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2495 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2496 return llvm::make_error<PreviouslyReportedError>();
2498 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2503 return llvm::make_error<PreviouslyReportedError>();
2506 taskStructMgr.freeStructPtr();
2508 return llvm::Error::success();
2517 llvm::omp::Directive::OMPD_taskgroup);
2521 moduleTranslation, dds);
2523 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2526 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2528 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2529 taskOp.getMergeable(),
2530 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2531 moduleTranslation.
lookupValue(taskOp.getPriority()));
2539 builder.restoreIP(*afterIP);
2547 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2548 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2556 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2562 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2563 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2564 builder.getContext(),
"omp.taskloop.start",
2565 builder.GetInsertBlock()->getParent());
2566 llvm::Instruction *branchToTaskloopStartBlock =
2567 builder.CreateBr(taskloopStartBlock);
2568 builder.SetInsertPoint(branchToTaskloopStartBlock);
2570 llvm::BasicBlock *copyBlock =
2571 splitBB(builder,
true,
"omp.private.copy");
2572 llvm::BasicBlock *initBlock =
2573 splitBB(builder,
true,
"omp.private.init");
2576 moduleTranslation, allocaIP);
2579 builder.SetInsertPoint(initBlock->getTerminator());
2582 taskStructMgr.generateTaskContextStruct();
2583 taskStructMgr.createGEPsToPrivateVars();
2585 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
2587 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2589 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2590 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2592 if (!privDecl.readsFromMold())
2594 assert(llvmPrivateVarAlloc &&
2595 "reads from mold so shouldn't have been skipped");
2598 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2599 blockArg, llvmPrivateVarAlloc, initBlock);
2600 if (!privateVarOrErr)
2601 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2603 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2606 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2608 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2609 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2610 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2612 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2613 llvmPrivateVarAlloc);
2615 assert(llvmPrivateVarAlloc->getType() ==
2616 moduleTranslation.
convertType(blockArg.getType()));
2622 taskloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2623 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2624 taskloopOp.getPrivateNeedsBarrier())))
2625 return llvm::failure();
2628 builder.SetInsertPoint(taskloopStartBlock);
2630 auto bodyCB = [&](InsertPointTy allocaIP,
2631 InsertPointTy codegenIP) -> llvm::Error {
2635 moduleTranslation, allocaIP);
2638 builder.restoreIP(codegenIP);
2640 llvm::BasicBlock *privInitBlock =
nullptr;
2642 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2645 auto [blockArg, privDecl, mlirPrivVar] = zip;
2647 if (privDecl.readsFromMold())
2650 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2651 llvm::Type *llvmAllocType =
2652 moduleTranslation.
convertType(privDecl.getType());
2653 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2654 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2655 llvmAllocType,
nullptr,
"omp.private.alloc");
2658 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2659 blockArg, llvmPrivateVar, privInitBlock);
2660 if (!privateVarOrError)
2661 return privateVarOrError.takeError();
2662 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2663 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2666 taskStructMgr.createGEPsToPrivateVars();
2667 for (
auto [i, llvmPrivVar] :
2668 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2670 assert(privateVarsInfo.
llvmVars[i] &&
2671 "This is added in the loop above");
2674 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2679 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2683 if (!privateDecl.readsFromMold())
2686 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2687 llvmPrivateVar = builder.CreateLoad(
2688 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2690 assert(llvmPrivateVar->getType() ==
2691 moduleTranslation.
convertType(blockArg.getType()));
2692 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2695 auto continuationBlockOrError =
2697 builder, moduleTranslation);
2699 if (failed(
handleError(continuationBlockOrError, opInst)))
2700 return llvm::make_error<PreviouslyReportedError>();
2702 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2710 taskloopOp.getLoc(), privateVarsInfo.
llvmVars,
2712 return llvm::make_error<PreviouslyReportedError>();
2715 taskStructMgr.freeStructPtr();
2717 return llvm::Error::success();
2723 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2724 llvm::Value *destPtr, llvm::Value *srcPtr)
2726 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2727 builder.restoreIP(codegenIP);
2730 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2732 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
2734 TaskContextStructManager &srcStructMgr = taskStructMgr;
2735 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2737 destStructMgr.generateTaskContextStruct();
2738 llvm::Value *dest = destStructMgr.getStructPtr();
2739 dest->setName(
"omp.taskloop.context.dest");
2740 builder.CreateStore(dest, destPtr);
2743 srcStructMgr.createGEPsToPrivateVars(src);
2745 destStructMgr.createGEPsToPrivateVars(dest);
2748 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2749 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
2752 if (!privDecl.readsFromMold())
2754 assert(llvmPrivateVarAlloc &&
2755 "reads from mold so shouldn't have been skipped");
2758 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2759 llvmPrivateVarAlloc, builder.GetInsertBlock());
2760 if (!privateVarOrErr)
2761 return privateVarOrErr.takeError();
2770 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2771 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2772 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2774 llvmPrivateVarAlloc = builder.CreateLoad(
2775 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2777 assert(llvmPrivateVarAlloc->getType() ==
2778 moduleTranslation.
convertType(blockArg.getType()));
2786 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2787 privateVarsInfo.
privatizers, taskloopOp.getPrivateNeedsBarrier())))
2788 return llvm::make_error<PreviouslyReportedError>();
2790 return builder.saveIP();
2793 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2803 llvm::Type *boundType =
2804 moduleTranslation.
lookupValue(lowerBounds[0])->getType();
2805 llvm::Value *lbVal =
nullptr;
2806 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2807 llvm::Value *stepVal =
nullptr;
2808 if (loopOp.getCollapseNumLoops() > 1) {
2826 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2827 llvm::Value *loopLb = moduleTranslation.
lookupValue(lowerBounds[i]);
2828 llvm::Value *loopUb = moduleTranslation.
lookupValue(upperBounds[i]);
2829 llvm::Value *loopStep = moduleTranslation.
lookupValue(steps[i]);
2835 llvm::Value *loopLbMinusOne = builder.CreateSub(
2836 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2837 llvm::Value *loopUbMinusOne = builder.CreateSub(
2838 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2839 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2840 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2841 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2842 llvm::Value *loopTripCount =
2843 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2844 loopTripCount = builder.CreateBinaryIntrinsic(
2845 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2849 llvm::Value *loopTripCountDivStep =
2850 builder.CreateSDiv(loopTripCount, loopStep);
2851 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2852 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2853 llvm::Value *loopTripCountRem =
2854 builder.CreateSRem(loopTripCount, loopStep);
2855 loopTripCountRem = builder.CreateBinaryIntrinsic(
2856 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2857 llvm::Value *needsRoundUp = builder.CreateICmpNE(
2859 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2862 builder.CreateAdd(loopTripCountDivStep,
2863 builder.CreateZExtOrTrunc(
2864 needsRoundUp, loopTripCountDivStep->getType()));
2865 ubVal = builder.CreateMul(ubVal, loopTripCount);
2867 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2868 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2870 lbVal = moduleTranslation.
lookupValue(lowerBounds[0]);
2871 ubVal = moduleTranslation.
lookupValue(upperBounds[0]);
2872 stepVal = moduleTranslation.
lookupValue(steps[0]);
2874 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
2875 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
2876 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
2878 llvm::Value *ifCond =
nullptr;
2879 llvm::Value *grainsize =
nullptr;
2881 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2882 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2883 if (
Value ifVar = taskloopOp.getIfExpr())
2886 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
2888 }
else if (numTasksVal) {
2889 grainsize = moduleTranslation.
lookupValue(numTasksVal);
2893 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
2894 if (taskStructMgr.getStructPtr())
2895 taskDupOrNull = taskDupCB;
2905 llvm::omp::Directive::OMPD_taskgroup);
2907 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2908 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2910 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
2911 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2912 sched, moduleTranslation.
lookupValue(taskloopOp.getFinal()),
2913 taskloopOp.getMergeable(),
2914 moduleTranslation.
lookupValue(taskloopOp.getPriority()),
2915 loopOp.getCollapseNumLoops(), taskDupOrNull,
2916 taskStructMgr.getStructPtr());
2923 builder.restoreIP(*afterIP);
2931 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2935 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2936 builder.restoreIP(codegenIP);
2938 builder, moduleTranslation)
2943 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2944 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2951 builder.restoreIP(*afterIP);
2970 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2974 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2976 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2980 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2983 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2984 llvm::Type *ivType = step->getType();
2985 llvm::Value *chunk =
nullptr;
2986 if (wsloopOp.getScheduleChunk()) {
2987 llvm::Value *chunkVar =
2988 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2989 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2992 omp::DistributeOp distributeOp =
nullptr;
2993 llvm::Value *distScheduleChunk =
nullptr;
2994 bool hasDistSchedule =
false;
2995 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
2996 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
2997 hasDistSchedule = distributeOp.getDistScheduleStatic();
2998 if (distributeOp.getDistScheduleChunkSize()) {
2999 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3000 distributeOp.getDistScheduleChunkSize());
3001 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3009 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3013 wsloopOp.getNumReductionVars());
3016 builder, moduleTranslation, privateVarsInfo, allocaIP);
3023 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3028 moduleTranslation, allocaIP, reductionDecls,
3029 privateReductionVariables, reductionVariableMap,
3030 deferredStores, isByRef)))
3039 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3041 wsloopOp.getPrivateNeedsBarrier())))
3044 assert(afterAllocas.get()->getSinglePredecessor());
3045 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3047 afterAllocas.get()->getSinglePredecessor(),
3048 reductionDecls, privateReductionVariables,
3049 reductionVariableMap, isByRef, deferredStores)))
3053 bool isOrdered = wsloopOp.getOrdered().has_value();
3054 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3055 bool isSimd = wsloopOp.getScheduleSimd();
3056 bool loopNeedsBarrier = !wsloopOp.getNowait();
3061 llvm::omp::WorksharingLoopType workshareLoopType =
3062 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3063 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3064 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3068 llvm::omp::Directive::OMPD_for);
3070 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3073 LinearClauseProcessor linearClauseProcessor;
3075 if (!wsloopOp.getLinearVars().empty()) {
3076 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3078 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3080 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3081 linearClauseProcessor.createLinearVar(
3082 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3084 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3085 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3089 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3097 if (!wsloopOp.getLinearVars().empty()) {
3098 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3099 loopInfo->getPreheader());
3100 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3102 builder.saveIP(), llvm::omp::OMPD_barrier);
3105 builder.restoreIP(*afterBarrierIP);
3106 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3107 loopInfo->getIndVar());
3108 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3111 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3114 bool noLoopMode =
false;
3115 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3117 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3121 if (loopOp == targetCapturedOp) {
3122 omp::TargetRegionFlags kernelFlags =
3123 targetOp.getKernelExecFlags(targetCapturedOp);
3124 if (omp::bitEnumContainsAll(kernelFlags,
3125 omp::TargetRegionFlags::spmd |
3126 omp::TargetRegionFlags::no_loop) &&
3127 !omp::bitEnumContainsAny(kernelFlags,
3128 omp::TargetRegionFlags::generic))
3133 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3134 ompBuilder->applyWorkshareLoop(
3135 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3136 convertToScheduleKind(schedule), chunk, isSimd,
3137 scheduleMod == omp::ScheduleModifier::monotonic,
3138 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3139 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3145 if (!wsloopOp.getLinearVars().empty()) {
3146 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3147 assert(loopInfo->getLastIter() &&
3148 "`lastiter` in CanonicalLoopInfo is nullptr");
3149 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3150 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3151 loopInfo->getLastIter());
3154 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3155 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3157 builder.restoreIP(oldIP);
3165 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3166 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3179 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3181 assert(isByRef.size() == opInst.getNumReductionVars());
3194 opInst.getNumReductionVars());
3197 auto bodyGenCB = [&](InsertPointTy allocaIP,
3198 InsertPointTy codeGenIP) -> llvm::Error {
3200 builder, moduleTranslation, privateVarsInfo, allocaIP);
3202 return llvm::make_error<PreviouslyReportedError>();
3208 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3211 InsertPointTy(allocaIP.getBlock(),
3212 allocaIP.getBlock()->getTerminator()->getIterator());
3215 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3216 reductionDecls, privateReductionVariables, reductionVariableMap,
3217 deferredStores, isByRef)))
3218 return llvm::make_error<PreviouslyReportedError>();
3220 assert(afterAllocas.get()->getSinglePredecessor());
3221 builder.restoreIP(codeGenIP);
3227 return llvm::make_error<PreviouslyReportedError>();
3230 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3232 opInst.getPrivateNeedsBarrier())))
3233 return llvm::make_error<PreviouslyReportedError>();
3236 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3237 afterAllocas.get()->getSinglePredecessor(),
3238 reductionDecls, privateReductionVariables,
3239 reductionVariableMap, isByRef, deferredStores)))
3240 return llvm::make_error<PreviouslyReportedError>();
3245 moduleTranslation, allocaIP);
3249 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3251 return regionBlock.takeError();
3254 if (opInst.getNumReductionVars() > 0) {
3259 owningReductionGenRefDataPtrGens;
3261 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3263 owningReductionGenRefDataPtrGens,
3264 privateReductionVariables, reductionInfos, isByRef);
3267 builder.SetInsertPoint((*regionBlock)->getTerminator());
3270 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3271 builder.SetInsertPoint(tempTerminator);
3273 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3274 ompBuilder->createReductions(
3275 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3277 if (!contInsertPoint)
3278 return contInsertPoint.takeError();
3280 if (!contInsertPoint->getBlock())
3281 return llvm::make_error<PreviouslyReportedError>();
3283 tempTerminator->eraseFromParent();
3284 builder.restoreIP(*contInsertPoint);
3287 return llvm::Error::success();
3290 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3291 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3300 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3301 InsertPointTy oldIP = builder.saveIP();
3302 builder.restoreIP(codeGenIP);
3307 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3308 [](omp::DeclareReductionOp reductionDecl) {
3309 return &reductionDecl.getCleanupRegion();
3312 reductionCleanupRegions, privateReductionVariables,
3313 moduleTranslation, builder,
"omp.reduction.cleanup")))
3314 return llvm::createStringError(
3315 "failed to inline `cleanup` region of `omp.declare_reduction`");
3320 return llvm::make_error<PreviouslyReportedError>();
3324 if (isCancellable) {
3325 auto IPOrErr = ompBuilder->createBarrier(
3326 llvm::OpenMPIRBuilder::LocationDescription(builder),
3327 llvm::omp::Directive::OMPD_unknown,
3331 return IPOrErr.takeError();
3334 builder.restoreIP(oldIP);
3335 return llvm::Error::success();
3338 llvm::Value *ifCond =
nullptr;
3339 if (
auto ifVar = opInst.getIfExpr())
3341 llvm::Value *numThreads =
nullptr;
3342 if (!opInst.getNumThreadsVars().empty())
3343 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
3344 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3345 if (
auto bind = opInst.getProcBindKind())
3348 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3350 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3352 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3353 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3354 ifCond, numThreads, pbKind, isCancellable);
3359 builder.restoreIP(*afterIP);
3364static llvm::omp::OrderKind
3367 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3369 case omp::ClauseOrderKind::Concurrent:
3370 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3372 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3380 auto simdOp = cast<omp::SimdOp>(opInst);
3388 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3391 simdOp.getNumReductionVars());
3396 assert(isByRef.size() == simdOp.getNumReductionVars());
3398 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3402 builder, moduleTranslation, privateVarsInfo, allocaIP);
3407 LinearClauseProcessor linearClauseProcessor;
3409 if (!simdOp.getLinearVars().empty()) {
3410 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3412 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3413 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3414 bool isImplicit =
false;
3415 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3419 if (linearVar == mlirPrivVar) {
3421 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3422 llvmPrivateVar, idx);
3428 linearClauseProcessor.createLinearVar(
3429 builder, moduleTranslation,
3432 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3433 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3437 moduleTranslation, allocaIP, reductionDecls,
3438 privateReductionVariables, reductionVariableMap,
3439 deferredStores, isByRef)))
3450 assert(afterAllocas.get()->getSinglePredecessor());
3451 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3453 afterAllocas.get()->getSinglePredecessor(),
3454 reductionDecls, privateReductionVariables,
3455 reductionVariableMap, isByRef, deferredStores)))
3458 llvm::ConstantInt *simdlen =
nullptr;
3459 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3460 simdlen = builder.getInt64(simdlenVar.value());
3462 llvm::ConstantInt *safelen =
nullptr;
3463 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3464 safelen = builder.getInt64(safelenVar.value());
3466 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3469 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3470 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3472 for (
size_t i = 0; i < operands.size(); ++i) {
3473 llvm::Value *alignment =
nullptr;
3474 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3475 llvm::Type *ty = llvmVal->getType();
3477 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3478 alignment = builder.getInt64(intAttr.getInt());
3479 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3480 assert(alignment &&
"Invalid alignment value");
3484 if (!intAttr.getValue().isPowerOf2())
3487 auto curInsert = builder.saveIP();
3488 builder.SetInsertPoint(sourceBlock);
3489 llvmVal = builder.CreateLoad(ty, llvmVal);
3490 builder.restoreIP(curInsert);
3491 alignedVars[llvmVal] = alignment;
3495 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3502 if (simdOp.getLinearVars().size()) {
3503 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3504 loopInfo->getPreheader());
3506 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3507 loopInfo->getIndVar());
3509 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3511 ompBuilder->applySimd(loopInfo, alignedVars,
3513 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3515 order, simdlen, safelen);
3517 linearClauseProcessor.emitStoresForLinearVar(builder);
3518 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++)
3519 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3526 for (
auto [i, tuple] : llvm::enumerate(
3527 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3528 privateReductionVariables))) {
3529 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3531 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3532 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3533 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3537 llvm::Value *redValue = originalVariable;
3540 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3541 llvm::Value *privateRedValue = builder.CreateLoad(
3542 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3543 llvm::Value *reduced;
3545 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3548 builder.restoreIP(res.get());
3552 builder.CreateStore(reduced, originalVariable);
3557 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3558 [](omp::DeclareReductionOp reductionDecl) {
3559 return &reductionDecl.getCleanupRegion();
3562 moduleTranslation, builder,
3563 "omp.reduction.cleanup")))
3576 auto loopOp = cast<omp::LoopNestOp>(opInst);
3582 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3587 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3588 llvm::Value *iv) -> llvm::Error {
3591 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3596 bodyInsertPoints.push_back(ip);
3598 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3599 return llvm::Error::success();
3602 builder.restoreIP(ip);
3604 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3606 return regionBlock.takeError();
3608 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3609 return llvm::Error::success();
3617 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3618 llvm::Value *lowerBound =
3619 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3620 llvm::Value *upperBound =
3621 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3622 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3627 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3628 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3630 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3632 computeIP = loopInfos.front()->getPreheaderIP();
3636 ompBuilder->createCanonicalLoop(
3637 loc, bodyGen, lowerBound, upperBound, step,
3638 true, loopOp.getLoopInclusive(), computeIP);
3643 loopInfos.push_back(*loopResult);
3646 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3647 loopInfos.front()->getAfterIP();
3650 if (
const auto &tiles = loopOp.getTileSizes()) {
3651 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3654 for (
auto tile : tiles.value()) {
3655 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3656 tileSizes.push_back(tileVal);
3659 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3660 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3664 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3665 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3666 afterIP = {afterAfterBB, afterAfterBB->begin()};
3670 for (
const auto &newLoop : newLoops)
3671 loopInfos.push_back(newLoop);
3675 const auto &numCollapse = loopOp.getCollapseNumLoops();
3677 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3679 auto newTopLoopInfo =
3680 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3682 assert(newTopLoopInfo &&
"New top loop information is missing");
3683 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3684 [&](OpenMPLoopInfoStackFrame &frame) {
3685 frame.loopInfo = newTopLoopInfo;
3693 builder.restoreIP(afterIP);
3703 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3704 Value loopIV = op.getInductionVar();
3705 Value loopTC = op.getTripCount();
3707 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3710 ompBuilder->createCanonicalLoop(
3712 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3715 moduleTranslation.
mapValue(loopIV, llvmIV);
3717 builder.restoreIP(ip);
3722 return bodyGenStatus.takeError();
3724 llvmTC,
"omp.loop");
3726 return op.emitError(llvm::toString(llvmOrError.takeError()));
3728 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3729 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3730 builder.restoreIP(afterIP);
3733 if (
Value cli = op.getCli())
3746 Value applyee = op.getApplyee();
3747 assert(applyee &&
"Loop to apply unrolling on required");
3749 llvm::CanonicalLoopInfo *consBuilderCLI =
3751 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3752 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3760static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3763 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3768 for (
Value size : op.getSizes()) {
3769 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3770 assert(translatedSize &&
3771 "sizes clause arguments must already be translated");
3772 translatedSizes.push_back(translatedSize);
3775 for (
Value applyee : op.getApplyees()) {
3776 llvm::CanonicalLoopInfo *consBuilderCLI =
3778 assert(applyee &&
"Canonical loop must already been translated");
3779 translatedLoops.push_back(consBuilderCLI);
3782 auto generatedLoops =
3783 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3784 if (!op.getGeneratees().empty()) {
3785 for (
auto [mlirLoop,
genLoop] :
3786 zip_equal(op.getGeneratees(), generatedLoops))
3791 for (
Value applyee : op.getApplyees())
3799static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
3802 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3806 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
3807 Value applyee = op.getApplyees()[i];
3808 llvm::CanonicalLoopInfo *consBuilderCLI =
3810 assert(applyee &&
"Canonical loop must already been translated");
3811 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
3812 beforeFuse.push_back(consBuilderCLI);
3813 else if (op.getCount().has_value() &&
3814 i >= op.getFirst().value() + op.getCount().value() - 1)
3815 afterFuse.push_back(consBuilderCLI);
3817 toFuse.push_back(consBuilderCLI);
3820 (op.getGeneratees().empty() ||
3821 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
3822 "Wrong number of generatees");
3825 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
3826 if (!op.getGeneratees().empty()) {
3828 for (; i < beforeFuse.size(); i++)
3829 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
3830 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
3831 for (; i < afterFuse.size(); i++)
3832 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
3836 for (
Value applyee : op.getApplyees())
3843static llvm::AtomicOrdering
3846 return llvm::AtomicOrdering::Monotonic;
3849 case omp::ClauseMemoryOrderKind::Seq_cst:
3850 return llvm::AtomicOrdering::SequentiallyConsistent;
3851 case omp::ClauseMemoryOrderKind::Acq_rel:
3852 return llvm::AtomicOrdering::AcquireRelease;
3853 case omp::ClauseMemoryOrderKind::Acquire:
3854 return llvm::AtomicOrdering::Acquire;
3855 case omp::ClauseMemoryOrderKind::Release:
3856 return llvm::AtomicOrdering::Release;
3857 case omp::ClauseMemoryOrderKind::Relaxed:
3858 return llvm::AtomicOrdering::Monotonic;
3860 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3867 auto readOp = cast<omp::AtomicReadOp>(opInst);
3872 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3878 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3879 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3881 llvm::Type *elementType =
3882 moduleTranslation.
convertType(readOp.getElementType());
3884 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3885 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3886 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3894 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3899 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3902 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3904 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3905 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3906 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3907 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3910 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3918 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3919 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3920 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3921 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3922 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3923 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3924 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3925 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3926 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3927 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3931 bool &isIgnoreDenormalMode,
3932 bool &isFineGrainedMemory,
3933 bool &isRemoteMemory) {
3934 isIgnoreDenormalMode =
false;
3935 isFineGrainedMemory =
false;
3936 isRemoteMemory =
false;
3937 if (atomicUpdateOp &&
3938 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3939 mlir::omp::AtomicControlAttr atomicControlAttr =
3940 atomicUpdateOp.getAtomicControlAttr();
3941 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3942 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3943 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3950 llvm::IRBuilderBase &builder,
3957 auto &innerOpList = opInst.getRegion().front().getOperations();
3958 bool isXBinopExpr{
false};
3959 llvm::AtomicRMWInst::BinOp binop;
3961 llvm::Value *llvmExpr =
nullptr;
3962 llvm::Value *llvmX =
nullptr;
3963 llvm::Type *llvmXElementType =
nullptr;
3964 if (innerOpList.size() == 2) {
3970 opInst.getRegion().getArgument(0))) {
3971 return opInst.emitError(
"no atomic update operation with region argument"
3972 " as operand found inside atomic.update region");
3975 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3977 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3981 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3983 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3985 opInst.getRegion().getArgument(0).getType());
3986 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3990 llvm::AtomicOrdering atomicOrdering =
3995 [&opInst, &moduleTranslation](
3996 llvm::Value *atomicx,
3999 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4000 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4001 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4002 return llvm::make_error<PreviouslyReportedError>();
4004 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4005 assert(yieldop && yieldop.getResults().size() == 1 &&
4006 "terminator must be omp.yield op and it must have exactly one "
4008 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4011 bool isIgnoreDenormalMode;
4012 bool isFineGrainedMemory;
4013 bool isRemoteMemory;
4018 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4019 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4020 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4021 atomicOrdering, binop, updateFn,
4022 isXBinopExpr, isIgnoreDenormalMode,
4023 isFineGrainedMemory, isRemoteMemory);
4028 builder.restoreIP(*afterIP);
4034 llvm::IRBuilderBase &builder,
4041 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4042 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4044 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4045 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4047 assert((atomicUpdateOp || atomicWriteOp) &&
4048 "internal op must be an atomic.update or atomic.write op");
4050 if (atomicWriteOp) {
4051 isPostfixUpdate =
true;
4052 mlirExpr = atomicWriteOp.getExpr();
4054 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4055 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4056 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4059 if (innerOpList.size() == 2) {
4062 atomicUpdateOp.getRegion().getArgument(0))) {
4063 return atomicUpdateOp.emitError(
4064 "no atomic update operation with region argument"
4065 " as operand found inside atomic.update region");
4069 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4072 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4076 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4077 llvm::Value *llvmX =
4078 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4079 llvm::Value *llvmV =
4080 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4081 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4082 atomicCaptureOp.getAtomicReadOp().getElementType());
4083 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4086 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4090 llvm::AtomicOrdering atomicOrdering =
4094 [&](llvm::Value *atomicx,
4097 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4098 Block &bb = *atomicUpdateOp.getRegion().
begin();
4099 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4101 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4102 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4103 return llvm::make_error<PreviouslyReportedError>();
4105 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4106 assert(yieldop && yieldop.getResults().size() == 1 &&
4107 "terminator must be omp.yield op and it must have exactly one "
4109 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4112 bool isIgnoreDenormalMode;
4113 bool isFineGrainedMemory;
4114 bool isRemoteMemory;
4116 isFineGrainedMemory, isRemoteMemory);
4119 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4120 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4121 ompBuilder->createAtomicCapture(
4122 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4123 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4124 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4126 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4129 builder.restoreIP(*afterIP);
4134 omp::ClauseCancellationConstructType directive) {
4135 switch (directive) {
4136 case omp::ClauseCancellationConstructType::Loop:
4137 return llvm::omp::Directive::OMPD_for;
4138 case omp::ClauseCancellationConstructType::Parallel:
4139 return llvm::omp::Directive::OMPD_parallel;
4140 case omp::ClauseCancellationConstructType::Sections:
4141 return llvm::omp::Directive::OMPD_sections;
4142 case omp::ClauseCancellationConstructType::Taskgroup:
4143 return llvm::omp::Directive::OMPD_taskgroup;
4145 llvm_unreachable(
"Unhandled cancellation construct type");
4154 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4157 llvm::Value *ifCond =
nullptr;
4158 if (
Value ifVar = op.getIfExpr())
4161 llvm::omp::Directive cancelledDirective =
4164 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4165 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4167 if (failed(
handleError(afterIP, *op.getOperation())))
4170 builder.restoreIP(afterIP.get());
4177 llvm::IRBuilderBase &builder,
4182 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4185 llvm::omp::Directive cancelledDirective =
4188 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4189 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4191 if (failed(
handleError(afterIP, *op.getOperation())))
4194 builder.restoreIP(afterIP.get());
4204 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4206 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4211 Value symAddr = threadprivateOp.getSymAddr();
4214 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4217 if (!isa<LLVM::AddressOfOp>(symOp))
4218 return opInst.
emitError(
"Addressing symbol not found");
4219 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4221 LLVM::GlobalOp global =
4222 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4223 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4224 llvm::Type *type = globalValue->getValueType();
4225 llvm::TypeSize typeSize =
4226 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4228 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4229 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4230 ompLoc, globalValue, size, global.getSymName() +
".cache");
4236static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4238 switch (deviceClause) {
4239 case mlir::omp::DeclareTargetDeviceType::host:
4240 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4242 case mlir::omp::DeclareTargetDeviceType::nohost:
4243 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4245 case mlir::omp::DeclareTargetDeviceType::any:
4246 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4249 llvm_unreachable(
"unhandled device clause");
4252static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4254 mlir::omp::DeclareTargetCaptureClause captureClause) {
4255 switch (captureClause) {
4256 case mlir::omp::DeclareTargetCaptureClause::to:
4257 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4258 case mlir::omp::DeclareTargetCaptureClause::link:
4259 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4260 case mlir::omp::DeclareTargetCaptureClause::enter:
4261 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4262 case mlir::omp::DeclareTargetCaptureClause::none:
4263 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4265 llvm_unreachable(
"unhandled capture clause");
4270 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4272 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4273 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4274 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4279static llvm::SmallString<64>
4281 llvm::OpenMPIRBuilder &ompBuilder) {
4283 llvm::raw_svector_ostream os(suffix);
4286 auto fileInfoCallBack = [&loc]() {
4287 return std::pair<std::string, uint64_t>(
4288 llvm::StringRef(loc.getFilename()), loc.getLine());
4291 auto vfs = llvm::vfs::getRealFileSystem();
4294 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4296 os <<
"_decl_tgt_ref_ptr";
4302 if (
auto declareTargetGlobal =
4303 dyn_cast_if_present<omp::DeclareTargetInterface>(
4305 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4306 omp::DeclareTargetCaptureClause::link)
4312 if (
auto declareTargetGlobal =
4313 dyn_cast_if_present<omp::DeclareTargetInterface>(
4315 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4316 omp::DeclareTargetCaptureClause::to ||
4317 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4318 omp::DeclareTargetCaptureClause::enter)
4332 if (
auto declareTargetGlobal =
4333 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4336 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4337 omp::DeclareTargetCaptureClause::link) ||
4338 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4339 omp::DeclareTargetCaptureClause::to &&
4340 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4344 if (gOp.getSymName().contains(suffix))
4349 (gOp.getSymName().str() + suffix.str()).str());
4358struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4359 SmallVector<Operation *, 4> Mappers;
4362 void append(MapInfosTy &curInfo) {
4363 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4364 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4373struct MapInfoData : MapInfosTy {
4374 llvm::SmallVector<bool, 4> IsDeclareTarget;
4375 llvm::SmallVector<bool, 4> IsAMember;
4377 llvm::SmallVector<bool, 4> IsAMapping;
4378 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4379 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4382 llvm::SmallVector<llvm::Type *, 4> BaseType;
4385 void append(MapInfoData &CurInfo) {
4386 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4387 CurInfo.IsDeclareTarget.end());
4388 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4389 OriginalValue.append(CurInfo.OriginalValue.begin(),
4390 CurInfo.OriginalValue.end());
4391 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4392 MapInfosTy::append(CurInfo);
4396enum class TargetDirectiveEnumTy : uint32_t {
4400 TargetEnterData = 3,
4405static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4406 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4407 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4408 .Case([](omp::TargetEnterDataOp) {
4409 return TargetDirectiveEnumTy::TargetEnterData;
4411 .Case([&](omp::TargetExitDataOp) {
4412 return TargetDirectiveEnumTy::TargetExitData;
4414 .Case([&](omp::TargetUpdateOp) {
4415 return TargetDirectiveEnumTy::TargetUpdate;
4417 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4418 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4425 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4426 arrTy.getElementType()))
4443 llvm::Value *basePointer,
4444 llvm::Type *baseType,
4445 llvm::IRBuilderBase &builder,
4447 if (
auto memberClause =
4448 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4453 if (!memberClause.getBounds().empty()) {
4454 llvm::Value *elementCount = builder.getInt64(1);
4455 for (
auto bounds : memberClause.getBounds()) {
4456 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4457 bounds.getDefiningOp())) {
4462 elementCount = builder.CreateMul(
4466 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
4467 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
4468 builder.getInt64(1)));
4475 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4483 return builder.CreateMul(elementCount,
4484 builder.getInt64(underlyingTypeSzInBits / 8));
4495static llvm::omp::OpenMPOffloadMappingFlags
4497 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4498 return (mlirFlags & flag) == flag;
4500 const bool hasExplicitMap =
4501 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
4502 omp::ClauseMapFlags::none;
4504 llvm::omp::OpenMPOffloadMappingFlags mapType =
4505 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4508 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4511 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4514 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4517 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4520 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4523 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4526 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4529 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4532 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4535 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4538 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4541 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4544 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4545 if (!hasExplicitMap)
4546 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4556 ArrayRef<Value> useDevAddrOperands = {},
4557 ArrayRef<Value> hasDevAddrOperands = {}) {
4558 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4566 for (Value mapValue : mapVars) {
4567 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4568 for (
auto member : map.getMembers())
4569 if (member == mapOp)
4576 for (Value mapValue : mapVars) {
4577 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4579 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4580 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4581 mapData.Pointers.push_back(mapData.OriginalValue.back());
4583 if (llvm::Value *refPtr =
4585 mapData.IsDeclareTarget.push_back(
true);
4586 mapData.BasePointers.push_back(refPtr);
4588 mapData.IsDeclareTarget.push_back(
true);
4589 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4591 mapData.IsDeclareTarget.push_back(
false);
4592 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4595 mapData.BaseType.push_back(
4596 moduleTranslation.
convertType(mapOp.getVarType()));
4597 mapData.Sizes.push_back(
4598 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4599 mapData.BaseType.back(), builder, moduleTranslation));
4600 mapData.MapClause.push_back(mapOp.getOperation());
4602 mapData.Names.push_back(LLVM::createMappingInformation(
4604 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4605 if (mapOp.getMapperId())
4606 mapData.Mappers.push_back(
4608 mapOp, mapOp.getMapperIdAttr()));
4610 mapData.Mappers.push_back(
nullptr);
4611 mapData.IsAMapping.push_back(
true);
4612 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4615 auto findMapInfo = [&mapData](llvm::Value *val,
4616 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4619 for (llvm::Value *basePtr : mapData.OriginalValue) {
4620 if (basePtr == val && mapData.IsAMapping[index]) {
4622 mapData.Types[index] |=
4623 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4624 mapData.DevicePointers[index] = devInfoTy;
4632 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4633 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4634 for (Value mapValue : useDevOperands) {
4635 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4637 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4638 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4641 if (!findMapInfo(origValue, devInfoTy)) {
4642 mapData.OriginalValue.push_back(origValue);
4643 mapData.Pointers.push_back(mapData.OriginalValue.back());
4644 mapData.IsDeclareTarget.push_back(
false);
4645 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4646 mapData.BaseType.push_back(
4647 moduleTranslation.
convertType(mapOp.getVarType()));
4648 mapData.Sizes.push_back(builder.getInt64(0));
4649 mapData.MapClause.push_back(mapOp.getOperation());
4650 mapData.Types.push_back(
4651 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4652 mapData.Names.push_back(LLVM::createMappingInformation(
4654 mapData.DevicePointers.push_back(devInfoTy);
4655 mapData.Mappers.push_back(
nullptr);
4656 mapData.IsAMapping.push_back(
false);
4657 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4662 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4663 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4665 for (Value mapValue : hasDevAddrOperands) {
4666 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4668 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4669 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4671 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4673 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4674 omp::ClauseMapFlags::none;
4676 mapData.OriginalValue.push_back(origValue);
4677 mapData.BasePointers.push_back(origValue);
4678 mapData.Pointers.push_back(origValue);
4679 mapData.IsDeclareTarget.push_back(
false);
4680 mapData.BaseType.push_back(
4681 moduleTranslation.
convertType(mapOp.getVarType()));
4682 mapData.Sizes.push_back(
4683 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4684 mapData.MapClause.push_back(mapOp.getOperation());
4685 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4689 mapData.Types.push_back(mapType);
4693 if (mapOp.getMapperId()) {
4694 mapData.Mappers.push_back(
4696 mapOp, mapOp.getMapperIdAttr()));
4698 mapData.Mappers.push_back(
nullptr);
4703 mapData.Types.push_back(
4704 isDevicePtr ? mapType
4705 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4706 mapData.Mappers.push_back(
nullptr);
4708 mapData.Names.push_back(LLVM::createMappingInformation(
4710 mapData.DevicePointers.push_back(
4711 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4712 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4713 mapData.IsAMapping.push_back(
false);
4714 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4719 auto *res = llvm::find(mapData.MapClause, memberOp);
4720 assert(res != mapData.MapClause.end() &&
4721 "MapInfoOp for member not found in MapData, cannot return index");
4722 return std::distance(mapData.MapClause.begin(), res);
4726 omp::MapInfoOp mapInfo) {
4727 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4737 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4738 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4740 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4741 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4742 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4744 if (aIndex == bIndex)
4747 if (aIndex < bIndex)
4750 if (aIndex > bIndex)
4757 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4759 occludedChildren.push_back(
b);
4761 occludedChildren.push_back(a);
4762 return memberAParent;
4768 for (
auto v : occludedChildren)
4775 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4777 if (indexAttr.size() == 1)
4778 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4782 return llvm::cast<omp::MapInfoOp>(
4807static std::vector<llvm::Value *>
4809 llvm::IRBuilderBase &builder,
bool isArrayTy,
4811 std::vector<llvm::Value *> idx;
4822 idx.push_back(builder.getInt64(0));
4823 for (
int i = bounds.size() - 1; i >= 0; --i) {
4824 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4825 bounds[i].getDefiningOp())) {
4826 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4844 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4845 for (
int i = bounds.size() - 1; i >= 0; --i) {
4846 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4847 bounds[i].getDefiningOp())) {
4848 if (i == ((
int)bounds.size() - 1))
4850 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4852 idx.back() = builder.CreateAdd(
4853 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4854 boundOp.getExtent())),
4855 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4864 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
4865 return cast<IntegerAttr>(value).getInt();
4873 omp::MapInfoOp parentOp) {
4875 if (parentOp.getMembers().empty())
4879 if (parentOp.getMembers().size() == 1) {
4880 overlapMapDataIdxs.push_back(0);
4886 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4887 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4888 memberByIndex.push_back(
4889 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4894 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4895 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
4901 for (
auto v : memberByIndex) {
4905 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
4908 llvm::SmallVector<int64_t> xArr(x.second.size());
4909 getAsIntegers(x.second, xArr);
4910 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4911 xArr.size() >= vArr.size();
4917 for (
auto v : memberByIndex)
4918 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4919 overlapMapDataIdxs.push_back(v.first);
4931 if (mapOp.getVarPtrPtr())
4960 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4961 MapInfoData &mapData, uint64_t mapDataIndex,
4962 TargetDirectiveEnumTy targetDirective) {
4963 assert(!ompBuilder.Config.isTargetDevice() &&
4964 "function only supported for host device codegen");
4967 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4969 auto *parentMapper = mapData.Mappers[mapDataIndex];
4975 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4976 (targetDirective == TargetDirectiveEnumTy::Target &&
4977 !mapData.IsDeclareTarget[mapDataIndex])
4978 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4979 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4982 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4986 mapFlags parentFlags = mapData.Types[mapDataIndex];
4987 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4988 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4989 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4990 baseFlag |= (parentFlags & preserve);
4993 combinedInfo.Types.emplace_back(baseFlag);
4994 combinedInfo.DevicePointers.emplace_back(
4995 mapData.DevicePointers[mapDataIndex]);
4999 combinedInfo.Mappers.emplace_back(
5000 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
5002 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5003 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5012 llvm::Value *lowAddr, *highAddr;
5013 if (!parentClause.getPartialMap()) {
5014 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5015 builder.getPtrTy());
5016 highAddr = builder.CreatePointerCast(
5017 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5018 mapData.Pointers[mapDataIndex], 1),
5019 builder.getPtrTy());
5020 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5022 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5025 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5026 builder.getPtrTy());
5029 highAddr = builder.CreatePointerCast(
5030 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5031 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5032 builder.getPtrTy());
5033 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5036 llvm::Value *size = builder.CreateIntCast(
5037 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5038 builder.getInt64Ty(),
5040 combinedInfo.Sizes.push_back(size);
5042 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5043 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5051 if (!parentClause.getPartialMap()) {
5056 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5057 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5058 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5059 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5060 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5062 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5063 combinedInfo.Types.emplace_back(mapFlag);
5064 combinedInfo.DevicePointers.emplace_back(
5065 mapData.DevicePointers[mapDataIndex]);
5067 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5068 combinedInfo.BasePointers.emplace_back(
5069 mapData.BasePointers[mapDataIndex]);
5070 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5071 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5072 combinedInfo.Mappers.emplace_back(
nullptr);
5083 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5084 builder.getPtrTy());
5085 highAddr = builder.CreatePointerCast(
5086 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5087 mapData.Pointers[mapDataIndex], 1),
5088 builder.getPtrTy());
5095 for (
auto v : overlapIdxs) {
5098 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5099 combinedInfo.Types.emplace_back(mapFlag);
5100 combinedInfo.DevicePointers.emplace_back(
5101 mapData.DevicePointers[mapDataOverlapIdx]);
5103 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5104 combinedInfo.BasePointers.emplace_back(
5105 mapData.BasePointers[mapDataIndex]);
5106 combinedInfo.Mappers.emplace_back(
nullptr);
5107 combinedInfo.Pointers.emplace_back(lowAddr);
5108 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5109 builder.CreatePtrDiff(builder.getInt8Ty(),
5110 mapData.OriginalValue[mapDataOverlapIdx],
5112 builder.getInt64Ty(),
true));
5113 lowAddr = builder.CreateConstGEP1_32(
5115 mapData.MapClause[mapDataOverlapIdx]))
5116 ? builder.getPtrTy()
5117 : mapData.BaseType[mapDataOverlapIdx],
5118 mapData.BasePointers[mapDataOverlapIdx], 1);
5121 combinedInfo.Types.emplace_back(mapFlag);
5122 combinedInfo.DevicePointers.emplace_back(
5123 mapData.DevicePointers[mapDataIndex]);
5125 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5126 combinedInfo.BasePointers.emplace_back(
5127 mapData.BasePointers[mapDataIndex]);
5128 combinedInfo.Mappers.emplace_back(
nullptr);
5129 combinedInfo.Pointers.emplace_back(lowAddr);
5130 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5131 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5132 builder.getInt64Ty(),
true));
5135 return memberOfFlag;
5141 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5142 MapInfoData &mapData, uint64_t mapDataIndex,
5143 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5144 TargetDirectiveEnumTy targetDirective) {
5145 assert(!ompBuilder.Config.isTargetDevice() &&
5146 "function only supported for host device codegen");
5149 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5151 for (
auto mappedMembers : parentClause.getMembers()) {
5153 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5156 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5167 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5168 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5169 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5170 combinedInfo.Types.emplace_back(mapFlag);
5171 combinedInfo.DevicePointers.emplace_back(
5172 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5173 combinedInfo.Mappers.emplace_back(
nullptr);
5174 combinedInfo.Names.emplace_back(
5176 combinedInfo.BasePointers.emplace_back(
5177 mapData.BasePointers[mapDataIndex]);
5178 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5179 combinedInfo.Sizes.emplace_back(builder.getInt64(
5180 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5186 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5187 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5188 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5190 ? parentClause.getVarPtr()
5191 : parentClause.getVarPtrPtr());
5194 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5195 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5196 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5199 combinedInfo.Types.emplace_back(mapFlag);
5200 combinedInfo.DevicePointers.emplace_back(
5201 mapData.DevicePointers[memberDataIdx]);
5202 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5203 combinedInfo.Names.emplace_back(
5205 uint64_t basePointerIndex =
5207 combinedInfo.BasePointers.emplace_back(
5208 mapData.BasePointers[basePointerIndex]);
5209 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5211 llvm::Value *size = mapData.Sizes[memberDataIdx];
5213 size = builder.CreateSelect(
5214 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5215 builder.getInt64(0), size);
5218 combinedInfo.Sizes.emplace_back(size);
5223 MapInfosTy &combinedInfo,
5224 TargetDirectiveEnumTy targetDirective,
5225 int mapDataParentIdx = -1) {
5229 auto mapFlag = mapData.Types[mapDataIdx];
5230 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5234 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5236 if (targetDirective == TargetDirectiveEnumTy::Target &&
5237 !mapData.IsDeclareTarget[mapDataIdx])
5238 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5240 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5242 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5247 if (mapDataParentIdx >= 0)
5248 combinedInfo.BasePointers.emplace_back(
5249 mapData.BasePointers[mapDataParentIdx]);
5251 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5253 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5254 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5255 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5256 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5257 combinedInfo.Types.emplace_back(mapFlag);
5258 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5262 llvm::IRBuilderBase &builder,
5263 llvm::OpenMPIRBuilder &ompBuilder,
5265 MapInfoData &mapData, uint64_t mapDataIndex,
5266 TargetDirectiveEnumTy targetDirective) {
5267 assert(!ompBuilder.Config.isTargetDevice() &&
5268 "function only supported for host device codegen");
5271 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5276 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5277 auto memberClause = llvm::cast<omp::MapInfoOp>(
5278 parentClause.getMembers()[0].getDefiningOp());
5295 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5297 combinedInfo, mapData, mapDataIndex,
5300 combinedInfo, mapData, mapDataIndex,
5301 memberOfParentFlag, targetDirective);
5311 llvm::IRBuilderBase &builder) {
5313 "function only supported for host device codegen");
5314 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5316 if (!mapData.IsDeclareTarget[i]) {
5317 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5318 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5328 switch (captureKind) {
5329 case omp::VariableCaptureKind::ByRef: {
5330 llvm::Value *newV = mapData.Pointers[i];
5332 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5335 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5337 if (!offsetIdx.empty())
5338 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5340 mapData.Pointers[i] = newV;
5342 case omp::VariableCaptureKind::ByCopy: {
5343 llvm::Type *type = mapData.BaseType[i];
5345 if (mapData.Pointers[i]->getType()->isPointerTy())
5346 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5348 newV = mapData.Pointers[i];
5351 auto curInsert = builder.saveIP();
5352 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5354 auto *memTempAlloc =
5355 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5356 builder.SetCurrentDebugLocation(DbgLoc);
5357 builder.restoreIP(curInsert);
5359 builder.CreateStore(newV, memTempAlloc);
5360 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5363 mapData.Pointers[i] = newV;
5364 mapData.BasePointers[i] = newV;
5366 case omp::VariableCaptureKind::This:
5367 case omp::VariableCaptureKind::VLAType:
5368 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5379 MapInfoData &mapData,
5380 TargetDirectiveEnumTy targetDirective) {
5382 "function only supported for host device codegen");
5403 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5406 if (mapData.IsAMember[i])
5409 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5410 if (!mapInfoOp.getMembers().empty()) {
5412 combinedInfo, mapData, i, targetDirective);
5420static llvm::Expected<llvm::Function *>
5422 LLVM::ModuleTranslation &moduleTranslation,
5423 llvm::StringRef mapperFuncName,
5424 TargetDirectiveEnumTy targetDirective);
5426static llvm::Expected<llvm::Function *>
5429 TargetDirectiveEnumTy targetDirective) {
5431 "function only supported for host device codegen");
5432 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5433 std::string mapperFuncName =
5435 {
"omp_mapper", declMapperOp.getSymName()});
5437 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
5445 if (llvm::Function *existingFunc =
5446 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
5447 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
5448 return existingFunc;
5452 mapperFuncName, targetDirective);
5455static llvm::Expected<llvm::Function *>
5458 llvm::StringRef mapperFuncName,
5459 TargetDirectiveEnumTy targetDirective) {
5461 "function only supported for host device codegen");
5462 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5463 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5466 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
5469 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5472 MapInfosTy combinedInfo;
5474 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5475 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5476 builder.restoreIP(codeGenIP);
5477 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
5478 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
5479 builder.GetInsertBlock());
5480 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
5483 return llvm::make_error<PreviouslyReportedError>();
5484 MapInfoData mapData;
5487 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5493 return combinedInfo;
5497 if (!combinedInfo.Mappers[i])
5500 moduleTranslation, targetDirective);
5504 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5506 return newFn.takeError();
5507 if ([[maybe_unused]] llvm::Function *mappedFunc =
5509 assert(mappedFunc == *newFn &&
5510 "mapper function mapping disagrees with emitted function");
5512 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
5520 llvm::Value *ifCond =
nullptr;
5521 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5525 llvm::omp::RuntimeFunction RTLFn;
5527 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5530 llvm::OpenMPIRBuilder::TargetDataInfo info(
5533 assert(!ompBuilder->Config.isTargetDevice() &&
5534 "target data/enter/exit/update are host ops");
5535 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5537 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
5538 llvm::Value *v = moduleTranslation.
lookupValue(dev);
5539 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
5544 .Case([&](omp::TargetDataOp dataOp) {
5548 if (
auto ifVar = dataOp.getIfExpr())
5552 deviceID = getDeviceID(devId);
5554 mapVars = dataOp.getMapVars();
5555 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5556 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5559 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5563 if (
auto ifVar = enterDataOp.getIfExpr())
5567 deviceID = getDeviceID(devId);
5570 enterDataOp.getNowait()
5571 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5572 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5573 mapVars = enterDataOp.getMapVars();
5574 info.HasNoWait = enterDataOp.getNowait();
5577 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5581 if (
auto ifVar = exitDataOp.getIfExpr())
5585 deviceID = getDeviceID(devId);
5587 RTLFn = exitDataOp.getNowait()
5588 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5589 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5590 mapVars = exitDataOp.getMapVars();
5591 info.HasNoWait = exitDataOp.getNowait();
5594 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5598 if (
auto ifVar = updateDataOp.getIfExpr())
5602 deviceID = getDeviceID(devId);
5605 updateDataOp.getNowait()
5606 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5607 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5608 mapVars = updateDataOp.getMapVars();
5609 info.HasNoWait = updateDataOp.getNowait();
5612 .DefaultUnreachable(
"unexpected operation");
5617 if (!isOffloadEntry)
5618 ifCond = builder.getFalse();
5620 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5621 MapInfoData mapData;
5623 builder, useDevicePtrVars, useDeviceAddrVars);
5626 MapInfosTy combinedInfo;
5627 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5628 builder.restoreIP(codeGenIP);
5629 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5631 return combinedInfo;
5637 [&moduleTranslation](
5638 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5642 for (
auto [arg, useDevVar] :
5643 llvm::zip_equal(blockArgs, useDeviceVars)) {
5645 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5646 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5647 : mapInfoOp.getVarPtr();
5650 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5651 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5652 mapInfoData.MapClause, mapInfoData.DevicePointers,
5653 mapInfoData.BasePointers)) {
5654 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5655 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5656 devicePointer != type)
5659 if (llvm::Value *devPtrInfoMap =
5660 mapper ? mapper(basePointer) : basePointer) {
5661 moduleTranslation.
mapValue(arg, devPtrInfoMap);
5668 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5669 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5670 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5673 builder.restoreIP(codeGenIP);
5674 assert(isa<omp::TargetDataOp>(op) &&
5675 "BodyGen requested for non TargetDataOp");
5676 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5677 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
5678 switch (bodyGenType) {
5679 case BodyGenTy::Priv:
5681 if (!info.DevicePtrInfoMap.empty()) {
5682 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5683 blockArgIface.getUseDeviceAddrBlockArgs(),
5684 useDeviceAddrVars, mapData,
5685 [&](llvm::Value *basePointer) -> llvm::Value * {
5686 if (!info.DevicePtrInfoMap[basePointer].second)
5688 return builder.CreateLoad(
5690 info.DevicePtrInfoMap[basePointer].second);
5692 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5693 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5694 mapData, [&](llvm::Value *basePointer) {
5695 return info.DevicePtrInfoMap[basePointer].second;
5699 moduleTranslation)))
5700 return llvm::make_error<PreviouslyReportedError>();
5703 case BodyGenTy::DupNoPriv:
5704 if (info.DevicePtrInfoMap.empty()) {
5707 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5708 blockArgIface.getUseDeviceAddrBlockArgs(),
5709 useDeviceAddrVars, mapData);
5710 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5711 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5715 case BodyGenTy::NoPriv:
5717 if (info.DevicePtrInfoMap.empty()) {
5719 moduleTranslation)))
5720 return llvm::make_error<PreviouslyReportedError>();
5724 return builder.saveIP();
5727 auto customMapperCB =
5729 if (!combinedInfo.Mappers[i])
5731 info.HasMapper =
true;
5733 moduleTranslation, targetDirective);
5736 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5737 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5739 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5740 if (isa<omp::TargetDataOp>(op))
5741 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5742 deviceID, ifCond, info, genMapInfoCB,
5746 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5747 deviceID, ifCond, info, genMapInfoCB,
5748 customMapperCB, &RTLFn);
5754 builder.restoreIP(*afterIP);
5762 auto distributeOp = cast<omp::DistributeOp>(opInst);
5769 bool doDistributeReduction =
5773 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5778 if (doDistributeReduction) {
5779 isByRef =
getIsByRef(teamsOp.getReductionByref());
5780 assert(isByRef.size() == teamsOp.getNumReductionVars());
5783 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5787 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5788 .getReductionBlockArgs();
5791 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5792 reductionDecls, privateReductionVariables, reductionVariableMap,
5797 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5798 auto bodyGenCB = [&](InsertPointTy allocaIP,
5799 InsertPointTy codeGenIP) -> llvm::Error {
5803 moduleTranslation, allocaIP);
5806 builder.restoreIP(codeGenIP);
5812 return llvm::make_error<PreviouslyReportedError>();
5817 return llvm::make_error<PreviouslyReportedError>();
5820 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
5822 distributeOp.getPrivateNeedsBarrier())))
5823 return llvm::make_error<PreviouslyReportedError>();
5826 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5829 builder, moduleTranslation);
5831 return regionBlock.takeError();
5832 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5837 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5840 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5841 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5842 : omp::ClauseScheduleKind::Static;
5844 bool isOrdered = hasDistSchedule;
5845 std::optional<omp::ScheduleModifier> scheduleMod;
5846 bool isSimd =
false;
5847 llvm::omp::WorksharingLoopType workshareLoopType =
5848 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5849 bool loopNeedsBarrier =
false;
5850 llvm::Value *chunk = moduleTranslation.
lookupValue(
5851 distributeOp.getDistScheduleChunkSize());
5852 llvm::CanonicalLoopInfo *loopInfo =
5854 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5855 ompBuilder->applyWorkshareLoop(
5856 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5857 convertToScheduleKind(schedule), chunk, isSimd,
5858 scheduleMod == omp::ScheduleModifier::monotonic,
5859 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5860 workshareLoopType,
false, hasDistSchedule, chunk);
5863 return wsloopIP.takeError();
5866 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5868 return llvm::make_error<PreviouslyReportedError>();
5870 return llvm::Error::success();
5873 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5876 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5877 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5882 builder.restoreIP(*afterIP);
5884 if (doDistributeReduction) {
5887 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5888 privateReductionVariables, isByRef,
5900 if (!cast<mlir::ModuleOp>(op))
5905 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5906 attribute.getOpenmpDeviceVersion());
5908 if (attribute.getNoGpuLib())
5911 ompBuilder->createGlobalFlag(
5912 attribute.getDebugKind() ,
5913 "__omp_rtl_debug_kind");
5914 ompBuilder->createGlobalFlag(
5916 .getAssumeTeamsOversubscription()
5918 "__omp_rtl_assume_teams_oversubscription");
5919 ompBuilder->createGlobalFlag(
5921 .getAssumeThreadsOversubscription()
5923 "__omp_rtl_assume_threads_oversubscription");
5924 ompBuilder->createGlobalFlag(
5925 attribute.getAssumeNoThreadState() ,
5926 "__omp_rtl_assume_no_thread_state");
5927 ompBuilder->createGlobalFlag(
5929 .getAssumeNoNestedParallelism()
5931 "__omp_rtl_assume_no_nested_parallelism");
5936 omp::TargetOp targetOp,
5937 llvm::StringRef parentName =
"") {
5938 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5940 assert(fileLoc &&
"No file found from location");
5941 StringRef fileName = fileLoc.getFilename().getValue();
5943 llvm::sys::fs::UniqueID id;
5944 uint64_t line = fileLoc.getLine();
5945 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5946 size_t fileHash = llvm::hash_value(fileName.str());
5947 size_t deviceId = 0xdeadf17e;
5949 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5951 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5952 id.getFile(), line);
5959 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5961 "function only supported for target device codegen");
5962 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5963 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5976 if (mapData.IsDeclareTarget[i]) {
5983 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5984 convertUsersOfConstantsToInstructions(constant,
func,
false);
5991 for (llvm::User *user : mapData.OriginalValue[i]->users())
5992 userVec.push_back(user);
5994 for (llvm::User *user : userVec) {
5995 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5996 if (insn->getFunction() ==
func) {
5997 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5998 llvm::Value *substitute = mapData.BasePointers[i];
6000 : mapOp.getVarPtr())) {
6001 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6002 substitute = builder.CreateLoad(
6003 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6004 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6006 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6053static llvm::IRBuilderBase::InsertPoint
6055 llvm::Value *input, llvm::Value *&retVal,
6056 llvm::IRBuilderBase &builder,
6057 llvm::OpenMPIRBuilder &ompBuilder,
6059 llvm::IRBuilderBase::InsertPoint allocaIP,
6060 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6061 assert(ompBuilder.Config.isTargetDevice() &&
6062 "function only supported for target device codegen");
6063 builder.restoreIP(allocaIP);
6065 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6067 ompBuilder.M.getContext());
6068 unsigned alignmentValue = 0;
6070 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
6071 if (mapData.OriginalValue[i] == input) {
6072 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6073 capture = mapOp.getMapCaptureType();
6076 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6080 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6081 unsigned int defaultAS =
6082 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6085 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6087 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6088 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6090 builder.CreateStore(&arg, v);
6092 builder.restoreIP(codeGenIP);
6095 case omp::VariableCaptureKind::ByCopy: {
6099 case omp::VariableCaptureKind::ByRef: {
6100 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6102 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6117 if (v->getType()->isPointerTy() && alignmentValue) {
6118 llvm::MDBuilder MDB(builder.getContext());
6119 loadInst->setMetadata(
6120 llvm::LLVMContext::MD_align,
6121 llvm::MDNode::get(builder.getContext(),
6122 MDB.createConstant(llvm::ConstantInt::get(
6123 llvm::Type::getInt64Ty(builder.getContext()),
6130 case omp::VariableCaptureKind::This:
6131 case omp::VariableCaptureKind::VLAType:
6134 assert(
false &&
"Currently unsupported capture kind");
6138 return builder.saveIP();
6155 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6156 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6157 blockArgIface.getHostEvalBlockArgs())) {
6158 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6162 .Case([&](omp::TeamsOp teamsOp) {
6163 if (teamsOp.getNumTeamsLower() == blockArg)
6164 numTeamsLower = hostEvalVar;
6165 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6167 numTeamsUpper = hostEvalVar;
6168 else if (!teamsOp.getThreadLimitVars().empty() &&
6169 teamsOp.getThreadLimit(0) == blockArg)
6170 threadLimit = hostEvalVar;
6172 llvm_unreachable(
"unsupported host_eval use");
6174 .Case([&](omp::ParallelOp parallelOp) {
6175 if (!parallelOp.getNumThreadsVars().empty() &&
6176 parallelOp.getNumThreads(0) == blockArg)
6177 numThreads = hostEvalVar;
6179 llvm_unreachable(
"unsupported host_eval use");
6181 .Case([&](omp::LoopNestOp loopOp) {
6182 auto processBounds =
6186 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6187 if (lb == blockArg) {
6190 (*outBounds)[i] = hostEvalVar;
6196 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6197 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6199 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6201 assert(found &&
"unsupported host_eval use");
6203 .DefaultUnreachable(
"unsupported host_eval use");
6215template <
typename OpTy>
6220 if (OpTy casted = dyn_cast<OpTy>(op))
6223 if (immediateParent)
6224 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6233 return std::nullopt;
6236 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6237 return constAttr.getInt();
6239 return std::nullopt;
6244 uint64_t sizeInBytes = sizeInBits / 8;
6248template <
typename OpTy>
6250 if (op.getNumReductionVars() > 0) {
6255 members.reserve(reductions.size());
6256 for (omp::DeclareReductionOp &red : reductions)
6257 members.push_back(red.getType());
6259 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6275 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6276 bool isTargetDevice,
bool isGPU) {
6279 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6280 if (!isTargetDevice) {
6288 numTeamsLower = teamsOp.getNumTeamsLower();
6290 if (!teamsOp.getNumTeamsUpperVars().empty())
6291 numTeamsUpper = teamsOp.getNumTeams(0);
6292 if (!teamsOp.getThreadLimitVars().empty())
6293 threadLimit = teamsOp.getThreadLimit(0);
6297 if (!parallelOp.getNumThreadsVars().empty())
6298 numThreads = parallelOp.getNumThreads(0);
6304 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6308 if (numTeamsUpper) {
6310 minTeamsVal = maxTeamsVal = *val;
6312 minTeamsVal = maxTeamsVal = 0;
6318 minTeamsVal = maxTeamsVal = 1;
6320 minTeamsVal = maxTeamsVal = -1;
6325 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6339 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6340 if (!targetOp.getThreadLimitVars().empty())
6341 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6342 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6345 int32_t maxThreadsVal = -1;
6347 setMaxValueFromClause(numThreads, maxThreadsVal);
6355 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6356 if (combinedMaxThreadsVal < 0 ||
6357 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6358 combinedMaxThreadsVal = teamsThreadLimitVal;
6360 if (combinedMaxThreadsVal < 0 ||
6361 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6362 combinedMaxThreadsVal = maxThreadsVal;
6364 int32_t reductionDataSize = 0;
6365 if (isGPU && capturedOp) {
6371 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6373 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6374 omp::TargetRegionFlags::spmd) &&
6375 "invalid kernel flags");
6377 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6378 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6379 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6380 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6381 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6382 if (omp::bitEnumContainsAll(kernelFlags,
6383 omp::TargetRegionFlags::spmd |
6384 omp::TargetRegionFlags::no_loop) &&
6385 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6386 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6388 attrs.MinTeams = minTeamsVal;
6389 attrs.MaxTeams.front() = maxTeamsVal;
6390 attrs.MinThreads = 1;
6391 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6392 attrs.ReductionDataSize = reductionDataSize;
6395 if (attrs.ReductionDataSize != 0)
6396 attrs.ReductionBufferLength = 1024;
6408 omp::TargetOp targetOp,
Operation *capturedOp,
6409 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6411 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6413 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6417 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6420 if (!targetOp.getThreadLimitVars().empty()) {
6421 Value targetThreadLimit = targetOp.getThreadLimit(0);
6422 attrs.TargetThreadLimit.front() =
6430 attrs.MinTeams = builder.CreateSExtOrTrunc(
6431 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
6434 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6435 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
6437 if (teamsThreadLimit)
6438 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6439 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
6442 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
6444 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6445 omp::TargetRegionFlags::trip_count)) {
6447 attrs.LoopTripCount =
nullptr;
6452 for (
auto [loopLower, loopUpper, loopStep] :
6453 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6454 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
6455 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
6456 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
6458 if (!lowerBound || !upperBound || !step) {
6459 attrs.LoopTripCount =
nullptr;
6463 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6464 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6465 loc, lowerBound, upperBound, step,
true,
6466 loopOp.getLoopInclusive());
6468 if (!attrs.LoopTripCount) {
6469 attrs.LoopTripCount = tripCount;
6474 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6479 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6481 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
6483 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6490 auto targetOp = cast<omp::TargetOp>(opInst);
6495 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6504 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6505 assert(parentBB &&
"No insert block is set for the builder");
6506 llvm::Function *parentLLVMFn = parentBB->getParent();
6507 assert(parentLLVMFn &&
"Parent Function must be valid");
6508 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6509 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6510 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6511 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6514 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6515 bool isGPU = ompBuilder->Config.isGPU();
6518 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6519 auto &targetRegion = targetOp.getRegion();
6536 llvm::Function *llvmOutlinedFn =
nullptr;
6537 TargetDirectiveEnumTy targetDirective =
6538 getTargetDirectiveEnumTyFromOp(&opInst);
6542 bool isOffloadEntry =
6543 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6550 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6552 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6553 std::optional<DenseI64ArrayAttr> privateMapIndices =
6554 targetOp.getPrivateMapsAttr();
6556 for (
auto [privVarIdx, privVarSymPair] :
6557 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6558 auto privVar = std::get<0>(privVarSymPair);
6559 auto privSym = std::get<1>(privVarSymPair);
6561 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6562 omp::PrivateClauseOp privatizer =
6565 if (!privatizer.needsMap())
6569 targetOp.getMappedValueForPrivateVar(privVarIdx);
6570 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6571 "variable that needs mapping");
6576 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6577 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6581 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6583 varType == privVar.getType() &&
6584 "Type of private var doesn't match the type of the mapped value");
6588 mappedPrivateVars.insert(
6590 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6591 (*privateMapIndices)[privVarIdx])});
6595 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6596 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6597 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6598 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6599 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6602 llvm::Function *llvmParentFn =
6604 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6605 assert(llvmParentFn && llvmOutlinedFn &&
6606 "Both parent and outlined functions must exist at this point");
6608 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6609 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6611 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6612 attr.isStringAttribute())
6613 llvmOutlinedFn->addFnAttr(attr);
6615 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6616 attr.isStringAttribute())
6617 llvmOutlinedFn->addFnAttr(attr);
6619 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6620 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6621 llvm::Value *mapOpValue =
6622 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6623 moduleTranslation.
mapValue(arg, mapOpValue);
6625 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6626 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6627 llvm::Value *mapOpValue =
6628 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6629 moduleTranslation.
mapValue(arg, mapOpValue);
6638 allocaIP, &mappedPrivateVars);
6641 return llvm::make_error<PreviouslyReportedError>();
6643 builder.restoreIP(codeGenIP);
6645 &mappedPrivateVars),
6648 return llvm::make_error<PreviouslyReportedError>();
6651 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
6653 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6654 return llvm::make_error<PreviouslyReportedError>();
6658 std::back_inserter(privateCleanupRegions),
6659 [](omp::PrivateClauseOp privatizer) {
6660 return &privatizer.getDeallocRegion();
6664 targetRegion,
"omp.target", builder, moduleTranslation);
6667 return exitBlock.takeError();
6669 builder.SetInsertPoint(*exitBlock);
6670 if (!privateCleanupRegions.empty()) {
6672 privateCleanupRegions, privateVarsInfo.
llvmVars,
6673 moduleTranslation, builder,
"omp.targetop.private.cleanup",
6675 return llvm::createStringError(
6676 "failed to inline `dealloc` region of `omp.private` "
6677 "op in the target region");
6679 return builder.saveIP();
6682 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6685 StringRef parentName = parentFn.getName();
6687 llvm::TargetRegionEntryInfo entryInfo;
6691 MapInfoData mapData;
6696 MapInfosTy combinedInfos;
6698 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6699 builder.restoreIP(codeGenIP);
6700 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6702 return combinedInfos;
6705 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6706 llvm::Value *&retVal, InsertPointTy allocaIP,
6707 InsertPointTy codeGenIP)
6708 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6709 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6710 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6716 if (!isTargetDevice) {
6717 retVal = cast<llvm::Value>(&arg);
6722 *ompBuilder, moduleTranslation,
6723 allocaIP, codeGenIP);
6726 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6727 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6728 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6730 isTargetDevice, isGPU);
6734 if (!isTargetDevice)
6736 targetCapturedOp, runtimeAttrs);
6744 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6745 llvm::Value *value = moduleTranslation.
lookupValue(var);
6746 moduleTranslation.
mapValue(arg, value);
6748 if (!llvm::isa<llvm::Constant>(value))
6749 kernelInput.push_back(value);
6752 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6759 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6760 kernelInput.push_back(mapData.OriginalValue[i]);
6765 moduleTranslation, dds);
6767 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6769 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6771 llvm::OpenMPIRBuilder::TargetDataInfo info(
6775 auto customMapperCB =
6777 if (!combinedInfos.Mappers[i])
6779 info.HasMapper =
true;
6781 moduleTranslation, targetDirective);
6784 llvm::Value *ifCond =
nullptr;
6785 if (
Value targetIfCond = targetOp.getIfExpr())
6786 ifCond = moduleTranslation.
lookupValue(targetIfCond);
6788 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6790 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6791 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6792 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6797 builder.restoreIP(*afterIP);
6818 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6819 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6821 if (!offloadMod.getIsTargetDevice())
6824 omp::DeclareTargetDeviceType declareType =
6825 attribute.getDeviceType().getValue();
6827 if (declareType == omp::DeclareTargetDeviceType::host) {
6828 llvm::Function *llvmFunc =
6830 llvmFunc->dropAllReferences();
6831 llvmFunc->eraseFromParent();
6837 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6838 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6839 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6841 bool isDeclaration = gOp.isDeclaration();
6842 bool isExternallyVisible =
6845 llvm::StringRef mangledName = gOp.getSymName();
6846 auto captureClause =
6852 std::vector<llvm::GlobalVariable *> generatedRefs;
6854 std::vector<llvm::Triple> targetTriple;
6855 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6857 LLVM::LLVMDialect::getTargetTripleAttrName()));
6858 if (targetTripleAttr)
6859 targetTriple.emplace_back(targetTripleAttr.data());
6861 auto fileInfoCallBack = [&loc]() {
6862 std::string filename =
"";
6863 std::uint64_t lineNo = 0;
6866 filename = loc.getFilename().str();
6867 lineNo = loc.getLine();
6870 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6874 auto vfs = llvm::vfs::getRealFileSystem();
6876 ompBuilder->registerTargetGlobalVariable(
6877 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6878 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6879 mangledName, generatedRefs,
false, targetTriple,
6881 gVal->getType(), gVal);
6883 if (ompBuilder->Config.isTargetDevice() &&
6884 (attribute.getCaptureClause().getValue() !=
6885 mlir::omp::DeclareTargetCaptureClause::to ||
6886 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6887 ompBuilder->getAddrOfDeclareTargetVar(
6888 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6889 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6890 mangledName, generatedRefs,
false, targetTriple,
6891 gVal->getType(),
nullptr,
6904class OpenMPDialectLLVMIRTranslationInterface
6905 :
public LLVMTranslationDialectInterface {
6907 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
6912 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6913 LLVM::ModuleTranslation &moduleTranslation)
const final;
6918 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6919 NamedAttribute attribute,
6920 LLVM::ModuleTranslation &moduleTranslation)
const final;
6925LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6926 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6927 NamedAttribute attribute,
6928 LLVM::ModuleTranslation &moduleTranslation)
const {
6929 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6931 .Case(
"omp.is_target_device",
6932 [&](Attribute attr) {
6933 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6934 llvm::OpenMPIRBuilderConfig &
config =
6936 config.setIsTargetDevice(deviceAttr.getValue());
6942 [&](Attribute attr) {
6943 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6944 llvm::OpenMPIRBuilderConfig &
config =
6946 config.setIsGPU(gpuAttr.getValue());
6951 .Case(
"omp.host_ir_filepath",
6952 [&](Attribute attr) {
6953 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6954 llvm::OpenMPIRBuilder *ompBuilder =
6956 auto VFS = llvm::vfs::getRealFileSystem();
6957 ompBuilder->loadOffloadInfoMetadata(*VFS,
6958 filepathAttr.getValue());
6964 [&](Attribute attr) {
6965 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6969 .Case(
"omp.version",
6970 [&](Attribute attr) {
6971 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6972 llvm::OpenMPIRBuilder *ompBuilder =
6974 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6975 versionAttr.getVersion());
6980 .Case(
"omp.declare_target",
6981 [&](Attribute attr) {
6982 if (
auto declareTargetAttr =
6983 dyn_cast<omp::DeclareTargetAttr>(attr))
6988 .Case(
"omp.requires",
6989 [&](Attribute attr) {
6990 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6991 using Requires = omp::ClauseRequires;
6992 Requires flags = requiresAttr.getValue();
6993 llvm::OpenMPIRBuilderConfig &
config =
6995 config.setHasRequiresReverseOffload(
6996 bitEnumContainsAll(flags, Requires::reverse_offload));
6997 config.setHasRequiresUnifiedAddress(
6998 bitEnumContainsAll(flags, Requires::unified_address));
6999 config.setHasRequiresUnifiedSharedMemory(
7000 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7001 config.setHasRequiresDynamicAllocators(
7002 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7007 .Case(
"omp.target_triples",
7008 [&](Attribute attr) {
7009 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7010 llvm::OpenMPIRBuilderConfig &
config =
7012 config.TargetTriples.clear();
7013 config.TargetTriples.reserve(triplesAttr.size());
7014 for (Attribute tripleAttr : triplesAttr) {
7015 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7016 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7024 .Default([](Attribute) {
7040 if (
auto declareTargetIface =
7041 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7042 parentFn.getOperation()))
7043 if (declareTargetIface.isDeclareTarget() &&
7044 declareTargetIface.getDeclareTargetDeviceType() !=
7045 mlir::omp::DeclareTargetDeviceType::host)
7055 llvm::Module *llvmModule) {
7056 llvm::Type *i64Ty = builder.getInt64Ty();
7057 llvm::Type *i32Ty = builder.getInt32Ty();
7058 llvm::Type *returnType = builder.getPtrTy(0);
7059 llvm::FunctionType *fnType =
7060 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
7061 llvm::Function *
func = cast<llvm::Function>(
7062 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
7069 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7074 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7078 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7080 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7081 mlir::Type heapTy = allocMemOp.getAllocatedType();
7082 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
7083 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7084 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7085 for (
auto typeParam : allocMemOp.getTypeparams())
7087 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
7089 llvm::CallInst *call =
7090 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7091 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7094 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
7099 llvm::Module *llvmModule) {
7100 llvm::Type *ptrTy = builder.getPtrTy(0);
7101 llvm::Type *i32Ty = builder.getInt32Ty();
7102 llvm::Type *voidTy = builder.getVoidTy();
7103 llvm::FunctionType *fnType =
7104 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
7105 llvm::Function *
func = dyn_cast<llvm::Function>(
7106 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
7113 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7118 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7122 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7125 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
7127 llvm::Value *intToPtr =
7128 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7129 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7135LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7136 Operation *op, llvm::IRBuilderBase &builder,
7137 LLVM::ModuleTranslation &moduleTranslation)
const {
7140 if (ompBuilder->Config.isTargetDevice() &&
7141 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7144 return op->
emitOpError() <<
"unsupported host op found in device";
7152 bool isOutermostLoopWrapper =
7153 isa_and_present<omp::LoopWrapperInterface>(op) &&
7154 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
7156 if (isOutermostLoopWrapper)
7157 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
7160 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7161 .Case([&](omp::BarrierOp op) -> LogicalResult {
7165 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7166 ompBuilder->createBarrier(builder.saveIP(),
7167 llvm::omp::OMPD_barrier);
7169 if (res.succeeded()) {
7172 builder.restoreIP(*afterIP);
7176 .Case([&](omp::TaskyieldOp op) {
7180 ompBuilder->createTaskyield(builder.saveIP());
7183 .Case([&](omp::FlushOp op) {
7195 ompBuilder->createFlush(builder.saveIP());
7198 .Case([&](omp::ParallelOp op) {
7201 .Case([&](omp::MaskedOp) {
7204 .Case([&](omp::MasterOp) {
7207 .Case([&](omp::CriticalOp) {
7210 .Case([&](omp::OrderedRegionOp) {
7213 .Case([&](omp::OrderedOp) {
7216 .Case([&](omp::WsloopOp) {
7219 .Case([&](omp::SimdOp) {
7222 .Case([&](omp::AtomicReadOp) {
7225 .Case([&](omp::AtomicWriteOp) {
7228 .Case([&](omp::AtomicUpdateOp op) {
7231 .Case([&](omp::AtomicCaptureOp op) {
7234 .Case([&](omp::CancelOp op) {
7237 .Case([&](omp::CancellationPointOp op) {
7240 .Case([&](omp::SectionsOp) {
7243 .Case([&](omp::SingleOp op) {
7246 .Case([&](omp::TeamsOp op) {
7249 .Case([&](omp::TaskOp op) {
7252 .Case([&](omp::TaskloopOp op) {
7255 .Case([&](omp::TaskgroupOp op) {
7258 .Case([&](omp::TaskwaitOp op) {
7261 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7262 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7263 omp::CriticalDeclareOp>([](
auto op) {
7276 .Case([&](omp::ThreadprivateOp) {
7279 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7280 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
7283 .Case([&](omp::TargetOp) {
7286 .Case([&](omp::DistributeOp) {
7289 .Case([&](omp::LoopNestOp) {
7292 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7299 .Case([&](omp::NewCliOp op) {
7304 .Case([&](omp::CanonicalLoopOp op) {
7307 .Case([&](omp::UnrollHeuristicOp op) {
7316 .Case([&](omp::TileOp op) {
7317 return applyTile(op, builder, moduleTranslation);
7319 .Case([&](omp::FuseOp op) {
7320 return applyFuse(op, builder, moduleTranslation);
7322 .Case([&](omp::TargetAllocMemOp) {
7325 .Case([&](omp::TargetFreeMemOp) {
7328 .Default([&](Operation *inst) {
7330 <<
"not yet implemented: " << inst->
getName();
7333 if (isOutermostLoopWrapper)
7340 registry.
insert<omp::OpenMPDialect>();
7342 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars