24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
198 if (linearVarType->isIntegerTy()) {
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 }
else if (linearVarType->isFloatingPointTy()) {
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
220 "Linear variable must be of integer or floating-point type");
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(),
"omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
265 builder.SetInsertPoint(linearExitBB->getTerminator());
267 builder.saveIP(), llvm::omp::OMPD_barrier);
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
282 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (
auto *user : users) {
288 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
307 assert(privatizer &&
"privatizer not found in the symbol table");
318 auto todo = [&op](StringRef clauseName) {
319 return op.
emitError() <<
"not yet implemented: Unhandled clause "
320 << clauseName <<
" in " << op.
getName()
324 auto checkAffinity = [&todo](
auto op, LogicalResult &
result) {
325 if (!op.getAffinityVars().empty())
326 result = todo(
"affinity");
328 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
329 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
330 result = todo(
"allocate");
332 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
334 result = todo(
"ompx_bare");
336 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
340 auto checkHint = [](
auto op, LogicalResult &) {
344 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo(
"in_reduction");
349 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
353 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
354 if (op.getOrder() || op.getOrderMod())
357 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &
result) {
358 if (op.getParLevelSimd())
359 result = todo(
"parallelization-level");
361 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo(
"privatization");
365 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo(
"reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo(
"reduction with modifier");
374 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo(
"task_reduction");
379 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
380 if (op.hasNumTeamsMultiDim())
381 result = todo(
"num_teams with multi-dimensional values");
383 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
384 if (op.hasNumThreadsMultiDim())
385 result = todo(
"num_threads with multi-dimensional values");
388 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
389 if (op.hasThreadLimitMultiDim())
390 result = todo(
"thread_limit with multi-dimensional values");
395 .Case([&](omp::DistributeOp op) {
396 checkAllocate(op,
result);
399 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op,
result); })
400 .Case([&](omp::SectionsOp op) {
401 checkAllocate(op,
result);
403 checkReduction(op,
result);
405 .Case([&](omp::SingleOp op) {
406 checkAllocate(op,
result);
409 .Case([&](omp::TeamsOp op) {
410 checkAllocate(op,
result);
412 checkNumTeams(op,
result);
413 checkThreadLimit(op,
result);
415 .Case([&](omp::TaskOp op) {
416 checkAffinity(op,
result);
417 checkAllocate(op,
result);
418 checkInReduction(op,
result);
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op,
result);
422 checkTaskReduction(op,
result);
424 .Case([&](omp::TaskwaitOp op) {
428 .Case([&](omp::TaskloopOp op) {
429 checkAllocate(op,
result);
430 checkInReduction(op,
result);
431 checkReduction(op,
result);
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op,
result);
436 checkReduction(op,
result);
438 .Case([&](omp::ParallelOp op) {
439 checkAllocate(op,
result);
440 checkReduction(op,
result);
441 checkNumThreads(op,
result);
443 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
444 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
445 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
446 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
447 [&](
auto op) { checkDepend(op,
result); })
448 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
449 .Case([&](omp::TargetOp op) {
450 checkAllocate(op,
result);
452 checkInReduction(op,
result);
453 checkThreadLimit(op,
result);
465 llvm::handleAllErrors(
467 [&](
const PreviouslyReportedError &) {
result = failure(); },
468 [&](
const llvm::ErrorInfoBase &err) {
485static llvm::OpenMPIRBuilder::InsertPointTy
491 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
493 [&](OpenMPAllocaStackFrame &frame) {
494 allocaInsertPoint = frame.allocaInsertPoint;
502 allocaInsertPoint.getBlock()->getParent() ==
503 builder.GetInsertBlock()->getParent())
504 return allocaInsertPoint;
513 if (builder.GetInsertBlock() ==
514 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
515 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
516 "Assuming end of basic block");
517 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
518 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
519 builder.GetInsertBlock()->getNextNode());
520 builder.CreateBr(entryBB);
521 builder.SetInsertPoint(entryBB);
524 llvm::BasicBlock &funcEntryBlock =
525 builder.GetInsertBlock()->getParent()->getEntryBlock();
526 return llvm::OpenMPIRBuilder::InsertPointTy(
527 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
533static llvm::CanonicalLoopInfo *
535 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
536 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
537 [&](OpenMPLoopInfoStackFrame &frame) {
538 loopInfo = frame.loopInfo;
550 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
553 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
555 llvm::BasicBlock *continuationBlock =
556 splitBB(builder,
true,
"omp.region.cont");
557 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
559 llvm::LLVMContext &llvmContext = builder.getContext();
560 for (
Block &bb : region) {
561 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
562 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
563 builder.GetInsertBlock()->getNextNode());
564 moduleTranslation.
mapBlock(&bb, llvmBB);
567 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
574 unsigned numYields = 0;
576 if (!isLoopWrapper) {
577 bool operandsProcessed =
false;
579 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
580 if (!operandsProcessed) {
581 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
582 continuationBlockPHITypes.push_back(
583 moduleTranslation.
convertType(yield->getOperand(i).getType()));
585 operandsProcessed =
true;
587 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
588 "mismatching number of values yielded from the region");
589 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
590 llvm::Type *operandType =
591 moduleTranslation.
convertType(yield->getOperand(i).getType());
593 assert(continuationBlockPHITypes[i] == operandType &&
594 "values of mismatching types yielded from the region");
604 if (!continuationBlockPHITypes.empty())
606 continuationBlockPHIs &&
607 "expected continuation block PHIs if converted regions yield values");
608 if (continuationBlockPHIs) {
609 llvm::IRBuilderBase::InsertPointGuard guard(builder);
610 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
611 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
612 for (llvm::Type *ty : continuationBlockPHITypes)
613 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
619 for (
Block *bb : blocks) {
620 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
623 if (bb->isEntryBlock()) {
624 assert(sourceTerminator->getNumSuccessors() == 1 &&
625 "provided entry block has multiple successors");
626 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
627 "ContinuationBlock is not the successor of the entry block");
628 sourceTerminator->setSuccessor(0, llvmBB);
631 llvm::IRBuilderBase::InsertPointGuard guard(builder);
633 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
634 return llvm::make_error<PreviouslyReportedError>();
639 builder.CreateBr(continuationBlock);
650 Operation *terminator = bb->getTerminator();
651 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
652 builder.CreateBr(continuationBlock);
654 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
655 (*continuationBlockPHIs)[i]->addIncoming(
669 return continuationBlock;
675 case omp::ClauseProcBindKind::Close:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
677 case omp::ClauseProcBindKind::Master:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
679 case omp::ClauseProcBindKind::Primary:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
681 case omp::ClauseProcBindKind::Spread:
682 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
684 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
691 auto maskedOp = cast<omp::MaskedOp>(opInst);
692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
697 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
699 auto ®ion = maskedOp.getRegion();
700 builder.restoreIP(codeGenIP);
708 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
710 llvm::Value *filterVal =
nullptr;
711 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
712 filterVal = moduleTranslation.
lookupValue(filterVar);
714 llvm::LLVMContext &llvmContext = builder.getContext();
716 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
718 assert(filterVal !=
nullptr);
719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
720 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
727 builder.restoreIP(*afterIP);
735 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
736 auto masterOp = cast<omp::MasterOp>(opInst);
741 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
743 auto ®ion = masterOp.getRegion();
744 builder.restoreIP(codeGenIP);
752 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
762 builder.restoreIP(*afterIP);
770 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
771 auto criticalOp = cast<omp::CriticalOp>(opInst);
776 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
778 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
779 builder.restoreIP(codeGenIP);
787 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
789 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
790 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
791 llvm::Constant *hint =
nullptr;
794 if (criticalOp.getNameAttr()) {
797 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
798 auto criticalDeclareOp =
802 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
803 static_cast<int>(criticalDeclareOp.getHint()));
805 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
807 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
812 builder.restoreIP(*afterIP);
819 template <
typename OP>
822 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
825 collectPrivatizationDecls<OP>(op);
840 void collectPrivatizationDecls(OP op) {
841 std::optional<ArrayAttr> attr = op.getPrivateSyms();
846 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
857 std::optional<ArrayAttr> attr = op.getReductionSyms();
861 reductions.reserve(reductions.size() + op.getNumReductionVars());
862 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
863 reductions.push_back(
875 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
884 llvm::Instruction *potentialTerminator =
885 builder.GetInsertBlock()->empty() ?
nullptr
886 : &builder.GetInsertBlock()->back();
888 if (potentialTerminator && potentialTerminator->isTerminator())
889 potentialTerminator->removeFromParent();
890 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
893 region.
front(),
true, builder)))
897 if (continuationBlockArgs)
899 *continuationBlockArgs,
906 if (potentialTerminator && potentialTerminator->isTerminator()) {
907 llvm::BasicBlock *block = builder.GetInsertBlock();
908 if (block->empty()) {
914 potentialTerminator->insertInto(block, block->begin());
916 potentialTerminator->insertAfter(&block->back());
930 if (continuationBlockArgs)
931 llvm::append_range(*continuationBlockArgs, phis);
932 builder.SetInsertPoint(*continuationBlock,
933 (*continuationBlock)->getFirstInsertionPt());
940using OwningReductionGen =
941 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
942 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
944using OwningAtomicReductionGen =
945 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
946 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
948using OwningDataPtrPtrReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
956static OwningReductionGen
962 OwningReductionGen gen =
963 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
964 llvm::Value *
lhs, llvm::Value *
rhs,
965 llvm::Value *&
result)
mutable
966 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
967 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
968 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
969 builder.restoreIP(insertPoint);
972 "omp.reduction.nonatomic.body", builder,
973 moduleTranslation, &phis)))
974 return llvm::createStringError(
975 "failed to inline `combiner` region of `omp.declare_reduction`");
976 result = llvm::getSingleElement(phis);
977 return builder.saveIP();
986static OwningAtomicReductionGen
988 llvm::IRBuilderBase &builder,
990 if (decl.getAtomicReductionRegion().empty())
991 return OwningAtomicReductionGen();
997 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
998 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1001 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1002 builder.restoreIP(insertPoint);
1005 "omp.reduction.atomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `atomic` region of `omp.declare_reduction`");
1009 assert(phis.empty());
1010 return builder.saveIP();
1019static OwningDataPtrPtrReductionGen
1023 return OwningDataPtrPtrReductionGen();
1025 OwningDataPtrPtrReductionGen refDataPtrGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1027 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1030 builder.restoreIP(insertPoint);
1033 "omp.data_ptr_ptr.body", builder,
1034 moduleTranslation, &phis)))
1035 return llvm::createStringError(
1036 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1037 result = llvm::getSingleElement(phis);
1038 return builder.saveIP();
1041 return refDataPtrGen;
1048 auto orderedOp = cast<omp::OrderedOp>(opInst);
1053 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1054 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1055 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1057 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1059 size_t indexVecValues = 0;
1060 while (indexVecValues < vecValues.size()) {
1062 storeValues.reserve(numLoops);
1063 for (
unsigned i = 0; i < numLoops; i++) {
1064 storeValues.push_back(vecValues[indexVecValues]);
1067 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1070 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1071 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1082 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1087 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1089 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1090 builder.restoreIP(codeGenIP);
1098 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1100 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1101 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1103 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1108 builder.restoreIP(*afterIP);
1114struct DeferredStore {
1115 DeferredStore(llvm::Value *value, llvm::Value *address)
1116 : value(value), address(address) {}
1119 llvm::Value *address;
1126template <
typename T>
1129 llvm::IRBuilderBase &builder,
1131 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1137 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1141 deferredStores.reserve(loop.getNumReductionVars());
1143 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1144 Region &allocRegion = reductionDecls[i].getAllocRegion();
1146 if (allocRegion.
empty())
1151 builder, moduleTranslation, &phis)))
1152 return loop.emitError(
1153 "failed to inline `alloc` region of `omp.declare_reduction`");
1155 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1156 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.
convertType(reductionDecls[i].getType()));
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 llvm::Value *castPhi =
1167 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1169 deferredStores.emplace_back(castPhi, castVar);
1171 privateReductionVariables[i] = castVar;
1172 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1173 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1175 assert(allocRegion.
empty() &&
1176 "allocaction is implicit for by-val reduction");
1177 llvm::Value *var = builder.CreateAlloca(
1178 moduleTranslation.
convertType(reductionDecls[i].getType()));
1180 llvm::Type *ptrTy = builder.getPtrTy();
1181 llvm::Value *castVar =
1182 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1184 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1185 privateReductionVariables[i] = castVar;
1186 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1194template <
typename T>
1197 llvm::IRBuilderBase &builder,
1202 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1203 Region &initializerRegion = reduction.getInitializerRegion();
1206 mlir::Value mlirSource = loop.getReductionVars()[i];
1207 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1208 llvm::Value *origVal = llvmSource;
1210 if (!isa<LLVM::LLVMPointerType>(
1211 reduction.getInitializerMoldArg().getType()) &&
1212 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1215 reduction.getInitializerMoldArg().getType()),
1216 llvmSource,
"omp_orig");
1218 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1221 llvm::Value *allocation =
1222 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1223 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1229 llvm::BasicBlock *block =
nullptr) {
1230 if (block ==
nullptr)
1231 block = builder.GetInsertBlock();
1233 if (block->empty() || block->getTerminator() ==
nullptr)
1234 builder.SetInsertPoint(block);
1236 builder.SetInsertPoint(block->getTerminator());
1244template <
typename OP>
1247 llvm::IRBuilderBase &builder,
1249 llvm::BasicBlock *latestAllocaBlock,
1255 if (op.getNumReductionVars() == 0)
1258 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1259 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1260 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1261 builder.restoreIP(allocaIP);
1264 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1266 if (!reductionDecls[i].getAllocRegion().empty())
1272 byRefVars[i] = builder.CreateAlloca(
1273 moduleTranslation.
convertType(reductionDecls[i].getType()));
1281 for (
auto [data, addr] : deferredStores)
1282 builder.CreateStore(data, addr);
1287 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1292 reductionVariableMap, i);
1300 "omp.reduction.neutral", builder,
1301 moduleTranslation, &phis)))
1304 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1305 "reduction neutral element declaration region");
1310 if (!reductionDecls[i].getAllocRegion().empty())
1319 builder.CreateStore(phis[0], byRefVars[i]);
1321 privateReductionVariables[i] = byRefVars[i];
1322 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1323 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1326 builder.CreateStore(phis[0], privateReductionVariables[i]);
1333 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1340template <
typename T>
1341static void collectReductionInfo(
1342 T loop, llvm::IRBuilderBase &builder,
1351 unsigned numReductions = loop.getNumReductionVars();
1353 for (
unsigned i = 0; i < numReductions; ++i) {
1356 owningAtomicReductionGens.push_back(
1359 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1363 reductionInfos.reserve(numReductions);
1364 for (
unsigned i = 0; i < numReductions; ++i) {
1365 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1366 if (owningAtomicReductionGens[i])
1367 atomicGen = owningAtomicReductionGens[i];
1368 llvm::Value *variable =
1369 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1372 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1373 allocatedType = alloca.getElemType();
1380 reductionInfos.push_back(
1382 privateReductionVariables[i],
1383 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1387 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1388 reductionDecls[i].getByrefElementType()
1390 *reductionDecls[i].getByrefElementType())
1400 llvm::IRBuilderBase &builder, StringRef regionName,
1401 bool shouldLoadCleanupRegionArg =
true) {
1402 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1403 if (cleanupRegion->empty())
1409 llvm::Instruction *potentialTerminator =
1410 builder.GetInsertBlock()->empty() ?
nullptr
1411 : &builder.GetInsertBlock()->back();
1412 if (potentialTerminator && potentialTerminator->isTerminator())
1413 builder.SetInsertPoint(potentialTerminator);
1414 llvm::Value *privateVarValue =
1415 shouldLoadCleanupRegionArg
1416 ? builder.CreateLoad(
1418 privateVariables[i])
1419 : privateVariables[i];
1424 moduleTranslation)))
1437 OP op, llvm::IRBuilderBase &builder,
1439 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1442 bool isNowait =
false,
bool isTeamsReduction =
false) {
1444 if (op.getNumReductionVars() == 0)
1456 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1458 owningReductionGenRefDataPtrGens,
1459 privateReductionVariables, reductionInfos, isByRef);
1464 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1465 builder.SetInsertPoint(tempTerminator);
1466 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1467 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1468 isByRef, isNowait, isTeamsReduction);
1473 if (!contInsertPoint->getBlock())
1474 return op->emitOpError() <<
"failed to convert reductions";
1476 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1477 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1482 tempTerminator->eraseFromParent();
1483 builder.restoreIP(*afterIP);
1487 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1488 [](omp::DeclareReductionOp reductionDecl) {
1489 return &reductionDecl.getCleanupRegion();
1492 moduleTranslation, builder,
1493 "omp.reduction.cleanup");
1504template <
typename OP>
1508 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1513 if (op.getNumReductionVars() == 0)
1519 allocaIP, reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 deferredStores, isByRef)))
1524 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1525 allocaIP.getBlock(), reductionDecls,
1526 privateReductionVariables, reductionVariableMap,
1527 isByRef, deferredStores);
1541 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1544 Value blockArg = (*mappedPrivateVars)[privateVar];
1547 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1548 "A block argument corresponding to a mapped var should have "
1551 if (privVarType == blockArgType)
1558 if (!isa<LLVM::LLVMPointerType>(privVarType))
1559 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1572 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1574 llvm::BasicBlock *privInitBlock,
1576 Region &initRegion = privDecl.getInitRegion();
1577 if (initRegion.
empty())
1578 return llvmPrivateVar;
1580 assert(nonPrivateVar);
1581 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1582 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1587 moduleTranslation, &phis)))
1588 return llvm::createStringError(
1589 "failed to inline `init` region of `omp.private`");
1591 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1608 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1611 builder, moduleTranslation, privDecl,
1614 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1623 return llvm::Error::success();
1625 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1628 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1631 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1633 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1634 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1637 return privVarOrErr.takeError();
1639 llvmPrivateVar = privVarOrErr.get();
1640 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1645 return llvm::Error::success();
1655 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1658 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1659 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1660 allocaTerminator->getIterator()),
1661 true, allocaTerminator->getStableDebugLoc(),
1662 "omp.region.after_alloca");
1664 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1666 allocaTerminator = allocaIP.getBlock()->getTerminator();
1667 builder.SetInsertPoint(allocaTerminator);
1669 assert(allocaTerminator->getNumSuccessors() == 1 &&
1670 "This is an unconditional branch created by splitBB");
1672 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1673 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1675 unsigned int allocaAS =
1676 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1679 .getProgramAddressSpace();
1681 for (
auto [privDecl, mlirPrivVar, blockArg] :
1684 llvm::Type *llvmAllocType =
1685 moduleTranslation.
convertType(privDecl.getType());
1686 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1687 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1688 llvmAllocType,
nullptr,
"omp.private.alloc");
1689 if (allocaAS != defaultAS)
1690 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1691 builder.getPtrTy(defaultAS));
1693 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1696 return afterAllocas;
1704 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1713 if (mlir::isa<omp::ParallelOp>(parent))
1727 bool needsFirstprivate =
1728 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1729 return privOp.getDataSharingType() ==
1730 omp::DataSharingClauseType::FirstPrivate;
1733 if (!needsFirstprivate)
1736 llvm::BasicBlock *copyBlock =
1737 splitBB(builder,
true,
"omp.private.copy");
1740 for (
auto [decl, moldVar, llvmVar] :
1741 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1742 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1746 Region ©Region = decl.getCopyRegion();
1748 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1751 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1755 moduleTranslation)))
1756 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1771 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1787 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1789 llvm::Value *moldVar = findAssociatedValue(
1790 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1795 llvmPrivateVars, privateDecls, insertBarrier,
1806 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1807 [](omp::PrivateClauseOp privatizer) {
1808 return &privatizer.getDeallocRegion();
1812 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1813 "omp.private.dealloc",
false)))
1814 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1815 "`omp.private` op in");
1827 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1838 using StorableBodyGenCallbackTy =
1839 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1841 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1847 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1855 sectionsOp.getNumReductionVars());
1859 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1862 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1863 reductionDecls, privateReductionVariables, reductionVariableMap,
1870 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1874 Region ®ion = sectionOp.getRegion();
1875 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1876 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1877 builder.restoreIP(codeGenIP);
1884 sectionsOp.getRegion().getNumArguments());
1885 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1886 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1887 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1889 moduleTranslation.
mapValue(sectionArg, llvmVal);
1896 sectionCBs.push_back(sectionCB);
1902 if (sectionCBs.empty())
1905 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1910 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1911 llvm::Value &vPtr, llvm::Value *&replacementValue)
1912 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1913 replacementValue = &vPtr;
1919 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1926 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1927 sectionsOp.getNowait());
1932 builder.restoreIP(*afterIP);
1936 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1937 privateReductionVariables, isByRef, sectionsOp.getNowait());
1944 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1950 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1951 builder.restoreIP(codegenIP);
1953 builder, moduleTranslation)
1956 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1960 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1963 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1964 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1966 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1967 llvmCPFuncs.push_back(
1971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1973 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1979 builder.restoreIP(*afterIP);
1985 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1990 for (
auto ra : iface.getReductionBlockArgs())
1991 for (
auto &use : ra.getUses()) {
1992 auto *useOp = use.getOwner();
1994 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1995 debugUses.push_back(useOp);
1999 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2004 Operation *currentOp = currentDistOp.getOperation();
2005 if (distOp && (distOp != currentOp))
2014 for (
auto *use : debugUses)
2023 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2028 unsigned numReductionVars = op.getNumReductionVars();
2032 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2038 if (doTeamsReduction) {
2039 isByRef =
getIsByRef(op.getReductionByref());
2041 assert(isByRef.size() == op.getNumReductionVars());
2044 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2049 op, reductionArgs, builder, moduleTranslation, allocaIP,
2050 reductionDecls, privateReductionVariables, reductionVariableMap,
2055 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2057 moduleTranslation, allocaIP);
2058 builder.restoreIP(codegenIP);
2064 llvm::Value *numTeamsLower =
nullptr;
2065 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2066 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2068 llvm::Value *numTeamsUpper =
nullptr;
2069 if (!op.getNumTeamsUpperVars().empty())
2070 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2072 llvm::Value *threadLimit =
nullptr;
2073 if (!op.getThreadLimitVars().empty())
2074 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2076 llvm::Value *ifExpr =
nullptr;
2077 if (
Value ifVar = op.getIfExpr())
2080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2081 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2083 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2088 builder.restoreIP(*afterIP);
2089 if (doTeamsReduction) {
2092 op, builder, moduleTranslation, allocaIP, reductionDecls,
2093 privateReductionVariables, isByRef,
2103 if (dependVars.empty())
2105 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2106 llvm::omp::RTLDependenceKindTy type;
2108 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2109 case mlir::omp::ClauseTaskDepend::taskdependin:
2110 type = llvm::omp::RTLDependenceKindTy::DepIn;
2115 case mlir::omp::ClauseTaskDepend::taskdependout:
2116 case mlir::omp::ClauseTaskDepend::taskdependinout:
2117 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2119 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2120 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2122 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2126 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2140 llvm::IRBuilderBase &llvmBuilder,
2142 llvm::omp::Directive cancelDirective) {
2143 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2144 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2148 llvmBuilder.restoreIP(ip);
2154 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2155 return llvm::Error::success();
2160 ompBuilder.pushFinalizationCB(
2170 llvm::OpenMPIRBuilder &ompBuilder,
2171 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2172 ompBuilder.popFinalizationCB();
2173 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2174 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2175 assert(cancelBranch->getNumSuccessors() == 1 &&
2176 "cancel branch should have one target");
2177 cancelBranch->setSuccessor(0, constructFini);
2184class TaskContextStructManager {
2186 TaskContextStructManager(llvm::IRBuilderBase &builder,
2187 LLVM::ModuleTranslation &moduleTranslation,
2188 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2189 : builder{builder}, moduleTranslation{moduleTranslation},
2190 privateDecls{privateDecls} {}
2196 void generateTaskContextStruct();
2202 void createGEPsToPrivateVars();
2208 SmallVector<llvm::Value *>
2209 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2212 void freeStructPtr();
2214 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2215 return llvmPrivateVarGEPs;
2218 llvm::Value *getStructPtr() {
return structPtr; }
2221 llvm::IRBuilderBase &builder;
2222 LLVM::ModuleTranslation &moduleTranslation;
2223 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2226 SmallVector<llvm::Type *> privateVarTypes;
2230 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2233 llvm::Value *structPtr =
nullptr;
2235 llvm::Type *structTy =
nullptr;
2239void TaskContextStructManager::generateTaskContextStruct() {
2240 if (privateDecls.empty())
2242 privateVarTypes.reserve(privateDecls.size());
2244 for (omp::PrivateClauseOp &privOp : privateDecls) {
2247 if (!privOp.readsFromMold())
2249 Type mlirType = privOp.getType();
2250 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2253 if (privateVarTypes.empty())
2256 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2259 llvm::DataLayout dataLayout =
2260 builder.GetInsertBlock()->getModule()->getDataLayout();
2261 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2262 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2265 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2267 "omp.task.context_ptr");
2270SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2271 llvm::Value *altStructPtr)
const {
2272 SmallVector<llvm::Value *> ret;
2275 ret.reserve(privateDecls.size());
2276 llvm::Value *zero = builder.getInt32(0);
2278 for (
auto privDecl : privateDecls) {
2279 if (!privDecl.readsFromMold()) {
2281 ret.push_back(
nullptr);
2284 llvm::Value *iVal = builder.getInt32(i);
2285 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2292void TaskContextStructManager::createGEPsToPrivateVars() {
2294 assert(privateVarTypes.empty());
2298 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2301void TaskContextStructManager::freeStructPtr() {
2305 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2307 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2308 builder.CreateFree(structPtr);
2315 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2320 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2337 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2338 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2339 builder.getContext(),
"omp.task.start",
2340 builder.GetInsertBlock()->getParent());
2341 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2342 builder.SetInsertPoint(branchToTaskStartBlock);
2345 llvm::BasicBlock *copyBlock =
2346 splitBB(builder,
true,
"omp.private.copy");
2347 llvm::BasicBlock *initBlock =
2348 splitBB(builder,
true,
"omp.private.init");
2364 moduleTranslation, allocaIP);
2367 builder.SetInsertPoint(initBlock->getTerminator());
2370 taskStructMgr.generateTaskContextStruct();
2377 taskStructMgr.createGEPsToPrivateVars();
2379 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2382 taskStructMgr.getLLVMPrivateVarGEPs())) {
2384 if (!privDecl.readsFromMold())
2386 assert(llvmPrivateVarAlloc &&
2387 "reads from mold so shouldn't have been skipped");
2390 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2391 blockArg, llvmPrivateVarAlloc, initBlock);
2392 if (!privateVarOrErr)
2393 return handleError(privateVarOrErr, *taskOp.getOperation());
2402 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2403 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2404 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2406 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2407 llvmPrivateVarAlloc);
2409 assert(llvmPrivateVarAlloc->getType() ==
2410 moduleTranslation.
convertType(blockArg.getType()));
2420 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2421 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2422 taskOp.getPrivateNeedsBarrier())))
2423 return llvm::failure();
2426 builder.SetInsertPoint(taskStartBlock);
2428 auto bodyCB = [&](InsertPointTy allocaIP,
2429 InsertPointTy codegenIP) -> llvm::Error {
2433 moduleTranslation, allocaIP);
2436 builder.restoreIP(codegenIP);
2438 llvm::BasicBlock *privInitBlock =
nullptr;
2440 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2443 auto [blockArg, privDecl, mlirPrivVar] = zip;
2445 if (privDecl.readsFromMold())
2448 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2449 llvm::Type *llvmAllocType =
2450 moduleTranslation.
convertType(privDecl.getType());
2451 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2452 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2453 llvmAllocType,
nullptr,
"omp.private.alloc");
2456 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2457 blockArg, llvmPrivateVar, privInitBlock);
2458 if (!privateVarOrError)
2459 return privateVarOrError.takeError();
2460 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2461 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2464 taskStructMgr.createGEPsToPrivateVars();
2465 for (
auto [i, llvmPrivVar] :
2466 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2468 assert(privateVarsInfo.
llvmVars[i] &&
2469 "This is added in the loop above");
2472 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2477 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2481 if (!privateDecl.readsFromMold())
2484 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2485 llvmPrivateVar = builder.CreateLoad(
2486 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2488 assert(llvmPrivateVar->getType() ==
2489 moduleTranslation.
convertType(blockArg.getType()));
2490 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2494 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2495 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2496 return llvm::make_error<PreviouslyReportedError>();
2498 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2503 return llvm::make_error<PreviouslyReportedError>();
2506 taskStructMgr.freeStructPtr();
2508 return llvm::Error::success();
2517 llvm::omp::Directive::OMPD_taskgroup);
2521 moduleTranslation, dds);
2523 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2526 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2528 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2529 taskOp.getMergeable(),
2530 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2531 moduleTranslation.
lookupValue(taskOp.getPriority()));
2539 builder.restoreIP(*afterIP);
2547 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2548 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2556 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2562 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2563 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2564 builder.getContext(),
"omp.taskloop.start",
2565 builder.GetInsertBlock()->getParent());
2566 llvm::Instruction *branchToTaskloopStartBlock =
2567 builder.CreateBr(taskloopStartBlock);
2568 builder.SetInsertPoint(branchToTaskloopStartBlock);
2570 llvm::BasicBlock *copyBlock =
2571 splitBB(builder,
true,
"omp.private.copy");
2572 llvm::BasicBlock *initBlock =
2573 splitBB(builder,
true,
"omp.private.init");
2576 moduleTranslation, allocaIP);
2579 builder.SetInsertPoint(initBlock->getTerminator());
2582 taskStructMgr.generateTaskContextStruct();
2583 taskStructMgr.createGEPsToPrivateVars();
2585 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
2587 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2589 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2590 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2592 if (!privDecl.readsFromMold())
2594 assert(llvmPrivateVarAlloc &&
2595 "reads from mold so shouldn't have been skipped");
2598 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2599 blockArg, llvmPrivateVarAlloc, initBlock);
2600 if (!privateVarOrErr)
2601 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2603 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2606 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2608 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2609 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2610 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2612 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2613 llvmPrivateVarAlloc);
2615 assert(llvmPrivateVarAlloc->getType() ==
2616 moduleTranslation.
convertType(blockArg.getType()));
2622 taskloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2623 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2624 taskloopOp.getPrivateNeedsBarrier())))
2625 return llvm::failure();
2628 builder.SetInsertPoint(taskloopStartBlock);
2630 auto bodyCB = [&](InsertPointTy allocaIP,
2631 InsertPointTy codegenIP) -> llvm::Error {
2635 moduleTranslation, allocaIP);
2638 builder.restoreIP(codegenIP);
2640 llvm::BasicBlock *privInitBlock =
nullptr;
2642 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2645 auto [blockArg, privDecl, mlirPrivVar] = zip;
2647 if (privDecl.readsFromMold())
2650 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2651 llvm::Type *llvmAllocType =
2652 moduleTranslation.
convertType(privDecl.getType());
2653 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2654 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2655 llvmAllocType,
nullptr,
"omp.private.alloc");
2658 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2659 blockArg, llvmPrivateVar, privInitBlock);
2660 if (!privateVarOrError)
2661 return privateVarOrError.takeError();
2662 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2663 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2666 taskStructMgr.createGEPsToPrivateVars();
2667 for (
auto [i, llvmPrivVar] :
2668 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2670 assert(privateVarsInfo.
llvmVars[i] &&
2671 "This is added in the loop above");
2674 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2679 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2683 if (!privateDecl.readsFromMold())
2686 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2687 llvmPrivateVar = builder.CreateLoad(
2688 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2690 assert(llvmPrivateVar->getType() ==
2691 moduleTranslation.
convertType(blockArg.getType()));
2692 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2695 auto continuationBlockOrError =
2697 builder, moduleTranslation);
2699 if (failed(
handleError(continuationBlockOrError, opInst)))
2700 return llvm::make_error<PreviouslyReportedError>();
2702 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2710 taskloopOp.getLoc(), privateVarsInfo.
llvmVars,
2712 return llvm::make_error<PreviouslyReportedError>();
2715 taskStructMgr.freeStructPtr();
2717 return llvm::Error::success();
2723 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2724 llvm::Value *destPtr, llvm::Value *srcPtr)
2726 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2727 builder.restoreIP(codegenIP);
2730 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2732 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
2734 TaskContextStructManager &srcStructMgr = taskStructMgr;
2735 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2737 destStructMgr.generateTaskContextStruct();
2738 llvm::Value *dest = destStructMgr.getStructPtr();
2739 dest->setName(
"omp.taskloop.context.dest");
2740 builder.CreateStore(dest, destPtr);
2743 srcStructMgr.createGEPsToPrivateVars(src);
2745 destStructMgr.createGEPsToPrivateVars(dest);
2748 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2749 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
2752 if (!privDecl.readsFromMold())
2754 assert(llvmPrivateVarAlloc &&
2755 "reads from mold so shouldn't have been skipped");
2758 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2759 llvmPrivateVarAlloc, builder.GetInsertBlock());
2760 if (!privateVarOrErr)
2761 return privateVarOrErr.takeError();
2770 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2771 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2772 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2774 llvmPrivateVarAlloc = builder.CreateLoad(
2775 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2777 assert(llvmPrivateVarAlloc->getType() ==
2778 moduleTranslation.
convertType(blockArg.getType()));
2786 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2787 privateVarsInfo.
privatizers, taskloopOp.getPrivateNeedsBarrier())))
2788 return llvm::make_error<PreviouslyReportedError>();
2790 return builder.saveIP();
2793 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2803 llvm::Type *boundType =
2804 moduleTranslation.
lookupValue(lowerBounds[0])->getType();
2805 llvm::Value *lbVal =
nullptr;
2806 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2807 llvm::Value *stepVal =
nullptr;
2808 if (loopOp.getCollapseNumLoops() > 1) {
2826 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2827 llvm::Value *loopLb = moduleTranslation.
lookupValue(lowerBounds[i]);
2828 llvm::Value *loopUb = moduleTranslation.
lookupValue(upperBounds[i]);
2829 llvm::Value *loopStep = moduleTranslation.
lookupValue(steps[i]);
2835 llvm::Value *loopLbMinusOne = builder.CreateSub(
2836 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2837 llvm::Value *loopUbMinusOne = builder.CreateSub(
2838 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2839 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2840 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2841 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2842 llvm::Value *loopTripCount =
2843 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2844 loopTripCount = builder.CreateBinaryIntrinsic(
2845 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2849 llvm::Value *loopTripCountDivStep =
2850 builder.CreateSDiv(loopTripCount, loopStep);
2851 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2852 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2853 llvm::Value *loopTripCountRem =
2854 builder.CreateSRem(loopTripCount, loopStep);
2855 loopTripCountRem = builder.CreateBinaryIntrinsic(
2856 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2857 llvm::Value *needsRoundUp = builder.CreateICmpNE(
2859 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2862 builder.CreateAdd(loopTripCountDivStep,
2863 builder.CreateZExtOrTrunc(
2864 needsRoundUp, loopTripCountDivStep->getType()));
2865 ubVal = builder.CreateMul(ubVal, loopTripCount);
2867 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2868 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2870 lbVal = moduleTranslation.
lookupValue(lowerBounds[0]);
2871 ubVal = moduleTranslation.
lookupValue(upperBounds[0]);
2872 stepVal = moduleTranslation.
lookupValue(steps[0]);
2874 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
2875 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
2876 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
2878 llvm::Value *ifCond =
nullptr;
2879 llvm::Value *grainsize =
nullptr;
2881 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2882 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2883 if (
Value ifVar = taskloopOp.getIfExpr())
2886 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
2888 }
else if (numTasksVal) {
2889 grainsize = moduleTranslation.
lookupValue(numTasksVal);
2893 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
2894 if (taskStructMgr.getStructPtr())
2895 taskDupOrNull = taskDupCB;
2905 llvm::omp::Directive::OMPD_taskgroup);
2907 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2908 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2910 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
2911 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2912 sched, moduleTranslation.
lookupValue(taskloopOp.getFinal()),
2913 taskloopOp.getMergeable(),
2914 moduleTranslation.
lookupValue(taskloopOp.getPriority()),
2915 loopOp.getCollapseNumLoops(), taskDupOrNull,
2916 taskStructMgr.getStructPtr());
2923 builder.restoreIP(*afterIP);
2931 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2935 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2936 builder.restoreIP(codegenIP);
2938 builder, moduleTranslation)
2943 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2944 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2951 builder.restoreIP(*afterIP);
2970 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2974 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2976 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2980 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2983 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2984 llvm::Type *ivType = step->getType();
2985 llvm::Value *chunk =
nullptr;
2986 if (wsloopOp.getScheduleChunk()) {
2987 llvm::Value *chunkVar =
2988 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2989 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2992 omp::DistributeOp distributeOp =
nullptr;
2993 llvm::Value *distScheduleChunk =
nullptr;
2994 bool hasDistSchedule =
false;
2995 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
2996 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
2997 hasDistSchedule = distributeOp.getDistScheduleStatic();
2998 if (distributeOp.getDistScheduleChunkSize()) {
2999 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3000 distributeOp.getDistScheduleChunkSize());
3001 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3009 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3013 wsloopOp.getNumReductionVars());
3016 builder, moduleTranslation, privateVarsInfo, allocaIP);
3023 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3028 moduleTranslation, allocaIP, reductionDecls,
3029 privateReductionVariables, reductionVariableMap,
3030 deferredStores, isByRef)))
3039 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3041 wsloopOp.getPrivateNeedsBarrier())))
3044 assert(afterAllocas.get()->getSinglePredecessor());
3045 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3047 afterAllocas.get()->getSinglePredecessor(),
3048 reductionDecls, privateReductionVariables,
3049 reductionVariableMap, isByRef, deferredStores)))
3053 bool isOrdered = wsloopOp.getOrdered().has_value();
3054 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3055 bool isSimd = wsloopOp.getScheduleSimd();
3056 bool loopNeedsBarrier = !wsloopOp.getNowait();
3061 llvm::omp::WorksharingLoopType workshareLoopType =
3062 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3063 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3064 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3068 llvm::omp::Directive::OMPD_for);
3070 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3073 LinearClauseProcessor linearClauseProcessor;
3075 if (!wsloopOp.getLinearVars().empty()) {
3076 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3078 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3080 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3081 linearClauseProcessor.createLinearVar(
3082 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3084 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3085 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3089 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3097 if (!wsloopOp.getLinearVars().empty()) {
3098 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3099 loopInfo->getPreheader());
3100 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3102 builder.saveIP(), llvm::omp::OMPD_barrier);
3105 builder.restoreIP(*afterBarrierIP);
3106 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3107 loopInfo->getIndVar());
3108 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3111 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3114 bool noLoopMode =
false;
3115 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3117 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3121 if (loopOp == targetCapturedOp) {
3122 omp::TargetRegionFlags kernelFlags =
3123 targetOp.getKernelExecFlags(targetCapturedOp);
3124 if (omp::bitEnumContainsAll(kernelFlags,
3125 omp::TargetRegionFlags::spmd |
3126 omp::TargetRegionFlags::no_loop) &&
3127 !omp::bitEnumContainsAny(kernelFlags,
3128 omp::TargetRegionFlags::generic))
3133 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3134 ompBuilder->applyWorkshareLoop(
3135 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3136 convertToScheduleKind(schedule), chunk, isSimd,
3137 scheduleMod == omp::ScheduleModifier::monotonic,
3138 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3139 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3145 if (!wsloopOp.getLinearVars().empty()) {
3146 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3147 assert(loopInfo->getLastIter() &&
3148 "`lastiter` in CanonicalLoopInfo is nullptr");
3149 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3150 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3151 loopInfo->getLastIter());
3154 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3155 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3157 builder.restoreIP(oldIP);
3165 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3166 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3179 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3181 assert(isByRef.size() == opInst.getNumReductionVars());
3194 opInst.getNumReductionVars());
3197 auto bodyGenCB = [&](InsertPointTy allocaIP,
3198 InsertPointTy codeGenIP) -> llvm::Error {
3200 builder, moduleTranslation, privateVarsInfo, allocaIP);
3202 return llvm::make_error<PreviouslyReportedError>();
3208 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3211 InsertPointTy(allocaIP.getBlock(),
3212 allocaIP.getBlock()->getTerminator()->getIterator());
3215 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3216 reductionDecls, privateReductionVariables, reductionVariableMap,
3217 deferredStores, isByRef)))
3218 return llvm::make_error<PreviouslyReportedError>();
3220 assert(afterAllocas.get()->getSinglePredecessor());
3221 builder.restoreIP(codeGenIP);
3227 return llvm::make_error<PreviouslyReportedError>();
3230 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3232 opInst.getPrivateNeedsBarrier())))
3233 return llvm::make_error<PreviouslyReportedError>();
3236 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3237 afterAllocas.get()->getSinglePredecessor(),
3238 reductionDecls, privateReductionVariables,
3239 reductionVariableMap, isByRef, deferredStores)))
3240 return llvm::make_error<PreviouslyReportedError>();
3245 moduleTranslation, allocaIP);
3249 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3251 return regionBlock.takeError();
3254 if (opInst.getNumReductionVars() > 0) {
3259 owningReductionGenRefDataPtrGens;
3261 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3263 owningReductionGenRefDataPtrGens,
3264 privateReductionVariables, reductionInfos, isByRef);
3267 builder.SetInsertPoint((*regionBlock)->getTerminator());
3270 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3271 builder.SetInsertPoint(tempTerminator);
3273 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3274 ompBuilder->createReductions(
3275 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3277 if (!contInsertPoint)
3278 return contInsertPoint.takeError();
3280 if (!contInsertPoint->getBlock())
3281 return llvm::make_error<PreviouslyReportedError>();
3283 tempTerminator->eraseFromParent();
3284 builder.restoreIP(*contInsertPoint);
3287 return llvm::Error::success();
3290 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3291 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3300 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3301 InsertPointTy oldIP = builder.saveIP();
3302 builder.restoreIP(codeGenIP);
3307 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3308 [](omp::DeclareReductionOp reductionDecl) {
3309 return &reductionDecl.getCleanupRegion();
3312 reductionCleanupRegions, privateReductionVariables,
3313 moduleTranslation, builder,
"omp.reduction.cleanup")))
3314 return llvm::createStringError(
3315 "failed to inline `cleanup` region of `omp.declare_reduction`");
3320 return llvm::make_error<PreviouslyReportedError>();
3324 if (isCancellable) {
3325 auto IPOrErr = ompBuilder->createBarrier(
3326 llvm::OpenMPIRBuilder::LocationDescription(builder),
3327 llvm::omp::Directive::OMPD_unknown,
3331 return IPOrErr.takeError();
3334 builder.restoreIP(oldIP);
3335 return llvm::Error::success();
3338 llvm::Value *ifCond =
nullptr;
3339 if (
auto ifVar = opInst.getIfExpr())
3341 llvm::Value *numThreads =
nullptr;
3342 if (!opInst.getNumThreadsVars().empty())
3343 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
3344 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3345 if (
auto bind = opInst.getProcBindKind())
3348 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3350 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3352 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3353 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3354 ifCond, numThreads, pbKind, isCancellable);
3359 builder.restoreIP(*afterIP);
3364static llvm::omp::OrderKind
3367 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3369 case omp::ClauseOrderKind::Concurrent:
3370 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3372 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3380 auto simdOp = cast<omp::SimdOp>(opInst);
3388 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3391 simdOp.getNumReductionVars());
3396 assert(isByRef.size() == simdOp.getNumReductionVars());
3398 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3402 builder, moduleTranslation, privateVarsInfo, allocaIP);
3407 LinearClauseProcessor linearClauseProcessor;
3409 if (!simdOp.getLinearVars().empty()) {
3410 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3412 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3413 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3414 bool isImplicit =
false;
3415 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3419 if (linearVar == mlirPrivVar) {
3421 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3422 llvmPrivateVar, idx);
3428 linearClauseProcessor.createLinearVar(
3429 builder, moduleTranslation,
3432 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3433 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3437 moduleTranslation, allocaIP, reductionDecls,
3438 privateReductionVariables, reductionVariableMap,
3439 deferredStores, isByRef)))
3450 assert(afterAllocas.get()->getSinglePredecessor());
3451 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3453 afterAllocas.get()->getSinglePredecessor(),
3454 reductionDecls, privateReductionVariables,
3455 reductionVariableMap, isByRef, deferredStores)))
3458 llvm::ConstantInt *simdlen =
nullptr;
3459 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3460 simdlen = builder.getInt64(simdlenVar.value());
3462 llvm::ConstantInt *safelen =
nullptr;
3463 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3464 safelen = builder.getInt64(safelenVar.value());
3466 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3469 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3470 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3472 for (
size_t i = 0; i < operands.size(); ++i) {
3473 llvm::Value *alignment =
nullptr;
3474 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3475 llvm::Type *ty = llvmVal->getType();
3477 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3478 alignment = builder.getInt64(intAttr.getInt());
3479 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3480 assert(alignment &&
"Invalid alignment value");
3484 if (!intAttr.getValue().isPowerOf2())
3487 auto curInsert = builder.saveIP();
3488 builder.SetInsertPoint(sourceBlock);
3489 llvmVal = builder.CreateLoad(ty, llvmVal);
3490 builder.restoreIP(curInsert);
3491 alignedVars[llvmVal] = alignment;
3495 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3502 if (simdOp.getLinearVars().size()) {
3503 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3504 loopInfo->getPreheader());
3506 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3507 loopInfo->getIndVar());
3509 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3511 ompBuilder->applySimd(loopInfo, alignedVars,
3513 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3515 order, simdlen, safelen);
3517 linearClauseProcessor.emitStoresForLinearVar(builder);
3518 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++)
3519 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3526 for (
auto [i, tuple] : llvm::enumerate(
3527 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3528 privateReductionVariables))) {
3529 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3531 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3532 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3533 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3537 llvm::Value *redValue = originalVariable;
3540 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3541 llvm::Value *privateRedValue = builder.CreateLoad(
3542 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3543 llvm::Value *reduced;
3545 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3548 builder.restoreIP(res.get());
3552 builder.CreateStore(reduced, originalVariable);
3557 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3558 [](omp::DeclareReductionOp reductionDecl) {
3559 return &reductionDecl.getCleanupRegion();
3562 moduleTranslation, builder,
3563 "omp.reduction.cleanup")))
3576 auto loopOp = cast<omp::LoopNestOp>(opInst);
3582 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3587 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3588 llvm::Value *iv) -> llvm::Error {
3591 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3596 bodyInsertPoints.push_back(ip);
3598 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3599 return llvm::Error::success();
3602 builder.restoreIP(ip);
3604 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3606 return regionBlock.takeError();
3608 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3609 return llvm::Error::success();
3617 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3618 llvm::Value *lowerBound =
3619 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3620 llvm::Value *upperBound =
3621 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3622 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3627 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3628 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3630 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3632 computeIP = loopInfos.front()->getPreheaderIP();
3636 ompBuilder->createCanonicalLoop(
3637 loc, bodyGen, lowerBound, upperBound, step,
3638 true, loopOp.getLoopInclusive(), computeIP);
3643 loopInfos.push_back(*loopResult);
3646 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3647 loopInfos.front()->getAfterIP();
3650 if (
const auto &tiles = loopOp.getTileSizes()) {
3651 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3654 for (
auto tile : tiles.value()) {
3655 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3656 tileSizes.push_back(tileVal);
3659 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3660 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3664 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3665 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3666 afterIP = {afterAfterBB, afterAfterBB->begin()};
3670 for (
const auto &newLoop : newLoops)
3671 loopInfos.push_back(newLoop);
3675 const auto &numCollapse = loopOp.getCollapseNumLoops();
3677 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3679 auto newTopLoopInfo =
3680 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3682 assert(newTopLoopInfo &&
"New top loop information is missing");
3683 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3684 [&](OpenMPLoopInfoStackFrame &frame) {
3685 frame.loopInfo = newTopLoopInfo;
3693 builder.restoreIP(afterIP);
3703 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3704 Value loopIV = op.getInductionVar();
3705 Value loopTC = op.getTripCount();
3707 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3710 ompBuilder->createCanonicalLoop(
3712 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3715 moduleTranslation.
mapValue(loopIV, llvmIV);
3717 builder.restoreIP(ip);
3722 return bodyGenStatus.takeError();
3724 llvmTC,
"omp.loop");
3726 return op.emitError(llvm::toString(llvmOrError.takeError()));
3728 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3729 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3730 builder.restoreIP(afterIP);
3733 if (
Value cli = op.getCli())
3746 Value applyee = op.getApplyee();
3747 assert(applyee &&
"Loop to apply unrolling on required");
3749 llvm::CanonicalLoopInfo *consBuilderCLI =
3751 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3752 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3760static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3763 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3768 for (
Value size : op.getSizes()) {
3769 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3770 assert(translatedSize &&
3771 "sizes clause arguments must already be translated");
3772 translatedSizes.push_back(translatedSize);
3775 for (
Value applyee : op.getApplyees()) {
3776 llvm::CanonicalLoopInfo *consBuilderCLI =
3778 assert(applyee &&
"Canonical loop must already been translated");
3779 translatedLoops.push_back(consBuilderCLI);
3782 auto generatedLoops =
3783 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3784 if (!op.getGeneratees().empty()) {
3785 for (
auto [mlirLoop,
genLoop] :
3786 zip_equal(op.getGeneratees(), generatedLoops))
3791 for (
Value applyee : op.getApplyees())
3798static llvm::AtomicOrdering
3801 return llvm::AtomicOrdering::Monotonic;
3804 case omp::ClauseMemoryOrderKind::Seq_cst:
3805 return llvm::AtomicOrdering::SequentiallyConsistent;
3806 case omp::ClauseMemoryOrderKind::Acq_rel:
3807 return llvm::AtomicOrdering::AcquireRelease;
3808 case omp::ClauseMemoryOrderKind::Acquire:
3809 return llvm::AtomicOrdering::Acquire;
3810 case omp::ClauseMemoryOrderKind::Release:
3811 return llvm::AtomicOrdering::Release;
3812 case omp::ClauseMemoryOrderKind::Relaxed:
3813 return llvm::AtomicOrdering::Monotonic;
3815 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3822 auto readOp = cast<omp::AtomicReadOp>(opInst);
3827 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3830 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3833 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3834 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3836 llvm::Type *elementType =
3837 moduleTranslation.
convertType(readOp.getElementType());
3839 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3840 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3841 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3849 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3854 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3857 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3859 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3860 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3861 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3862 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3865 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3873 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3874 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3875 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3876 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3877 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3878 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3879 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3880 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3881 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3882 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3886 bool &isIgnoreDenormalMode,
3887 bool &isFineGrainedMemory,
3888 bool &isRemoteMemory) {
3889 isIgnoreDenormalMode =
false;
3890 isFineGrainedMemory =
false;
3891 isRemoteMemory =
false;
3892 if (atomicUpdateOp &&
3893 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3894 mlir::omp::AtomicControlAttr atomicControlAttr =
3895 atomicUpdateOp.getAtomicControlAttr();
3896 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3897 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3898 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3905 llvm::IRBuilderBase &builder,
3912 auto &innerOpList = opInst.getRegion().front().getOperations();
3913 bool isXBinopExpr{
false};
3914 llvm::AtomicRMWInst::BinOp binop;
3916 llvm::Value *llvmExpr =
nullptr;
3917 llvm::Value *llvmX =
nullptr;
3918 llvm::Type *llvmXElementType =
nullptr;
3919 if (innerOpList.size() == 2) {
3925 opInst.getRegion().getArgument(0))) {
3926 return opInst.emitError(
"no atomic update operation with region argument"
3927 " as operand found inside atomic.update region");
3930 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3932 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3936 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3938 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3940 opInst.getRegion().getArgument(0).getType());
3941 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3945 llvm::AtomicOrdering atomicOrdering =
3950 [&opInst, &moduleTranslation](
3951 llvm::Value *atomicx,
3954 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3955 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3956 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3957 return llvm::make_error<PreviouslyReportedError>();
3959 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3960 assert(yieldop && yieldop.getResults().size() == 1 &&
3961 "terminator must be omp.yield op and it must have exactly one "
3963 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3966 bool isIgnoreDenormalMode;
3967 bool isFineGrainedMemory;
3968 bool isRemoteMemory;
3973 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3974 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3975 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3976 atomicOrdering, binop, updateFn,
3977 isXBinopExpr, isIgnoreDenormalMode,
3978 isFineGrainedMemory, isRemoteMemory);
3983 builder.restoreIP(*afterIP);
3989 llvm::IRBuilderBase &builder,
3996 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3997 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3999 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4000 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4002 assert((atomicUpdateOp || atomicWriteOp) &&
4003 "internal op must be an atomic.update or atomic.write op");
4005 if (atomicWriteOp) {
4006 isPostfixUpdate =
true;
4007 mlirExpr = atomicWriteOp.getExpr();
4009 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4010 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4011 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4014 if (innerOpList.size() == 2) {
4017 atomicUpdateOp.getRegion().getArgument(0))) {
4018 return atomicUpdateOp.emitError(
4019 "no atomic update operation with region argument"
4020 " as operand found inside atomic.update region");
4024 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4027 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4031 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4032 llvm::Value *llvmX =
4033 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4034 llvm::Value *llvmV =
4035 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4036 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4037 atomicCaptureOp.getAtomicReadOp().getElementType());
4038 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4041 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4045 llvm::AtomicOrdering atomicOrdering =
4049 [&](llvm::Value *atomicx,
4052 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4053 Block &bb = *atomicUpdateOp.getRegion().
begin();
4054 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4056 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4057 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4058 return llvm::make_error<PreviouslyReportedError>();
4060 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4061 assert(yieldop && yieldop.getResults().size() == 1 &&
4062 "terminator must be omp.yield op and it must have exactly one "
4064 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4067 bool isIgnoreDenormalMode;
4068 bool isFineGrainedMemory;
4069 bool isRemoteMemory;
4071 isFineGrainedMemory, isRemoteMemory);
4074 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4075 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4076 ompBuilder->createAtomicCapture(
4077 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4078 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4079 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4081 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4084 builder.restoreIP(*afterIP);
4089 omp::ClauseCancellationConstructType directive) {
4090 switch (directive) {
4091 case omp::ClauseCancellationConstructType::Loop:
4092 return llvm::omp::Directive::OMPD_for;
4093 case omp::ClauseCancellationConstructType::Parallel:
4094 return llvm::omp::Directive::OMPD_parallel;
4095 case omp::ClauseCancellationConstructType::Sections:
4096 return llvm::omp::Directive::OMPD_sections;
4097 case omp::ClauseCancellationConstructType::Taskgroup:
4098 return llvm::omp::Directive::OMPD_taskgroup;
4100 llvm_unreachable(
"Unhandled cancellation construct type");
4109 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4112 llvm::Value *ifCond =
nullptr;
4113 if (
Value ifVar = op.getIfExpr())
4116 llvm::omp::Directive cancelledDirective =
4119 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4120 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4122 if (failed(
handleError(afterIP, *op.getOperation())))
4125 builder.restoreIP(afterIP.get());
4132 llvm::IRBuilderBase &builder,
4137 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4140 llvm::omp::Directive cancelledDirective =
4143 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4144 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4146 if (failed(
handleError(afterIP, *op.getOperation())))
4149 builder.restoreIP(afterIP.get());
4159 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4161 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4166 Value symAddr = threadprivateOp.getSymAddr();
4169 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4172 if (!isa<LLVM::AddressOfOp>(symOp))
4173 return opInst.
emitError(
"Addressing symbol not found");
4174 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4176 LLVM::GlobalOp global =
4177 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4178 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4179 llvm::Type *type = globalValue->getValueType();
4180 llvm::TypeSize typeSize =
4181 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4183 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4184 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4185 ompLoc, globalValue, size, global.getSymName() +
".cache");
4191static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4193 switch (deviceClause) {
4194 case mlir::omp::DeclareTargetDeviceType::host:
4195 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4197 case mlir::omp::DeclareTargetDeviceType::nohost:
4198 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4200 case mlir::omp::DeclareTargetDeviceType::any:
4201 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4204 llvm_unreachable(
"unhandled device clause");
4207static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4209 mlir::omp::DeclareTargetCaptureClause captureClause) {
4210 switch (captureClause) {
4211 case mlir::omp::DeclareTargetCaptureClause::to:
4212 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4213 case mlir::omp::DeclareTargetCaptureClause::link:
4214 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4215 case mlir::omp::DeclareTargetCaptureClause::enter:
4216 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4217 case mlir::omp::DeclareTargetCaptureClause::none:
4218 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4220 llvm_unreachable(
"unhandled capture clause");
4225 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4227 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4228 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4229 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4234static llvm::SmallString<64>
4236 llvm::OpenMPIRBuilder &ompBuilder) {
4238 llvm::raw_svector_ostream os(suffix);
4241 auto fileInfoCallBack = [&loc]() {
4242 return std::pair<std::string, uint64_t>(
4243 llvm::StringRef(loc.getFilename()), loc.getLine());
4246 auto vfs = llvm::vfs::getRealFileSystem();
4249 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4251 os <<
"_decl_tgt_ref_ptr";
4257 if (
auto declareTargetGlobal =
4258 dyn_cast_if_present<omp::DeclareTargetInterface>(
4260 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4261 omp::DeclareTargetCaptureClause::link)
4267 if (
auto declareTargetGlobal =
4268 dyn_cast_if_present<omp::DeclareTargetInterface>(
4270 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4271 omp::DeclareTargetCaptureClause::to ||
4272 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4273 omp::DeclareTargetCaptureClause::enter)
4287 if (
auto declareTargetGlobal =
4288 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4291 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4292 omp::DeclareTargetCaptureClause::link) ||
4293 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4294 omp::DeclareTargetCaptureClause::to &&
4295 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4299 if (gOp.getSymName().contains(suffix))
4304 (gOp.getSymName().str() + suffix.str()).str());
4313struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4314 SmallVector<Operation *, 4> Mappers;
4317 void append(MapInfosTy &curInfo) {
4318 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4319 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4328struct MapInfoData : MapInfosTy {
4329 llvm::SmallVector<bool, 4> IsDeclareTarget;
4330 llvm::SmallVector<bool, 4> IsAMember;
4332 llvm::SmallVector<bool, 4> IsAMapping;
4333 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4334 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4337 llvm::SmallVector<llvm::Type *, 4> BaseType;
4340 void append(MapInfoData &CurInfo) {
4341 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4342 CurInfo.IsDeclareTarget.end());
4343 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4344 OriginalValue.append(CurInfo.OriginalValue.begin(),
4345 CurInfo.OriginalValue.end());
4346 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4347 MapInfosTy::append(CurInfo);
4351enum class TargetDirectiveEnumTy : uint32_t {
4355 TargetEnterData = 3,
4360static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4361 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4362 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4363 .Case([](omp::TargetEnterDataOp) {
4364 return TargetDirectiveEnumTy::TargetEnterData;
4366 .Case([&](omp::TargetExitDataOp) {
4367 return TargetDirectiveEnumTy::TargetExitData;
4369 .Case([&](omp::TargetUpdateOp) {
4370 return TargetDirectiveEnumTy::TargetUpdate;
4372 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4373 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4380 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4381 arrTy.getElementType()))
4398 llvm::Value *basePointer,
4399 llvm::Type *baseType,
4400 llvm::IRBuilderBase &builder,
4402 if (
auto memberClause =
4403 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4408 if (!memberClause.getBounds().empty()) {
4409 llvm::Value *elementCount = builder.getInt64(1);
4410 for (
auto bounds : memberClause.getBounds()) {
4411 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4412 bounds.getDefiningOp())) {
4417 elementCount = builder.CreateMul(
4421 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
4422 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
4423 builder.getInt64(1)));
4430 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4438 return builder.CreateMul(elementCount,
4439 builder.getInt64(underlyingTypeSzInBits / 8));
4450static llvm::omp::OpenMPOffloadMappingFlags
4452 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4453 return (mlirFlags & flag) == flag;
4455 const bool hasExplicitMap =
4456 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
4457 omp::ClauseMapFlags::none;
4459 llvm::omp::OpenMPOffloadMappingFlags mapType =
4460 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4463 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4466 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4469 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4472 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4475 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4478 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4481 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4484 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4487 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4490 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4493 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4496 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4499 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4500 if (!hasExplicitMap)
4501 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4511 ArrayRef<Value> useDevAddrOperands = {},
4512 ArrayRef<Value> hasDevAddrOperands = {}) {
4513 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4521 for (Value mapValue : mapVars) {
4522 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4523 for (
auto member : map.getMembers())
4524 if (member == mapOp)
4531 for (Value mapValue : mapVars) {
4532 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4534 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4535 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4536 mapData.Pointers.push_back(mapData.OriginalValue.back());
4538 if (llvm::Value *refPtr =
4540 mapData.IsDeclareTarget.push_back(
true);
4541 mapData.BasePointers.push_back(refPtr);
4543 mapData.IsDeclareTarget.push_back(
true);
4544 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4546 mapData.IsDeclareTarget.push_back(
false);
4547 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4550 mapData.BaseType.push_back(
4551 moduleTranslation.
convertType(mapOp.getVarType()));
4552 mapData.Sizes.push_back(
4553 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4554 mapData.BaseType.back(), builder, moduleTranslation));
4555 mapData.MapClause.push_back(mapOp.getOperation());
4557 mapData.Names.push_back(LLVM::createMappingInformation(
4559 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4560 if (mapOp.getMapperId())
4561 mapData.Mappers.push_back(
4563 mapOp, mapOp.getMapperIdAttr()));
4565 mapData.Mappers.push_back(
nullptr);
4566 mapData.IsAMapping.push_back(
true);
4567 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4570 auto findMapInfo = [&mapData](llvm::Value *val,
4571 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4574 for (llvm::Value *basePtr : mapData.OriginalValue) {
4575 if (basePtr == val && mapData.IsAMapping[index]) {
4577 mapData.Types[index] |=
4578 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4579 mapData.DevicePointers[index] = devInfoTy;
4587 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4588 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4589 for (Value mapValue : useDevOperands) {
4590 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4592 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4593 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4596 if (!findMapInfo(origValue, devInfoTy)) {
4597 mapData.OriginalValue.push_back(origValue);
4598 mapData.Pointers.push_back(mapData.OriginalValue.back());
4599 mapData.IsDeclareTarget.push_back(
false);
4600 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4601 mapData.BaseType.push_back(
4602 moduleTranslation.
convertType(mapOp.getVarType()));
4603 mapData.Sizes.push_back(builder.getInt64(0));
4604 mapData.MapClause.push_back(mapOp.getOperation());
4605 mapData.Types.push_back(
4606 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4607 mapData.Names.push_back(LLVM::createMappingInformation(
4609 mapData.DevicePointers.push_back(devInfoTy);
4610 mapData.Mappers.push_back(
nullptr);
4611 mapData.IsAMapping.push_back(
false);
4612 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4617 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4618 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4620 for (Value mapValue : hasDevAddrOperands) {
4621 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4623 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4624 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4626 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4628 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4629 omp::ClauseMapFlags::none;
4631 mapData.OriginalValue.push_back(origValue);
4632 mapData.BasePointers.push_back(origValue);
4633 mapData.Pointers.push_back(origValue);
4634 mapData.IsDeclareTarget.push_back(
false);
4635 mapData.BaseType.push_back(
4636 moduleTranslation.
convertType(mapOp.getVarType()));
4637 mapData.Sizes.push_back(
4638 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4639 mapData.MapClause.push_back(mapOp.getOperation());
4640 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4644 mapData.Types.push_back(mapType);
4648 if (mapOp.getMapperId()) {
4649 mapData.Mappers.push_back(
4651 mapOp, mapOp.getMapperIdAttr()));
4653 mapData.Mappers.push_back(
nullptr);
4658 mapData.Types.push_back(
4659 isDevicePtr ? mapType
4660 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4661 mapData.Mappers.push_back(
nullptr);
4663 mapData.Names.push_back(LLVM::createMappingInformation(
4665 mapData.DevicePointers.push_back(
4666 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4667 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4668 mapData.IsAMapping.push_back(
false);
4669 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4674 auto *res = llvm::find(mapData.MapClause, memberOp);
4675 assert(res != mapData.MapClause.end() &&
4676 "MapInfoOp for member not found in MapData, cannot return index");
4677 return std::distance(mapData.MapClause.begin(), res);
4681 omp::MapInfoOp mapInfo) {
4682 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4692 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4693 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4695 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4696 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4697 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4699 if (aIndex == bIndex)
4702 if (aIndex < bIndex)
4705 if (aIndex > bIndex)
4712 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4714 occludedChildren.push_back(
b);
4716 occludedChildren.push_back(a);
4717 return memberAParent;
4723 for (
auto v : occludedChildren)
4730 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4732 if (indexAttr.size() == 1)
4733 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4737 return llvm::cast<omp::MapInfoOp>(
4762static std::vector<llvm::Value *>
4764 llvm::IRBuilderBase &builder,
bool isArrayTy,
4766 std::vector<llvm::Value *> idx;
4777 idx.push_back(builder.getInt64(0));
4778 for (
int i = bounds.size() - 1; i >= 0; --i) {
4779 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4780 bounds[i].getDefiningOp())) {
4781 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4799 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4800 for (
int i = bounds.size() - 1; i >= 0; --i) {
4801 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4802 bounds[i].getDefiningOp())) {
4803 if (i == ((
int)bounds.size() - 1))
4805 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4807 idx.back() = builder.CreateAdd(
4808 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4809 boundOp.getExtent())),
4810 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4819 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
4820 return cast<IntegerAttr>(value).getInt();
4828 omp::MapInfoOp parentOp) {
4830 if (parentOp.getMembers().empty())
4834 if (parentOp.getMembers().size() == 1) {
4835 overlapMapDataIdxs.push_back(0);
4841 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4842 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4843 memberByIndex.push_back(
4844 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4849 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4850 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
4856 for (
auto v : memberByIndex) {
4860 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
4863 llvm::SmallVector<int64_t> xArr(x.second.size());
4864 getAsIntegers(x.second, xArr);
4865 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4866 xArr.size() >= vArr.size();
4872 for (
auto v : memberByIndex)
4873 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4874 overlapMapDataIdxs.push_back(v.first);
4886 if (mapOp.getVarPtrPtr())
4915 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4916 MapInfoData &mapData, uint64_t mapDataIndex,
4917 TargetDirectiveEnumTy targetDirective) {
4918 assert(!ompBuilder.Config.isTargetDevice() &&
4919 "function only supported for host device codegen");
4922 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4924 auto *parentMapper = mapData.Mappers[mapDataIndex];
4930 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4931 (targetDirective == TargetDirectiveEnumTy::Target &&
4932 !mapData.IsDeclareTarget[mapDataIndex])
4933 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4934 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4937 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4941 mapFlags parentFlags = mapData.Types[mapDataIndex];
4942 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4943 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4944 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4945 baseFlag |= (parentFlags & preserve);
4948 combinedInfo.Types.emplace_back(baseFlag);
4949 combinedInfo.DevicePointers.emplace_back(
4950 mapData.DevicePointers[mapDataIndex]);
4954 combinedInfo.Mappers.emplace_back(
4955 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
4957 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4958 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4967 llvm::Value *lowAddr, *highAddr;
4968 if (!parentClause.getPartialMap()) {
4969 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4970 builder.getPtrTy());
4971 highAddr = builder.CreatePointerCast(
4972 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4973 mapData.Pointers[mapDataIndex], 1),
4974 builder.getPtrTy());
4975 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4977 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4980 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4981 builder.getPtrTy());
4984 highAddr = builder.CreatePointerCast(
4985 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4986 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4987 builder.getPtrTy());
4988 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4991 llvm::Value *size = builder.CreateIntCast(
4992 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4993 builder.getInt64Ty(),
4995 combinedInfo.Sizes.push_back(size);
4997 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4998 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5006 if (!parentClause.getPartialMap()) {
5011 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5012 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5013 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5014 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5015 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5017 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5018 combinedInfo.Types.emplace_back(mapFlag);
5019 combinedInfo.DevicePointers.emplace_back(
5020 mapData.DevicePointers[mapDataIndex]);
5022 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5023 combinedInfo.BasePointers.emplace_back(
5024 mapData.BasePointers[mapDataIndex]);
5025 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5026 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5027 combinedInfo.Mappers.emplace_back(
nullptr);
5038 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5039 builder.getPtrTy());
5040 highAddr = builder.CreatePointerCast(
5041 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5042 mapData.Pointers[mapDataIndex], 1),
5043 builder.getPtrTy());
5050 for (
auto v : overlapIdxs) {
5053 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5054 combinedInfo.Types.emplace_back(mapFlag);
5055 combinedInfo.DevicePointers.emplace_back(
5056 mapData.DevicePointers[mapDataOverlapIdx]);
5058 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5059 combinedInfo.BasePointers.emplace_back(
5060 mapData.BasePointers[mapDataIndex]);
5061 combinedInfo.Mappers.emplace_back(
nullptr);
5062 combinedInfo.Pointers.emplace_back(lowAddr);
5063 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5064 builder.CreatePtrDiff(builder.getInt8Ty(),
5065 mapData.OriginalValue[mapDataOverlapIdx],
5067 builder.getInt64Ty(),
true));
5068 lowAddr = builder.CreateConstGEP1_32(
5070 mapData.MapClause[mapDataOverlapIdx]))
5071 ? builder.getPtrTy()
5072 : mapData.BaseType[mapDataOverlapIdx],
5073 mapData.BasePointers[mapDataOverlapIdx], 1);
5076 combinedInfo.Types.emplace_back(mapFlag);
5077 combinedInfo.DevicePointers.emplace_back(
5078 mapData.DevicePointers[mapDataIndex]);
5080 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5081 combinedInfo.BasePointers.emplace_back(
5082 mapData.BasePointers[mapDataIndex]);
5083 combinedInfo.Mappers.emplace_back(
nullptr);
5084 combinedInfo.Pointers.emplace_back(lowAddr);
5085 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5086 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5087 builder.getInt64Ty(),
true));
5090 return memberOfFlag;
5096 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5097 MapInfoData &mapData, uint64_t mapDataIndex,
5098 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5099 TargetDirectiveEnumTy targetDirective) {
5100 assert(!ompBuilder.Config.isTargetDevice() &&
5101 "function only supported for host device codegen");
5104 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5106 for (
auto mappedMembers : parentClause.getMembers()) {
5108 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5111 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5122 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5123 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5124 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5125 combinedInfo.Types.emplace_back(mapFlag);
5126 combinedInfo.DevicePointers.emplace_back(
5127 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5128 combinedInfo.Mappers.emplace_back(
nullptr);
5129 combinedInfo.Names.emplace_back(
5131 combinedInfo.BasePointers.emplace_back(
5132 mapData.BasePointers[mapDataIndex]);
5133 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5134 combinedInfo.Sizes.emplace_back(builder.getInt64(
5135 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5141 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5142 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5143 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5145 ? parentClause.getVarPtr()
5146 : parentClause.getVarPtrPtr());
5149 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5150 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5151 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5154 combinedInfo.Types.emplace_back(mapFlag);
5155 combinedInfo.DevicePointers.emplace_back(
5156 mapData.DevicePointers[memberDataIdx]);
5157 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5158 combinedInfo.Names.emplace_back(
5160 uint64_t basePointerIndex =
5162 combinedInfo.BasePointers.emplace_back(
5163 mapData.BasePointers[basePointerIndex]);
5164 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5166 llvm::Value *size = mapData.Sizes[memberDataIdx];
5168 size = builder.CreateSelect(
5169 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5170 builder.getInt64(0), size);
5173 combinedInfo.Sizes.emplace_back(size);
5178 MapInfosTy &combinedInfo,
5179 TargetDirectiveEnumTy targetDirective,
5180 int mapDataParentIdx = -1) {
5184 auto mapFlag = mapData.Types[mapDataIdx];
5185 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5189 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5191 if (targetDirective == TargetDirectiveEnumTy::Target &&
5192 !mapData.IsDeclareTarget[mapDataIdx])
5193 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5195 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5197 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5202 if (mapDataParentIdx >= 0)
5203 combinedInfo.BasePointers.emplace_back(
5204 mapData.BasePointers[mapDataParentIdx]);
5206 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5208 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5209 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5210 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5211 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5212 combinedInfo.Types.emplace_back(mapFlag);
5213 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5217 llvm::IRBuilderBase &builder,
5218 llvm::OpenMPIRBuilder &ompBuilder,
5220 MapInfoData &mapData, uint64_t mapDataIndex,
5221 TargetDirectiveEnumTy targetDirective) {
5222 assert(!ompBuilder.Config.isTargetDevice() &&
5223 "function only supported for host device codegen");
5226 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5231 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5232 auto memberClause = llvm::cast<omp::MapInfoOp>(
5233 parentClause.getMembers()[0].getDefiningOp());
5250 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5252 combinedInfo, mapData, mapDataIndex,
5255 combinedInfo, mapData, mapDataIndex,
5256 memberOfParentFlag, targetDirective);
5266 llvm::IRBuilderBase &builder) {
5268 "function only supported for host device codegen");
5269 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5271 if (!mapData.IsDeclareTarget[i]) {
5272 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5273 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5283 switch (captureKind) {
5284 case omp::VariableCaptureKind::ByRef: {
5285 llvm::Value *newV = mapData.Pointers[i];
5287 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5290 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5292 if (!offsetIdx.empty())
5293 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5295 mapData.Pointers[i] = newV;
5297 case omp::VariableCaptureKind::ByCopy: {
5298 llvm::Type *type = mapData.BaseType[i];
5300 if (mapData.Pointers[i]->getType()->isPointerTy())
5301 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5303 newV = mapData.Pointers[i];
5306 auto curInsert = builder.saveIP();
5307 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5309 auto *memTempAlloc =
5310 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5311 builder.SetCurrentDebugLocation(DbgLoc);
5312 builder.restoreIP(curInsert);
5314 builder.CreateStore(newV, memTempAlloc);
5315 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5318 mapData.Pointers[i] = newV;
5319 mapData.BasePointers[i] = newV;
5321 case omp::VariableCaptureKind::This:
5322 case omp::VariableCaptureKind::VLAType:
5323 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5334 MapInfoData &mapData,
5335 TargetDirectiveEnumTy targetDirective) {
5337 "function only supported for host device codegen");
5358 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5361 if (mapData.IsAMember[i])
5364 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5365 if (!mapInfoOp.getMembers().empty()) {
5367 combinedInfo, mapData, i, targetDirective);
5375static llvm::Expected<llvm::Function *>
5377 LLVM::ModuleTranslation &moduleTranslation,
5378 llvm::StringRef mapperFuncName,
5379 TargetDirectiveEnumTy targetDirective);
5381static llvm::Expected<llvm::Function *>
5384 TargetDirectiveEnumTy targetDirective) {
5386 "function only supported for host device codegen");
5387 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5388 std::string mapperFuncName =
5390 {
"omp_mapper", declMapperOp.getSymName()});
5392 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
5400 if (llvm::Function *existingFunc =
5401 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
5402 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
5403 return existingFunc;
5407 mapperFuncName, targetDirective);
5410static llvm::Expected<llvm::Function *>
5413 llvm::StringRef mapperFuncName,
5414 TargetDirectiveEnumTy targetDirective) {
5416 "function only supported for host device codegen");
5417 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5418 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5421 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
5424 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5427 MapInfosTy combinedInfo;
5429 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5430 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5431 builder.restoreIP(codeGenIP);
5432 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
5433 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
5434 builder.GetInsertBlock());
5435 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
5438 return llvm::make_error<PreviouslyReportedError>();
5439 MapInfoData mapData;
5442 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5448 return combinedInfo;
5452 if (!combinedInfo.Mappers[i])
5455 moduleTranslation, targetDirective);
5459 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5461 return newFn.takeError();
5462 if ([[maybe_unused]] llvm::Function *mappedFunc =
5464 assert(mappedFunc == *newFn &&
5465 "mapper function mapping disagrees with emitted function");
5467 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
5475 llvm::Value *ifCond =
nullptr;
5476 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5480 llvm::omp::RuntimeFunction RTLFn;
5482 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5485 llvm::OpenMPIRBuilder::TargetDataInfo info(
5488 assert(!ompBuilder->Config.isTargetDevice() &&
5489 "target data/enter/exit/update are host ops");
5490 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5492 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
5493 llvm::Value *v = moduleTranslation.
lookupValue(dev);
5494 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
5499 .Case([&](omp::TargetDataOp dataOp) {
5503 if (
auto ifVar = dataOp.getIfExpr())
5507 deviceID = getDeviceID(devId);
5509 mapVars = dataOp.getMapVars();
5510 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5511 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5514 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5518 if (
auto ifVar = enterDataOp.getIfExpr())
5522 deviceID = getDeviceID(devId);
5525 enterDataOp.getNowait()
5526 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5527 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5528 mapVars = enterDataOp.getMapVars();
5529 info.HasNoWait = enterDataOp.getNowait();
5532 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5536 if (
auto ifVar = exitDataOp.getIfExpr())
5540 deviceID = getDeviceID(devId);
5542 RTLFn = exitDataOp.getNowait()
5543 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5544 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5545 mapVars = exitDataOp.getMapVars();
5546 info.HasNoWait = exitDataOp.getNowait();
5549 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5553 if (
auto ifVar = updateDataOp.getIfExpr())
5557 deviceID = getDeviceID(devId);
5560 updateDataOp.getNowait()
5561 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5562 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5563 mapVars = updateDataOp.getMapVars();
5564 info.HasNoWait = updateDataOp.getNowait();
5567 .DefaultUnreachable(
"unexpected operation");
5572 if (!isOffloadEntry)
5573 ifCond = builder.getFalse();
5575 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5576 MapInfoData mapData;
5578 builder, useDevicePtrVars, useDeviceAddrVars);
5581 MapInfosTy combinedInfo;
5582 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5583 builder.restoreIP(codeGenIP);
5584 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5586 return combinedInfo;
5592 [&moduleTranslation](
5593 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5597 for (
auto [arg, useDevVar] :
5598 llvm::zip_equal(blockArgs, useDeviceVars)) {
5600 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5601 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5602 : mapInfoOp.getVarPtr();
5605 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5606 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5607 mapInfoData.MapClause, mapInfoData.DevicePointers,
5608 mapInfoData.BasePointers)) {
5609 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5610 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5611 devicePointer != type)
5614 if (llvm::Value *devPtrInfoMap =
5615 mapper ? mapper(basePointer) : basePointer) {
5616 moduleTranslation.
mapValue(arg, devPtrInfoMap);
5623 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5624 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5625 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5628 builder.restoreIP(codeGenIP);
5629 assert(isa<omp::TargetDataOp>(op) &&
5630 "BodyGen requested for non TargetDataOp");
5631 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5632 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
5633 switch (bodyGenType) {
5634 case BodyGenTy::Priv:
5636 if (!info.DevicePtrInfoMap.empty()) {
5637 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5638 blockArgIface.getUseDeviceAddrBlockArgs(),
5639 useDeviceAddrVars, mapData,
5640 [&](llvm::Value *basePointer) -> llvm::Value * {
5641 if (!info.DevicePtrInfoMap[basePointer].second)
5643 return builder.CreateLoad(
5645 info.DevicePtrInfoMap[basePointer].second);
5647 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5648 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5649 mapData, [&](llvm::Value *basePointer) {
5650 return info.DevicePtrInfoMap[basePointer].second;
5654 moduleTranslation)))
5655 return llvm::make_error<PreviouslyReportedError>();
5658 case BodyGenTy::DupNoPriv:
5659 if (info.DevicePtrInfoMap.empty()) {
5662 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5663 blockArgIface.getUseDeviceAddrBlockArgs(),
5664 useDeviceAddrVars, mapData);
5665 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5666 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5670 case BodyGenTy::NoPriv:
5672 if (info.DevicePtrInfoMap.empty()) {
5674 moduleTranslation)))
5675 return llvm::make_error<PreviouslyReportedError>();
5679 return builder.saveIP();
5682 auto customMapperCB =
5684 if (!combinedInfo.Mappers[i])
5686 info.HasMapper =
true;
5688 moduleTranslation, targetDirective);
5691 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5692 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5694 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5695 if (isa<omp::TargetDataOp>(op))
5696 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5697 deviceID, ifCond, info, genMapInfoCB,
5701 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5702 deviceID, ifCond, info, genMapInfoCB,
5703 customMapperCB, &RTLFn);
5709 builder.restoreIP(*afterIP);
5717 auto distributeOp = cast<omp::DistributeOp>(opInst);
5724 bool doDistributeReduction =
5728 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5733 if (doDistributeReduction) {
5734 isByRef =
getIsByRef(teamsOp.getReductionByref());
5735 assert(isByRef.size() == teamsOp.getNumReductionVars());
5738 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5742 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5743 .getReductionBlockArgs();
5746 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5747 reductionDecls, privateReductionVariables, reductionVariableMap,
5752 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5753 auto bodyGenCB = [&](InsertPointTy allocaIP,
5754 InsertPointTy codeGenIP) -> llvm::Error {
5758 moduleTranslation, allocaIP);
5761 builder.restoreIP(codeGenIP);
5767 return llvm::make_error<PreviouslyReportedError>();
5772 return llvm::make_error<PreviouslyReportedError>();
5775 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
5777 distributeOp.getPrivateNeedsBarrier())))
5778 return llvm::make_error<PreviouslyReportedError>();
5781 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5784 builder, moduleTranslation);
5786 return regionBlock.takeError();
5787 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5792 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5795 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5796 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5797 : omp::ClauseScheduleKind::Static;
5799 bool isOrdered = hasDistSchedule;
5800 std::optional<omp::ScheduleModifier> scheduleMod;
5801 bool isSimd =
false;
5802 llvm::omp::WorksharingLoopType workshareLoopType =
5803 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5804 bool loopNeedsBarrier =
false;
5805 llvm::Value *chunk = moduleTranslation.
lookupValue(
5806 distributeOp.getDistScheduleChunkSize());
5807 llvm::CanonicalLoopInfo *loopInfo =
5809 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5810 ompBuilder->applyWorkshareLoop(
5811 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5812 convertToScheduleKind(schedule), chunk, isSimd,
5813 scheduleMod == omp::ScheduleModifier::monotonic,
5814 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5815 workshareLoopType,
false, hasDistSchedule, chunk);
5818 return wsloopIP.takeError();
5821 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5823 return llvm::make_error<PreviouslyReportedError>();
5825 return llvm::Error::success();
5828 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5830 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5831 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5832 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5837 builder.restoreIP(*afterIP);
5839 if (doDistributeReduction) {
5842 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5843 privateReductionVariables, isByRef,
5855 if (!cast<mlir::ModuleOp>(op))
5860 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5861 attribute.getOpenmpDeviceVersion());
5863 if (attribute.getNoGpuLib())
5866 ompBuilder->createGlobalFlag(
5867 attribute.getDebugKind() ,
5868 "__omp_rtl_debug_kind");
5869 ompBuilder->createGlobalFlag(
5871 .getAssumeTeamsOversubscription()
5873 "__omp_rtl_assume_teams_oversubscription");
5874 ompBuilder->createGlobalFlag(
5876 .getAssumeThreadsOversubscription()
5878 "__omp_rtl_assume_threads_oversubscription");
5879 ompBuilder->createGlobalFlag(
5880 attribute.getAssumeNoThreadState() ,
5881 "__omp_rtl_assume_no_thread_state");
5882 ompBuilder->createGlobalFlag(
5884 .getAssumeNoNestedParallelism()
5886 "__omp_rtl_assume_no_nested_parallelism");
5891 omp::TargetOp targetOp,
5892 llvm::StringRef parentName =
"") {
5893 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5895 assert(fileLoc &&
"No file found from location");
5896 StringRef fileName = fileLoc.getFilename().getValue();
5898 llvm::sys::fs::UniqueID id;
5899 uint64_t line = fileLoc.getLine();
5900 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5901 size_t fileHash = llvm::hash_value(fileName.str());
5902 size_t deviceId = 0xdeadf17e;
5904 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5906 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5907 id.getFile(), line);
5914 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5916 "function only supported for target device codegen");
5917 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5918 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5931 if (mapData.IsDeclareTarget[i]) {
5938 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5939 convertUsersOfConstantsToInstructions(constant,
func,
false);
5946 for (llvm::User *user : mapData.OriginalValue[i]->users())
5947 userVec.push_back(user);
5949 for (llvm::User *user : userVec) {
5950 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5951 if (insn->getFunction() ==
func) {
5952 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5953 llvm::Value *substitute = mapData.BasePointers[i];
5955 : mapOp.getVarPtr())) {
5956 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5957 substitute = builder.CreateLoad(
5958 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5959 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5961 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6008static llvm::IRBuilderBase::InsertPoint
6010 llvm::Value *input, llvm::Value *&retVal,
6011 llvm::IRBuilderBase &builder,
6012 llvm::OpenMPIRBuilder &ompBuilder,
6014 llvm::IRBuilderBase::InsertPoint allocaIP,
6015 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6016 assert(ompBuilder.Config.isTargetDevice() &&
6017 "function only supported for target device codegen");
6018 builder.restoreIP(allocaIP);
6020 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6022 ompBuilder.M.getContext());
6023 unsigned alignmentValue = 0;
6025 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
6026 if (mapData.OriginalValue[i] == input) {
6027 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6028 capture = mapOp.getMapCaptureType();
6031 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6035 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6036 unsigned int defaultAS =
6037 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6040 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6042 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6043 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6045 builder.CreateStore(&arg, v);
6047 builder.restoreIP(codeGenIP);
6050 case omp::VariableCaptureKind::ByCopy: {
6054 case omp::VariableCaptureKind::ByRef: {
6055 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6057 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6072 if (v->getType()->isPointerTy() && alignmentValue) {
6073 llvm::MDBuilder MDB(builder.getContext());
6074 loadInst->setMetadata(
6075 llvm::LLVMContext::MD_align,
6076 llvm::MDNode::get(builder.getContext(),
6077 MDB.createConstant(llvm::ConstantInt::get(
6078 llvm::Type::getInt64Ty(builder.getContext()),
6085 case omp::VariableCaptureKind::This:
6086 case omp::VariableCaptureKind::VLAType:
6089 assert(
false &&
"Currently unsupported capture kind");
6093 return builder.saveIP();
6110 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6111 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6112 blockArgIface.getHostEvalBlockArgs())) {
6113 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6117 .Case([&](omp::TeamsOp teamsOp) {
6118 if (teamsOp.getNumTeamsLower() == blockArg)
6119 numTeamsLower = hostEvalVar;
6120 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6122 numTeamsUpper = hostEvalVar;
6123 else if (!teamsOp.getThreadLimitVars().empty() &&
6124 teamsOp.getThreadLimit(0) == blockArg)
6125 threadLimit = hostEvalVar;
6127 llvm_unreachable(
"unsupported host_eval use");
6129 .Case([&](omp::ParallelOp parallelOp) {
6130 if (!parallelOp.getNumThreadsVars().empty() &&
6131 parallelOp.getNumThreads(0) == blockArg)
6132 numThreads = hostEvalVar;
6134 llvm_unreachable(
"unsupported host_eval use");
6136 .Case([&](omp::LoopNestOp loopOp) {
6137 auto processBounds =
6141 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6142 if (lb == blockArg) {
6145 (*outBounds)[i] = hostEvalVar;
6151 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6152 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6154 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6156 assert(found &&
"unsupported host_eval use");
6158 .DefaultUnreachable(
"unsupported host_eval use");
6170template <
typename OpTy>
6175 if (OpTy casted = dyn_cast<OpTy>(op))
6178 if (immediateParent)
6179 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6188 return std::nullopt;
6191 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6192 return constAttr.getInt();
6194 return std::nullopt;
6199 uint64_t sizeInBytes = sizeInBits / 8;
6203template <
typename OpTy>
6205 if (op.getNumReductionVars() > 0) {
6210 members.reserve(reductions.size());
6211 for (omp::DeclareReductionOp &red : reductions)
6212 members.push_back(red.getType());
6214 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6230 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6231 bool isTargetDevice,
bool isGPU) {
6234 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6235 if (!isTargetDevice) {
6243 numTeamsLower = teamsOp.getNumTeamsLower();
6245 if (!teamsOp.getNumTeamsUpperVars().empty())
6246 numTeamsUpper = teamsOp.getNumTeams(0);
6247 if (!teamsOp.getThreadLimitVars().empty())
6248 threadLimit = teamsOp.getThreadLimit(0);
6252 if (!parallelOp.getNumThreadsVars().empty())
6253 numThreads = parallelOp.getNumThreads(0);
6259 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6263 if (numTeamsUpper) {
6265 minTeamsVal = maxTeamsVal = *val;
6267 minTeamsVal = maxTeamsVal = 0;
6273 minTeamsVal = maxTeamsVal = 1;
6275 minTeamsVal = maxTeamsVal = -1;
6280 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6294 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6295 if (!targetOp.getThreadLimitVars().empty())
6296 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6297 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6300 int32_t maxThreadsVal = -1;
6302 setMaxValueFromClause(numThreads, maxThreadsVal);
6310 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6311 if (combinedMaxThreadsVal < 0 ||
6312 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6313 combinedMaxThreadsVal = teamsThreadLimitVal;
6315 if (combinedMaxThreadsVal < 0 ||
6316 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6317 combinedMaxThreadsVal = maxThreadsVal;
6319 int32_t reductionDataSize = 0;
6320 if (isGPU && capturedOp) {
6326 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6328 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6329 omp::TargetRegionFlags::spmd) &&
6330 "invalid kernel flags");
6332 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6333 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6334 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6335 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6336 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6337 if (omp::bitEnumContainsAll(kernelFlags,
6338 omp::TargetRegionFlags::spmd |
6339 omp::TargetRegionFlags::no_loop) &&
6340 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6341 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6343 attrs.MinTeams = minTeamsVal;
6344 attrs.MaxTeams.front() = maxTeamsVal;
6345 attrs.MinThreads = 1;
6346 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6347 attrs.ReductionDataSize = reductionDataSize;
6350 if (attrs.ReductionDataSize != 0)
6351 attrs.ReductionBufferLength = 1024;
6363 omp::TargetOp targetOp,
Operation *capturedOp,
6364 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6366 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6368 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6372 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6375 if (!targetOp.getThreadLimitVars().empty()) {
6376 Value targetThreadLimit = targetOp.getThreadLimit(0);
6377 attrs.TargetThreadLimit.front() =
6385 attrs.MinTeams = builder.CreateSExtOrTrunc(
6386 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
6389 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6390 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
6392 if (teamsThreadLimit)
6393 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6394 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
6397 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
6399 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6400 omp::TargetRegionFlags::trip_count)) {
6402 attrs.LoopTripCount =
nullptr;
6407 for (
auto [loopLower, loopUpper, loopStep] :
6408 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6409 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
6410 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
6411 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
6413 if (!lowerBound || !upperBound || !step) {
6414 attrs.LoopTripCount =
nullptr;
6418 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6419 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6420 loc, lowerBound, upperBound, step,
true,
6421 loopOp.getLoopInclusive());
6423 if (!attrs.LoopTripCount) {
6424 attrs.LoopTripCount = tripCount;
6429 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6434 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6436 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
6438 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6445 auto targetOp = cast<omp::TargetOp>(opInst);
6449 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6458 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6459 assert(parentBB &&
"No insert block is set for the builder");
6460 llvm::Function *parentLLVMFn = parentBB->getParent();
6461 assert(parentLLVMFn &&
"Parent Function must be valid");
6462 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6463 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6464 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6465 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6468 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6469 bool isGPU = ompBuilder->Config.isGPU();
6472 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6473 auto &targetRegion = targetOp.getRegion();
6490 llvm::Function *llvmOutlinedFn =
nullptr;
6491 TargetDirectiveEnumTy targetDirective =
6492 getTargetDirectiveEnumTyFromOp(&opInst);
6496 bool isOffloadEntry =
6497 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6504 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6506 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6507 std::optional<DenseI64ArrayAttr> privateMapIndices =
6508 targetOp.getPrivateMapsAttr();
6510 for (
auto [privVarIdx, privVarSymPair] :
6511 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6512 auto privVar = std::get<0>(privVarSymPair);
6513 auto privSym = std::get<1>(privVarSymPair);
6515 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6516 omp::PrivateClauseOp privatizer =
6519 if (!privatizer.needsMap())
6523 targetOp.getMappedValueForPrivateVar(privVarIdx);
6524 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6525 "variable that needs mapping");
6530 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6531 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6535 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6537 varType == privVar.getType() &&
6538 "Type of private var doesn't match the type of the mapped value");
6542 mappedPrivateVars.insert(
6544 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6545 (*privateMapIndices)[privVarIdx])});
6549 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6550 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6551 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6552 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6553 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6556 llvm::Function *llvmParentFn =
6558 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6559 assert(llvmParentFn && llvmOutlinedFn &&
6560 "Both parent and outlined functions must exist at this point");
6562 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6563 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6565 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6566 attr.isStringAttribute())
6567 llvmOutlinedFn->addFnAttr(attr);
6569 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6570 attr.isStringAttribute())
6571 llvmOutlinedFn->addFnAttr(attr);
6573 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6574 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6575 llvm::Value *mapOpValue =
6576 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6577 moduleTranslation.
mapValue(arg, mapOpValue);
6579 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6580 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6581 llvm::Value *mapOpValue =
6582 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6583 moduleTranslation.
mapValue(arg, mapOpValue);
6592 allocaIP, &mappedPrivateVars);
6595 return llvm::make_error<PreviouslyReportedError>();
6597 builder.restoreIP(codeGenIP);
6599 &mappedPrivateVars),
6602 return llvm::make_error<PreviouslyReportedError>();
6605 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
6607 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6608 return llvm::make_error<PreviouslyReportedError>();
6612 std::back_inserter(privateCleanupRegions),
6613 [](omp::PrivateClauseOp privatizer) {
6614 return &privatizer.getDeallocRegion();
6618 targetRegion,
"omp.target", builder, moduleTranslation);
6621 return exitBlock.takeError();
6623 builder.SetInsertPoint(*exitBlock);
6624 if (!privateCleanupRegions.empty()) {
6626 privateCleanupRegions, privateVarsInfo.
llvmVars,
6627 moduleTranslation, builder,
"omp.targetop.private.cleanup",
6629 return llvm::createStringError(
6630 "failed to inline `dealloc` region of `omp.private` "
6631 "op in the target region");
6633 return builder.saveIP();
6636 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6639 StringRef parentName = parentFn.getName();
6641 llvm::TargetRegionEntryInfo entryInfo;
6645 MapInfoData mapData;
6650 MapInfosTy combinedInfos;
6652 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6653 builder.restoreIP(codeGenIP);
6654 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6656 return combinedInfos;
6659 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6660 llvm::Value *&retVal, InsertPointTy allocaIP,
6661 InsertPointTy codeGenIP)
6662 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6663 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6664 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6670 if (!isTargetDevice) {
6671 retVal = cast<llvm::Value>(&arg);
6676 *ompBuilder, moduleTranslation,
6677 allocaIP, codeGenIP);
6680 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6681 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6682 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6684 isTargetDevice, isGPU);
6688 if (!isTargetDevice)
6690 targetCapturedOp, runtimeAttrs);
6698 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6699 llvm::Value *value = moduleTranslation.
lookupValue(var);
6700 moduleTranslation.
mapValue(arg, value);
6702 if (!llvm::isa<llvm::Constant>(value))
6703 kernelInput.push_back(value);
6706 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6713 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6714 kernelInput.push_back(mapData.OriginalValue[i]);
6719 moduleTranslation, dds);
6721 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6723 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6725 llvm::OpenMPIRBuilder::TargetDataInfo info(
6729 auto customMapperCB =
6731 if (!combinedInfos.Mappers[i])
6733 info.HasMapper =
true;
6735 moduleTranslation, targetDirective);
6738 llvm::Value *ifCond =
nullptr;
6739 if (
Value targetIfCond = targetOp.getIfExpr())
6740 ifCond = moduleTranslation.
lookupValue(targetIfCond);
6742 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6744 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6745 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6746 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6751 builder.restoreIP(*afterIP);
6772 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6773 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6775 if (!offloadMod.getIsTargetDevice())
6778 omp::DeclareTargetDeviceType declareType =
6779 attribute.getDeviceType().getValue();
6781 if (declareType == omp::DeclareTargetDeviceType::host) {
6782 llvm::Function *llvmFunc =
6784 llvmFunc->dropAllReferences();
6785 llvmFunc->eraseFromParent();
6791 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6792 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6793 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6795 bool isDeclaration = gOp.isDeclaration();
6796 bool isExternallyVisible =
6799 llvm::StringRef mangledName = gOp.getSymName();
6800 auto captureClause =
6806 std::vector<llvm::GlobalVariable *> generatedRefs;
6808 std::vector<llvm::Triple> targetTriple;
6809 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6811 LLVM::LLVMDialect::getTargetTripleAttrName()));
6812 if (targetTripleAttr)
6813 targetTriple.emplace_back(targetTripleAttr.data());
6815 auto fileInfoCallBack = [&loc]() {
6816 std::string filename =
"";
6817 std::uint64_t lineNo = 0;
6820 filename = loc.getFilename().str();
6821 lineNo = loc.getLine();
6824 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6828 auto vfs = llvm::vfs::getRealFileSystem();
6830 ompBuilder->registerTargetGlobalVariable(
6831 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6832 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6833 mangledName, generatedRefs,
false, targetTriple,
6835 gVal->getType(), gVal);
6837 if (ompBuilder->Config.isTargetDevice() &&
6838 (attribute.getCaptureClause().getValue() !=
6839 mlir::omp::DeclareTargetCaptureClause::to ||
6840 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6841 ompBuilder->getAddrOfDeclareTargetVar(
6842 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6843 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6844 mangledName, generatedRefs,
false, targetTriple,
6845 gVal->getType(),
nullptr,
6858class OpenMPDialectLLVMIRTranslationInterface
6859 :
public LLVMTranslationDialectInterface {
6866 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6867 LLVM::ModuleTranslation &moduleTranslation)
const final;
6872 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6873 NamedAttribute attribute,
6874 LLVM::ModuleTranslation &moduleTranslation)
const final;
6879LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6880 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6881 NamedAttribute attribute,
6882 LLVM::ModuleTranslation &moduleTranslation)
const {
6883 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6885 .Case(
"omp.is_target_device",
6886 [&](Attribute attr) {
6887 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6888 llvm::OpenMPIRBuilderConfig &
config =
6890 config.setIsTargetDevice(deviceAttr.getValue());
6896 [&](Attribute attr) {
6897 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6898 llvm::OpenMPIRBuilderConfig &
config =
6900 config.setIsGPU(gpuAttr.getValue());
6905 .Case(
"omp.host_ir_filepath",
6906 [&](Attribute attr) {
6907 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6908 llvm::OpenMPIRBuilder *ompBuilder =
6910 auto VFS = llvm::vfs::getRealFileSystem();
6911 ompBuilder->loadOffloadInfoMetadata(*VFS,
6912 filepathAttr.getValue());
6918 [&](Attribute attr) {
6919 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6923 .Case(
"omp.version",
6924 [&](Attribute attr) {
6925 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6926 llvm::OpenMPIRBuilder *ompBuilder =
6928 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6929 versionAttr.getVersion());
6934 .Case(
"omp.declare_target",
6935 [&](Attribute attr) {
6936 if (
auto declareTargetAttr =
6937 dyn_cast<omp::DeclareTargetAttr>(attr))
6942 .Case(
"omp.requires",
6943 [&](Attribute attr) {
6944 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6945 using Requires = omp::ClauseRequires;
6946 Requires flags = requiresAttr.getValue();
6947 llvm::OpenMPIRBuilderConfig &
config =
6949 config.setHasRequiresReverseOffload(
6950 bitEnumContainsAll(flags, Requires::reverse_offload));
6951 config.setHasRequiresUnifiedAddress(
6952 bitEnumContainsAll(flags, Requires::unified_address));
6953 config.setHasRequiresUnifiedSharedMemory(
6954 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6955 config.setHasRequiresDynamicAllocators(
6956 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6961 .Case(
"omp.target_triples",
6962 [&](Attribute attr) {
6963 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6964 llvm::OpenMPIRBuilderConfig &
config =
6966 config.TargetTriples.clear();
6967 config.TargetTriples.reserve(triplesAttr.size());
6968 for (Attribute tripleAttr : triplesAttr) {
6969 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6970 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6978 .Default([](Attribute) {
6994 if (
auto declareTargetIface =
6995 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6996 parentFn.getOperation()))
6997 if (declareTargetIface.isDeclareTarget() &&
6998 declareTargetIface.getDeclareTargetDeviceType() !=
6999 mlir::omp::DeclareTargetDeviceType::host)
7009 llvm::Module *llvmModule) {
7010 llvm::Type *i64Ty = builder.getInt64Ty();
7011 llvm::Type *i32Ty = builder.getInt32Ty();
7012 llvm::Type *returnType = builder.getPtrTy(0);
7013 llvm::FunctionType *fnType =
7014 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
7015 llvm::Function *
func = cast<llvm::Function>(
7016 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
7023 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7028 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7032 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7034 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7035 mlir::Type heapTy = allocMemOp.getAllocatedType();
7036 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
7037 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7038 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7039 for (
auto typeParam : allocMemOp.getTypeparams())
7041 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
7043 llvm::CallInst *call =
7044 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7045 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7048 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
7053 llvm::Module *llvmModule) {
7054 llvm::Type *ptrTy = builder.getPtrTy(0);
7055 llvm::Type *i32Ty = builder.getInt32Ty();
7056 llvm::Type *voidTy = builder.getVoidTy();
7057 llvm::FunctionType *fnType =
7058 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
7059 llvm::Function *
func = dyn_cast<llvm::Function>(
7060 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
7067 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7072 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7076 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7079 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
7081 llvm::Value *intToPtr =
7082 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7083 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7089LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7090 Operation *op, llvm::IRBuilderBase &builder,
7091 LLVM::ModuleTranslation &moduleTranslation)
const {
7094 if (ompBuilder->Config.isTargetDevice() &&
7095 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7098 return op->
emitOpError() <<
"unsupported host op found in device";
7106 bool isOutermostLoopWrapper =
7107 isa_and_present<omp::LoopWrapperInterface>(op) &&
7108 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
7110 if (isOutermostLoopWrapper)
7111 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
7114 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7115 .Case([&](omp::BarrierOp op) -> LogicalResult {
7119 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7120 ompBuilder->createBarrier(builder.saveIP(),
7121 llvm::omp::OMPD_barrier);
7123 if (res.succeeded()) {
7126 builder.restoreIP(*afterIP);
7130 .Case([&](omp::TaskyieldOp op) {
7134 ompBuilder->createTaskyield(builder.saveIP());
7137 .Case([&](omp::FlushOp op) {
7149 ompBuilder->createFlush(builder.saveIP());
7152 .Case([&](omp::ParallelOp op) {
7155 .Case([&](omp::MaskedOp) {
7158 .Case([&](omp::MasterOp) {
7161 .Case([&](omp::CriticalOp) {
7164 .Case([&](omp::OrderedRegionOp) {
7167 .Case([&](omp::OrderedOp) {
7170 .Case([&](omp::WsloopOp) {
7173 .Case([&](omp::SimdOp) {
7176 .Case([&](omp::AtomicReadOp) {
7179 .Case([&](omp::AtomicWriteOp) {
7182 .Case([&](omp::AtomicUpdateOp op) {
7185 .Case([&](omp::AtomicCaptureOp op) {
7188 .Case([&](omp::CancelOp op) {
7191 .Case([&](omp::CancellationPointOp op) {
7194 .Case([&](omp::SectionsOp) {
7197 .Case([&](omp::SingleOp op) {
7200 .Case([&](omp::TeamsOp op) {
7203 .Case([&](omp::TaskOp op) {
7206 .Case([&](omp::TaskloopOp op) {
7209 .Case([&](omp::TaskgroupOp op) {
7212 .Case([&](omp::TaskwaitOp op) {
7215 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7216 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7217 omp::CriticalDeclareOp>([](
auto op) {
7230 .Case([&](omp::ThreadprivateOp) {
7233 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7234 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
7237 .Case([&](omp::TargetOp) {
7240 .Case([&](omp::DistributeOp) {
7243 .Case([&](omp::LoopNestOp) {
7246 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7253 .Case([&](omp::NewCliOp op) {
7258 .Case([&](omp::CanonicalLoopOp op) {
7261 .Case([&](omp::UnrollHeuristicOp op) {
7270 .Case([&](omp::TileOp op) {
7271 return applyTile(op, builder, moduleTranslation);
7273 .Case([&](omp::TargetAllocMemOp) {
7276 .Case([&](omp::TargetFreeMemOp) {
7279 .Default([&](Operation *inst) {
7281 <<
"not yet implemented: " << inst->
getName();
7284 if (isOutermostLoopWrapper)
7291 registry.
insert<omp::OpenMPDialect>();
7293 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars