24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
198 if (linearVarType->isIntegerTy()) {
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 }
else if (linearVarType->isFloatingPointTy()) {
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
220 "Linear variable must be of integer or floating-point type");
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(),
"omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
265 builder.SetInsertPoint(linearExitBB->getTerminator());
267 builder.saveIP(), llvm::omp::OMPD_barrier);
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
282 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (
auto *user : users) {
288 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
307 assert(privatizer &&
"privatizer not found in the symbol table");
318 auto todo = [&op](StringRef clauseName) {
319 return op.
emitError() <<
"not yet implemented: Unhandled clause "
320 << clauseName <<
" in " << op.
getName()
324 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
325 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
326 result = todo(
"allocate");
328 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
330 result = todo(
"ompx_bare");
332 auto checkCollapse = [&todo](
auto op, LogicalResult &
result) {
333 if (op.getCollapseNumLoops() > 1)
334 result = todo(
"collapse");
336 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
340 auto checkHint = [](
auto op, LogicalResult &) {
344 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo(
"in_reduction");
349 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
353 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
354 if (op.getOrder() || op.getOrderMod())
357 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &
result) {
358 if (op.getParLevelSimd())
359 result = todo(
"parallelization-level");
361 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo(
"privatization");
365 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo(
"reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo(
"reduction with modifier");
374 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo(
"task_reduction");
379 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
380 if (op.hasNumTeamsMultiDim())
381 result = todo(
"num_teams with multi-dimensional values");
383 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
384 if (op.hasNumThreadsMultiDim())
385 result = todo(
"num_threads with multi-dimensional values");
388 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
389 if (op.hasThreadLimitMultiDim())
390 result = todo(
"thread_limit with multi-dimensional values");
395 .Case([&](omp::DistributeOp op) {
396 checkAllocate(op,
result);
399 .Case([&](omp::LoopNestOp op) {
400 if (mlir::isa<omp::TaskloopOp>(op.getOperation()->
getParentOp()))
401 checkCollapse(op,
result);
403 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op,
result); })
404 .Case([&](omp::SectionsOp op) {
405 checkAllocate(op,
result);
407 checkReduction(op,
result);
409 .Case([&](omp::SingleOp op) {
410 checkAllocate(op,
result);
413 .Case([&](omp::TeamsOp op) {
414 checkAllocate(op,
result);
416 checkNumTeams(op,
result);
417 checkThreadLimit(op,
result);
419 .Case([&](omp::TaskOp op) {
420 checkAllocate(op,
result);
421 checkInReduction(op,
result);
423 .Case([&](omp::TaskgroupOp op) {
424 checkAllocate(op,
result);
425 checkTaskReduction(op,
result);
427 .Case([&](omp::TaskwaitOp op) {
431 .Case([&](omp::TaskloopOp op) {
432 checkAllocate(op,
result);
433 checkInReduction(op,
result);
434 checkReduction(op,
result);
436 .Case([&](omp::WsloopOp op) {
437 checkAllocate(op,
result);
439 checkReduction(op,
result);
441 .Case([&](omp::ParallelOp op) {
442 checkAllocate(op,
result);
443 checkReduction(op,
result);
444 checkNumThreads(op,
result);
446 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
450 [&](
auto op) { checkDepend(op,
result); })
451 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
452 .Case([&](omp::TargetOp op) {
453 checkAllocate(op,
result);
455 checkInReduction(op,
result);
456 checkThreadLimit(op,
result);
468 llvm::handleAllErrors(
470 [&](
const PreviouslyReportedError &) {
result = failure(); },
471 [&](
const llvm::ErrorInfoBase &err) {
488static llvm::OpenMPIRBuilder::InsertPointTy
494 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
496 [&](OpenMPAllocaStackFrame &frame) {
497 allocaInsertPoint = frame.allocaInsertPoint;
505 allocaInsertPoint.getBlock()->getParent() ==
506 builder.GetInsertBlock()->getParent())
507 return allocaInsertPoint;
516 if (builder.GetInsertBlock() ==
517 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
518 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
519 "Assuming end of basic block");
520 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
521 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
522 builder.GetInsertBlock()->getNextNode());
523 builder.CreateBr(entryBB);
524 builder.SetInsertPoint(entryBB);
527 llvm::BasicBlock &funcEntryBlock =
528 builder.GetInsertBlock()->getParent()->getEntryBlock();
529 return llvm::OpenMPIRBuilder::InsertPointTy(
530 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
536static llvm::CanonicalLoopInfo *
538 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
539 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
540 [&](OpenMPLoopInfoStackFrame &frame) {
541 loopInfo = frame.loopInfo;
553 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
556 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
558 llvm::BasicBlock *continuationBlock =
559 splitBB(builder,
true,
"omp.region.cont");
560 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
562 llvm::LLVMContext &llvmContext = builder.getContext();
563 for (
Block &bb : region) {
564 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
565 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
566 builder.GetInsertBlock()->getNextNode());
567 moduleTranslation.
mapBlock(&bb, llvmBB);
570 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
577 unsigned numYields = 0;
579 if (!isLoopWrapper) {
580 bool operandsProcessed =
false;
582 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
583 if (!operandsProcessed) {
584 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
585 continuationBlockPHITypes.push_back(
586 moduleTranslation.
convertType(yield->getOperand(i).getType()));
588 operandsProcessed =
true;
590 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
591 "mismatching number of values yielded from the region");
592 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
593 llvm::Type *operandType =
594 moduleTranslation.
convertType(yield->getOperand(i).getType());
596 assert(continuationBlockPHITypes[i] == operandType &&
597 "values of mismatching types yielded from the region");
607 if (!continuationBlockPHITypes.empty())
609 continuationBlockPHIs &&
610 "expected continuation block PHIs if converted regions yield values");
611 if (continuationBlockPHIs) {
612 llvm::IRBuilderBase::InsertPointGuard guard(builder);
613 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
614 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
615 for (llvm::Type *ty : continuationBlockPHITypes)
616 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
622 for (
Block *bb : blocks) {
623 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
626 if (bb->isEntryBlock()) {
627 assert(sourceTerminator->getNumSuccessors() == 1 &&
628 "provided entry block has multiple successors");
629 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
630 "ContinuationBlock is not the successor of the entry block");
631 sourceTerminator->setSuccessor(0, llvmBB);
634 llvm::IRBuilderBase::InsertPointGuard guard(builder);
636 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
637 return llvm::make_error<PreviouslyReportedError>();
642 builder.CreateBr(continuationBlock);
653 Operation *terminator = bb->getTerminator();
654 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
655 builder.CreateBr(continuationBlock);
657 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
658 (*continuationBlockPHIs)[i]->addIncoming(
672 return continuationBlock;
678 case omp::ClauseProcBindKind::Close:
679 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
680 case omp::ClauseProcBindKind::Master:
681 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
682 case omp::ClauseProcBindKind::Primary:
683 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
684 case omp::ClauseProcBindKind::Spread:
685 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
687 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
694 auto maskedOp = cast<omp::MaskedOp>(opInst);
695 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
700 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
702 auto ®ion = maskedOp.getRegion();
703 builder.restoreIP(codeGenIP);
711 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
713 llvm::Value *filterVal =
nullptr;
714 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
715 filterVal = moduleTranslation.
lookupValue(filterVar);
717 llvm::LLVMContext &llvmContext = builder.getContext();
719 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
721 assert(filterVal !=
nullptr);
722 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
723 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
730 builder.restoreIP(*afterIP);
738 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
739 auto masterOp = cast<omp::MasterOp>(opInst);
744 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
746 auto ®ion = masterOp.getRegion();
747 builder.restoreIP(codeGenIP);
755 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
757 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
758 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
765 builder.restoreIP(*afterIP);
773 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
774 auto criticalOp = cast<omp::CriticalOp>(opInst);
779 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
781 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
782 builder.restoreIP(codeGenIP);
790 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
792 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
793 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
794 llvm::Constant *hint =
nullptr;
797 if (criticalOp.getNameAttr()) {
800 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
801 auto criticalDeclareOp =
805 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
806 static_cast<int>(criticalDeclareOp.getHint()));
808 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
810 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
815 builder.restoreIP(*afterIP);
822 template <
typename OP>
825 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
828 collectPrivatizationDecls<OP>(op);
843 void collectPrivatizationDecls(OP op) {
844 std::optional<ArrayAttr> attr = op.getPrivateSyms();
849 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
860 std::optional<ArrayAttr> attr = op.getReductionSyms();
864 reductions.reserve(reductions.size() + op.getNumReductionVars());
865 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
866 reductions.push_back(
878 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
887 llvm::Instruction *potentialTerminator =
888 builder.GetInsertBlock()->empty() ?
nullptr
889 : &builder.GetInsertBlock()->back();
891 if (potentialTerminator && potentialTerminator->isTerminator())
892 potentialTerminator->removeFromParent();
893 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
896 region.
front(),
true, builder)))
900 if (continuationBlockArgs)
902 *continuationBlockArgs,
909 if (potentialTerminator && potentialTerminator->isTerminator()) {
910 llvm::BasicBlock *block = builder.GetInsertBlock();
911 if (block->empty()) {
917 potentialTerminator->insertInto(block, block->begin());
919 potentialTerminator->insertAfter(&block->back());
933 if (continuationBlockArgs)
934 llvm::append_range(*continuationBlockArgs, phis);
935 builder.SetInsertPoint(*continuationBlock,
936 (*continuationBlock)->getFirstInsertionPt());
943using OwningReductionGen =
944 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
945 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
947using OwningAtomicReductionGen =
948 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
949 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
951using OwningDataPtrPtrReductionGen =
952 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
953 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
959static OwningReductionGen
965 OwningReductionGen gen =
966 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
967 llvm::Value *
lhs, llvm::Value *
rhs,
968 llvm::Value *&
result)
mutable
969 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
970 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
971 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
972 builder.restoreIP(insertPoint);
975 "omp.reduction.nonatomic.body", builder,
976 moduleTranslation, &phis)))
977 return llvm::createStringError(
978 "failed to inline `combiner` region of `omp.declare_reduction`");
979 result = llvm::getSingleElement(phis);
980 return builder.saveIP();
989static OwningAtomicReductionGen
991 llvm::IRBuilderBase &builder,
993 if (decl.getAtomicReductionRegion().empty())
994 return OwningAtomicReductionGen();
1000 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1001 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1002 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1003 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1004 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1005 builder.restoreIP(insertPoint);
1008 "omp.reduction.atomic.body", builder,
1009 moduleTranslation, &phis)))
1010 return llvm::createStringError(
1011 "failed to inline `atomic` region of `omp.declare_reduction`");
1012 assert(phis.empty());
1013 return builder.saveIP();
1022static OwningDataPtrPtrReductionGen
1026 return OwningDataPtrPtrReductionGen();
1028 OwningDataPtrPtrReductionGen refDataPtrGen =
1029 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1030 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1031 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1032 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1033 builder.restoreIP(insertPoint);
1036 "omp.data_ptr_ptr.body", builder,
1037 moduleTranslation, &phis)))
1038 return llvm::createStringError(
1039 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1040 result = llvm::getSingleElement(phis);
1041 return builder.saveIP();
1044 return refDataPtrGen;
1051 auto orderedOp = cast<omp::OrderedOp>(opInst);
1056 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1057 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1058 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1060 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1062 size_t indexVecValues = 0;
1063 while (indexVecValues < vecValues.size()) {
1065 storeValues.reserve(numLoops);
1066 for (
unsigned i = 0; i < numLoops; i++) {
1067 storeValues.push_back(vecValues[indexVecValues]);
1070 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1072 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1073 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1074 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1084 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1085 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1090 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1092 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1093 builder.restoreIP(codeGenIP);
1101 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1103 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1104 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1106 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1111 builder.restoreIP(*afterIP);
1117struct DeferredStore {
1118 DeferredStore(llvm::Value *value, llvm::Value *address)
1119 : value(value), address(address) {}
1122 llvm::Value *address;
1129template <
typename T>
1132 llvm::IRBuilderBase &builder,
1134 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1140 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1141 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1144 deferredStores.reserve(loop.getNumReductionVars());
1146 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1147 Region &allocRegion = reductionDecls[i].getAllocRegion();
1149 if (allocRegion.
empty())
1154 builder, moduleTranslation, &phis)))
1155 return loop.emitError(
1156 "failed to inline `alloc` region of `omp.declare_reduction`");
1158 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1159 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1163 llvm::Value *var = builder.CreateAlloca(
1164 moduleTranslation.
convertType(reductionDecls[i].getType()));
1166 llvm::Type *ptrTy = builder.getPtrTy();
1167 llvm::Value *castVar =
1168 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1169 llvm::Value *castPhi =
1170 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1172 deferredStores.emplace_back(castPhi, castVar);
1174 privateReductionVariables[i] = castVar;
1175 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1176 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1178 assert(allocRegion.
empty() &&
1179 "allocaction is implicit for by-val reduction");
1180 llvm::Value *var = builder.CreateAlloca(
1181 moduleTranslation.
convertType(reductionDecls[i].getType()));
1183 llvm::Type *ptrTy = builder.getPtrTy();
1184 llvm::Value *castVar =
1185 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1187 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1188 privateReductionVariables[i] = castVar;
1189 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1197template <
typename T>
1200 llvm::IRBuilderBase &builder,
1205 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1206 Region &initializerRegion = reduction.getInitializerRegion();
1209 mlir::Value mlirSource = loop.getReductionVars()[i];
1210 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1211 llvm::Value *origVal = llvmSource;
1213 if (!isa<LLVM::LLVMPointerType>(
1214 reduction.getInitializerMoldArg().getType()) &&
1215 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1218 reduction.getInitializerMoldArg().getType()),
1219 llvmSource,
"omp_orig");
1221 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1224 llvm::Value *allocation =
1225 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1226 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1232 llvm::BasicBlock *block =
nullptr) {
1233 if (block ==
nullptr)
1234 block = builder.GetInsertBlock();
1236 if (block->empty() || block->getTerminator() ==
nullptr)
1237 builder.SetInsertPoint(block);
1239 builder.SetInsertPoint(block->getTerminator());
1247template <
typename OP>
1250 llvm::IRBuilderBase &builder,
1252 llvm::BasicBlock *latestAllocaBlock,
1258 if (op.getNumReductionVars() == 0)
1261 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1262 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1263 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1264 builder.restoreIP(allocaIP);
1267 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1269 if (!reductionDecls[i].getAllocRegion().empty())
1275 byRefVars[i] = builder.CreateAlloca(
1276 moduleTranslation.
convertType(reductionDecls[i].getType()));
1284 for (
auto [data, addr] : deferredStores)
1285 builder.CreateStore(data, addr);
1290 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1295 reductionVariableMap, i);
1303 "omp.reduction.neutral", builder,
1304 moduleTranslation, &phis)))
1307 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1308 "reduction neutral element declaration region");
1313 if (!reductionDecls[i].getAllocRegion().empty())
1322 builder.CreateStore(phis[0], byRefVars[i]);
1324 privateReductionVariables[i] = byRefVars[i];
1325 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1326 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1329 builder.CreateStore(phis[0], privateReductionVariables[i]);
1336 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1343template <
typename T>
1344static void collectReductionInfo(
1345 T loop, llvm::IRBuilderBase &builder,
1354 unsigned numReductions = loop.getNumReductionVars();
1356 for (
unsigned i = 0; i < numReductions; ++i) {
1359 owningAtomicReductionGens.push_back(
1362 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1366 reductionInfos.reserve(numReductions);
1367 for (
unsigned i = 0; i < numReductions; ++i) {
1368 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1369 if (owningAtomicReductionGens[i])
1370 atomicGen = owningAtomicReductionGens[i];
1371 llvm::Value *variable =
1372 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1375 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1376 allocatedType = alloca.getElemType();
1383 reductionInfos.push_back(
1385 privateReductionVariables[i],
1386 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1390 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1391 reductionDecls[i].getByrefElementType()
1393 *reductionDecls[i].getByrefElementType())
1403 llvm::IRBuilderBase &builder, StringRef regionName,
1404 bool shouldLoadCleanupRegionArg =
true) {
1405 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1406 if (cleanupRegion->empty())
1412 llvm::Instruction *potentialTerminator =
1413 builder.GetInsertBlock()->empty() ?
nullptr
1414 : &builder.GetInsertBlock()->back();
1415 if (potentialTerminator && potentialTerminator->isTerminator())
1416 builder.SetInsertPoint(potentialTerminator);
1417 llvm::Value *privateVarValue =
1418 shouldLoadCleanupRegionArg
1419 ? builder.CreateLoad(
1421 privateVariables[i])
1422 : privateVariables[i];
1427 moduleTranslation)))
1440 OP op, llvm::IRBuilderBase &builder,
1442 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1445 bool isNowait =
false,
bool isTeamsReduction =
false) {
1447 if (op.getNumReductionVars() == 0)
1459 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1461 owningReductionGenRefDataPtrGens,
1462 privateReductionVariables, reductionInfos, isByRef);
1467 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1468 builder.SetInsertPoint(tempTerminator);
1469 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1470 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1471 isByRef, isNowait, isTeamsReduction);
1476 if (!contInsertPoint->getBlock())
1477 return op->emitOpError() <<
"failed to convert reductions";
1479 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1480 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1485 tempTerminator->eraseFromParent();
1486 builder.restoreIP(*afterIP);
1490 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1491 [](omp::DeclareReductionOp reductionDecl) {
1492 return &reductionDecl.getCleanupRegion();
1495 moduleTranslation, builder,
1496 "omp.reduction.cleanup");
1507template <
typename OP>
1511 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1516 if (op.getNumReductionVars() == 0)
1522 allocaIP, reductionDecls,
1523 privateReductionVariables, reductionVariableMap,
1524 deferredStores, isByRef)))
1527 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1528 allocaIP.getBlock(), reductionDecls,
1529 privateReductionVariables, reductionVariableMap,
1530 isByRef, deferredStores);
1544 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1547 Value blockArg = (*mappedPrivateVars)[privateVar];
1550 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1551 "A block argument corresponding to a mapped var should have "
1554 if (privVarType == blockArgType)
1561 if (!isa<LLVM::LLVMPointerType>(privVarType))
1562 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1575 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1577 llvm::BasicBlock *privInitBlock,
1579 Region &initRegion = privDecl.getInitRegion();
1580 if (initRegion.
empty())
1581 return llvmPrivateVar;
1583 assert(nonPrivateVar);
1584 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1585 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1590 moduleTranslation, &phis)))
1591 return llvm::createStringError(
1592 "failed to inline `init` region of `omp.private`");
1594 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1611 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1614 builder, moduleTranslation, privDecl,
1617 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1626 return llvm::Error::success();
1628 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1631 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1634 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1636 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1637 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1640 return privVarOrErr.takeError();
1642 llvmPrivateVar = privVarOrErr.get();
1643 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1648 return llvm::Error::success();
1658 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1661 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1662 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1663 allocaTerminator->getIterator()),
1664 true, allocaTerminator->getStableDebugLoc(),
1665 "omp.region.after_alloca");
1667 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1669 allocaTerminator = allocaIP.getBlock()->getTerminator();
1670 builder.SetInsertPoint(allocaTerminator);
1672 assert(allocaTerminator->getNumSuccessors() == 1 &&
1673 "This is an unconditional branch created by splitBB");
1675 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1676 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1678 unsigned int allocaAS =
1679 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1682 .getProgramAddressSpace();
1684 for (
auto [privDecl, mlirPrivVar, blockArg] :
1687 llvm::Type *llvmAllocType =
1688 moduleTranslation.
convertType(privDecl.getType());
1689 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1690 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1691 llvmAllocType,
nullptr,
"omp.private.alloc");
1692 if (allocaAS != defaultAS)
1693 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1694 builder.getPtrTy(defaultAS));
1696 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1699 return afterAllocas;
1707 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1716 if (mlir::isa<omp::ParallelOp>(parent))
1730 bool needsFirstprivate =
1731 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1732 return privOp.getDataSharingType() ==
1733 omp::DataSharingClauseType::FirstPrivate;
1736 if (!needsFirstprivate)
1739 llvm::BasicBlock *copyBlock =
1740 splitBB(builder,
true,
"omp.private.copy");
1743 for (
auto [decl, moldVar, llvmVar] :
1744 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1745 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1749 Region ©Region = decl.getCopyRegion();
1751 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1754 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1758 moduleTranslation)))
1759 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1773 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1774 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1790 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1792 llvm::Value *moldVar = findAssociatedValue(
1793 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1798 llvmPrivateVars, privateDecls, insertBarrier,
1809 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1810 [](omp::PrivateClauseOp privatizer) {
1811 return &privatizer.getDeallocRegion();
1815 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1816 "omp.private.dealloc",
false)))
1817 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1818 "`omp.private` op in");
1830 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1840 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1841 using StorableBodyGenCallbackTy =
1842 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1844 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1850 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1854 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1858 sectionsOp.getNumReductionVars());
1862 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1865 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1866 reductionDecls, privateReductionVariables, reductionVariableMap,
1873 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1877 Region ®ion = sectionOp.getRegion();
1878 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1879 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1880 builder.restoreIP(codeGenIP);
1887 sectionsOp.getRegion().getNumArguments());
1888 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1889 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1890 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1892 moduleTranslation.
mapValue(sectionArg, llvmVal);
1899 sectionCBs.push_back(sectionCB);
1905 if (sectionCBs.empty())
1908 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1913 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1914 llvm::Value &vPtr, llvm::Value *&replacementValue)
1915 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1916 replacementValue = &vPtr;
1922 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1926 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1927 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1929 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1930 sectionsOp.getNowait());
1935 builder.restoreIP(*afterIP);
1939 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1940 privateReductionVariables, isByRef, sectionsOp.getNowait());
1947 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1948 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1953 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1954 builder.restoreIP(codegenIP);
1956 builder, moduleTranslation)
1959 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1963 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1966 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1967 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1969 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1970 llvmCPFuncs.push_back(
1974 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1976 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1982 builder.restoreIP(*afterIP);
1988 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1993 for (
auto ra : iface.getReductionBlockArgs())
1994 for (
auto &use : ra.getUses()) {
1995 auto *useOp = use.getOwner();
1997 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1998 debugUses.push_back(useOp);
2002 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2007 Operation *currentOp = currentDistOp.getOperation();
2008 if (distOp && (distOp != currentOp))
2017 for (
auto *use : debugUses)
2026 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2031 unsigned numReductionVars = op.getNumReductionVars();
2035 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2041 if (doTeamsReduction) {
2042 isByRef =
getIsByRef(op.getReductionByref());
2044 assert(isByRef.size() == op.getNumReductionVars());
2047 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2052 op, reductionArgs, builder, moduleTranslation, allocaIP,
2053 reductionDecls, privateReductionVariables, reductionVariableMap,
2058 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2060 moduleTranslation, allocaIP);
2061 builder.restoreIP(codegenIP);
2067 llvm::Value *numTeamsLower =
nullptr;
2068 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2069 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2071 llvm::Value *numTeamsUpper =
nullptr;
2072 if (!op.getNumTeamsUpperVars().empty())
2073 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2075 llvm::Value *threadLimit =
nullptr;
2076 if (!op.getThreadLimitVars().empty())
2077 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2079 llvm::Value *ifExpr =
nullptr;
2080 if (
Value ifVar = op.getIfExpr())
2083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2084 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2086 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2091 builder.restoreIP(*afterIP);
2092 if (doTeamsReduction) {
2095 op, builder, moduleTranslation, allocaIP, reductionDecls,
2096 privateReductionVariables, isByRef,
2106 if (dependVars.empty())
2108 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2109 llvm::omp::RTLDependenceKindTy type;
2111 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2112 case mlir::omp::ClauseTaskDepend::taskdependin:
2113 type = llvm::omp::RTLDependenceKindTy::DepIn;
2118 case mlir::omp::ClauseTaskDepend::taskdependout:
2119 case mlir::omp::ClauseTaskDepend::taskdependinout:
2120 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2122 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2125 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2126 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2129 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2130 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2131 dds.emplace_back(dd);
2143 llvm::IRBuilderBase &llvmBuilder,
2145 llvm::omp::Directive cancelDirective) {
2146 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2147 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2151 llvmBuilder.restoreIP(ip);
2157 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2158 return llvm::Error::success();
2163 ompBuilder.pushFinalizationCB(
2173 llvm::OpenMPIRBuilder &ompBuilder,
2174 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2175 ompBuilder.popFinalizationCB();
2176 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2177 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2178 assert(cancelBranch->getNumSuccessors() == 1 &&
2179 "cancel branch should have one target");
2180 cancelBranch->setSuccessor(0, constructFini);
2187class TaskContextStructManager {
2189 TaskContextStructManager(llvm::IRBuilderBase &builder,
2190 LLVM::ModuleTranslation &moduleTranslation,
2191 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2192 : builder{builder}, moduleTranslation{moduleTranslation},
2193 privateDecls{privateDecls} {}
2199 void generateTaskContextStruct();
2205 void createGEPsToPrivateVars();
2211 SmallVector<llvm::Value *>
2212 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2215 void freeStructPtr();
2217 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2218 return llvmPrivateVarGEPs;
2221 llvm::Value *getStructPtr() {
return structPtr; }
2224 llvm::IRBuilderBase &builder;
2225 LLVM::ModuleTranslation &moduleTranslation;
2226 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2229 SmallVector<llvm::Type *> privateVarTypes;
2233 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2236 llvm::Value *structPtr =
nullptr;
2238 llvm::Type *structTy =
nullptr;
2242void TaskContextStructManager::generateTaskContextStruct() {
2243 if (privateDecls.empty())
2245 privateVarTypes.reserve(privateDecls.size());
2247 for (omp::PrivateClauseOp &privOp : privateDecls) {
2250 if (!privOp.readsFromMold())
2252 Type mlirType = privOp.getType();
2253 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2256 if (privateVarTypes.empty())
2259 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2262 llvm::DataLayout dataLayout =
2263 builder.GetInsertBlock()->getModule()->getDataLayout();
2264 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2265 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2268 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2270 "omp.task.context_ptr");
2273SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2274 llvm::Value *altStructPtr)
const {
2275 SmallVector<llvm::Value *> ret;
2278 ret.reserve(privateDecls.size());
2279 llvm::Value *zero = builder.getInt32(0);
2281 for (
auto privDecl : privateDecls) {
2282 if (!privDecl.readsFromMold()) {
2284 ret.push_back(
nullptr);
2287 llvm::Value *iVal = builder.getInt32(i);
2288 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2295void TaskContextStructManager::createGEPsToPrivateVars() {
2297 assert(privateVarTypes.empty());
2301 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2304void TaskContextStructManager::freeStructPtr() {
2308 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2310 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2311 builder.CreateFree(structPtr);
2318 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2323 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2335 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2340 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2341 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2342 builder.getContext(),
"omp.task.start",
2343 builder.GetInsertBlock()->getParent());
2344 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2345 builder.SetInsertPoint(branchToTaskStartBlock);
2348 llvm::BasicBlock *copyBlock =
2349 splitBB(builder,
true,
"omp.private.copy");
2350 llvm::BasicBlock *initBlock =
2351 splitBB(builder,
true,
"omp.private.init");
2367 moduleTranslation, allocaIP);
2370 builder.SetInsertPoint(initBlock->getTerminator());
2373 taskStructMgr.generateTaskContextStruct();
2380 taskStructMgr.createGEPsToPrivateVars();
2382 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2385 taskStructMgr.getLLVMPrivateVarGEPs())) {
2387 if (!privDecl.readsFromMold())
2389 assert(llvmPrivateVarAlloc &&
2390 "reads from mold so shouldn't have been skipped");
2393 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2394 blockArg, llvmPrivateVarAlloc, initBlock);
2395 if (!privateVarOrErr)
2396 return handleError(privateVarOrErr, *taskOp.getOperation());
2405 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2406 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2407 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2409 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2410 llvmPrivateVarAlloc);
2412 assert(llvmPrivateVarAlloc->getType() ==
2413 moduleTranslation.
convertType(blockArg.getType()));
2423 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2424 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2425 taskOp.getPrivateNeedsBarrier())))
2426 return llvm::failure();
2429 builder.SetInsertPoint(taskStartBlock);
2431 auto bodyCB = [&](InsertPointTy allocaIP,
2432 InsertPointTy codegenIP) -> llvm::Error {
2436 moduleTranslation, allocaIP);
2439 builder.restoreIP(codegenIP);
2441 llvm::BasicBlock *privInitBlock =
nullptr;
2443 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2446 auto [blockArg, privDecl, mlirPrivVar] = zip;
2448 if (privDecl.readsFromMold())
2451 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2452 llvm::Type *llvmAllocType =
2453 moduleTranslation.
convertType(privDecl.getType());
2454 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2455 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2456 llvmAllocType,
nullptr,
"omp.private.alloc");
2459 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2460 blockArg, llvmPrivateVar, privInitBlock);
2461 if (!privateVarOrError)
2462 return privateVarOrError.takeError();
2463 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2464 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2467 taskStructMgr.createGEPsToPrivateVars();
2468 for (
auto [i, llvmPrivVar] :
2469 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2471 assert(privateVarsInfo.
llvmVars[i] &&
2472 "This is added in the loop above");
2475 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2480 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2484 if (!privateDecl.readsFromMold())
2487 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2488 llvmPrivateVar = builder.CreateLoad(
2489 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2491 assert(llvmPrivateVar->getType() ==
2492 moduleTranslation.
convertType(blockArg.getType()));
2493 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2497 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2498 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2499 return llvm::make_error<PreviouslyReportedError>();
2501 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2506 return llvm::make_error<PreviouslyReportedError>();
2509 taskStructMgr.freeStructPtr();
2511 return llvm::Error::success();
2520 llvm::omp::Directive::OMPD_taskgroup);
2524 moduleTranslation, dds);
2526 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2527 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2529 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2531 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2532 taskOp.getMergeable(),
2533 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2534 moduleTranslation.
lookupValue(taskOp.getPriority()));
2542 builder.restoreIP(*afterIP);
2550 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2551 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2559 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2562 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2565 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2566 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2567 builder.getContext(),
"omp.taskloop.start",
2568 builder.GetInsertBlock()->getParent());
2569 llvm::Instruction *branchToTaskloopStartBlock =
2570 builder.CreateBr(taskloopStartBlock);
2571 builder.SetInsertPoint(branchToTaskloopStartBlock);
2573 llvm::BasicBlock *copyBlock =
2574 splitBB(builder,
true,
"omp.private.copy");
2575 llvm::BasicBlock *initBlock =
2576 splitBB(builder,
true,
"omp.private.init");
2579 moduleTranslation, allocaIP);
2582 builder.SetInsertPoint(initBlock->getTerminator());
2585 taskStructMgr.generateTaskContextStruct();
2586 taskStructMgr.createGEPsToPrivateVars();
2588 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
2590 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2592 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2593 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2595 if (!privDecl.readsFromMold())
2597 assert(llvmPrivateVarAlloc &&
2598 "reads from mold so shouldn't have been skipped");
2601 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2602 blockArg, llvmPrivateVarAlloc, initBlock);
2603 if (!privateVarOrErr)
2604 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2606 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2608 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2609 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2611 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2612 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2613 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2615 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2616 llvmPrivateVarAlloc);
2618 assert(llvmPrivateVarAlloc->getType() ==
2619 moduleTranslation.
convertType(blockArg.getType()));
2625 taskloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2626 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2627 taskloopOp.getPrivateNeedsBarrier())))
2628 return llvm::failure();
2631 builder.SetInsertPoint(taskloopStartBlock);
2633 auto bodyCB = [&](InsertPointTy allocaIP,
2634 InsertPointTy codegenIP) -> llvm::Error {
2638 moduleTranslation, allocaIP);
2641 builder.restoreIP(codegenIP);
2643 llvm::BasicBlock *privInitBlock =
nullptr;
2645 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2648 auto [blockArg, privDecl, mlirPrivVar] = zip;
2650 if (privDecl.readsFromMold())
2653 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2654 llvm::Type *llvmAllocType =
2655 moduleTranslation.
convertType(privDecl.getType());
2656 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2657 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2658 llvmAllocType,
nullptr,
"omp.private.alloc");
2661 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2662 blockArg, llvmPrivateVar, privInitBlock);
2663 if (!privateVarOrError)
2664 return privateVarOrError.takeError();
2665 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2666 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2669 taskStructMgr.createGEPsToPrivateVars();
2670 for (
auto [i, llvmPrivVar] :
2671 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2673 assert(privateVarsInfo.
llvmVars[i] &&
2674 "This is added in the loop above");
2677 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2682 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2686 if (!privateDecl.readsFromMold())
2689 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2690 llvmPrivateVar = builder.CreateLoad(
2691 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2693 assert(llvmPrivateVar->getType() ==
2694 moduleTranslation.
convertType(blockArg.getType()));
2695 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2698 auto continuationBlockOrError =
2700 builder, moduleTranslation);
2702 if (failed(
handleError(continuationBlockOrError, opInst)))
2703 return llvm::make_error<PreviouslyReportedError>();
2705 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2713 taskloopOp.getLoc(), privateVarsInfo.
llvmVars,
2715 return llvm::make_error<PreviouslyReportedError>();
2718 taskStructMgr.freeStructPtr();
2720 return llvm::Error::success();
2726 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2727 llvm::Value *destPtr, llvm::Value *srcPtr)
2729 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2730 builder.restoreIP(codegenIP);
2733 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2735 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
2737 TaskContextStructManager &srcStructMgr = taskStructMgr;
2738 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2740 destStructMgr.generateTaskContextStruct();
2741 llvm::Value *dest = destStructMgr.getStructPtr();
2742 dest->setName(
"omp.taskloop.context.dest");
2743 builder.CreateStore(dest, destPtr);
2746 srcStructMgr.createGEPsToPrivateVars(src);
2748 destStructMgr.createGEPsToPrivateVars(dest);
2751 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2752 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
2755 if (!privDecl.readsFromMold())
2757 assert(llvmPrivateVarAlloc &&
2758 "reads from mold so shouldn't have been skipped");
2761 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2762 llvmPrivateVarAlloc, builder.GetInsertBlock());
2763 if (!privateVarOrErr)
2764 return privateVarOrErr.takeError();
2773 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2774 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2775 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2777 llvmPrivateVarAlloc = builder.CreateLoad(
2778 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2780 assert(llvmPrivateVarAlloc->getType() ==
2781 moduleTranslation.
convertType(blockArg.getType()));
2789 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2790 privateVarsInfo.
privatizers, taskloopOp.getPrivateNeedsBarrier())))
2791 return llvm::make_error<PreviouslyReportedError>();
2793 return builder.saveIP();
2796 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2803 llvm::Value *ifCond =
nullptr;
2804 llvm::Value *grainsize =
nullptr;
2806 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2807 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2808 if (
Value ifVar = taskloopOp.getIfExpr())
2811 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
2813 }
else if (numTasksVal) {
2814 grainsize = moduleTranslation.
lookupValue(numTasksVal);
2818 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
2819 if (taskStructMgr.getStructPtr())
2820 taskDupOrNull = taskDupCB;
2830 llvm::omp::Directive::OMPD_taskgroup);
2832 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2833 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2835 ompLoc, allocaIP, bodyCB, loopInfo,
2836 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[0]),
2837 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[0]),
2838 moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]),
2839 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2840 sched, moduleTranslation.
lookupValue(taskloopOp.getFinal()),
2841 taskloopOp.getMergeable(),
2842 moduleTranslation.
lookupValue(taskloopOp.getPriority()),
2843 taskDupOrNull, taskStructMgr.getStructPtr());
2850 builder.restoreIP(*afterIP);
2858 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2862 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2863 builder.restoreIP(codegenIP);
2865 builder, moduleTranslation)
2870 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2871 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2878 builder.restoreIP(*afterIP);
2897 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2901 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2903 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2907 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2910 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2911 llvm::Type *ivType = step->getType();
2912 llvm::Value *chunk =
nullptr;
2913 if (wsloopOp.getScheduleChunk()) {
2914 llvm::Value *chunkVar =
2915 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2916 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2919 omp::DistributeOp distributeOp =
nullptr;
2920 llvm::Value *distScheduleChunk =
nullptr;
2921 bool hasDistSchedule =
false;
2922 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
2923 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
2924 hasDistSchedule = distributeOp.getDistScheduleStatic();
2925 if (distributeOp.getDistScheduleChunkSize()) {
2926 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
2927 distributeOp.getDistScheduleChunkSize());
2928 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2936 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2940 wsloopOp.getNumReductionVars());
2943 builder, moduleTranslation, privateVarsInfo, allocaIP);
2950 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2955 moduleTranslation, allocaIP, reductionDecls,
2956 privateReductionVariables, reductionVariableMap,
2957 deferredStores, isByRef)))
2966 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2968 wsloopOp.getPrivateNeedsBarrier())))
2971 assert(afterAllocas.get()->getSinglePredecessor());
2972 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2974 afterAllocas.get()->getSinglePredecessor(),
2975 reductionDecls, privateReductionVariables,
2976 reductionVariableMap, isByRef, deferredStores)))
2980 bool isOrdered = wsloopOp.getOrdered().has_value();
2981 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2982 bool isSimd = wsloopOp.getScheduleSimd();
2983 bool loopNeedsBarrier = !wsloopOp.getNowait();
2988 llvm::omp::WorksharingLoopType workshareLoopType =
2989 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2990 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2991 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2995 llvm::omp::Directive::OMPD_for);
2997 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3000 LinearClauseProcessor linearClauseProcessor;
3002 if (!wsloopOp.getLinearVars().empty()) {
3003 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3005 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3007 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3008 linearClauseProcessor.createLinearVar(
3009 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3011 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3012 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3016 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3024 if (!wsloopOp.getLinearVars().empty()) {
3025 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3026 loopInfo->getPreheader());
3027 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3029 builder.saveIP(), llvm::omp::OMPD_barrier);
3032 builder.restoreIP(*afterBarrierIP);
3033 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3034 loopInfo->getIndVar());
3035 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3038 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3041 bool noLoopMode =
false;
3042 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3044 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3048 if (loopOp == targetCapturedOp) {
3049 omp::TargetRegionFlags kernelFlags =
3050 targetOp.getKernelExecFlags(targetCapturedOp);
3051 if (omp::bitEnumContainsAll(kernelFlags,
3052 omp::TargetRegionFlags::spmd |
3053 omp::TargetRegionFlags::no_loop) &&
3054 !omp::bitEnumContainsAny(kernelFlags,
3055 omp::TargetRegionFlags::generic))
3060 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3061 ompBuilder->applyWorkshareLoop(
3062 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3063 convertToScheduleKind(schedule), chunk, isSimd,
3064 scheduleMod == omp::ScheduleModifier::monotonic,
3065 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3066 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3072 if (!wsloopOp.getLinearVars().empty()) {
3073 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3074 assert(loopInfo->getLastIter() &&
3075 "`lastiter` in CanonicalLoopInfo is nullptr");
3076 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3077 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3078 loopInfo->getLastIter());
3081 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3082 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3084 builder.restoreIP(oldIP);
3092 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3093 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3106 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3108 assert(isByRef.size() == opInst.getNumReductionVars());
3121 opInst.getNumReductionVars());
3124 auto bodyGenCB = [&](InsertPointTy allocaIP,
3125 InsertPointTy codeGenIP) -> llvm::Error {
3127 builder, moduleTranslation, privateVarsInfo, allocaIP);
3129 return llvm::make_error<PreviouslyReportedError>();
3135 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3138 InsertPointTy(allocaIP.getBlock(),
3139 allocaIP.getBlock()->getTerminator()->getIterator());
3142 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3143 reductionDecls, privateReductionVariables, reductionVariableMap,
3144 deferredStores, isByRef)))
3145 return llvm::make_error<PreviouslyReportedError>();
3147 assert(afterAllocas.get()->getSinglePredecessor());
3148 builder.restoreIP(codeGenIP);
3154 return llvm::make_error<PreviouslyReportedError>();
3157 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3159 opInst.getPrivateNeedsBarrier())))
3160 return llvm::make_error<PreviouslyReportedError>();
3163 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3164 afterAllocas.get()->getSinglePredecessor(),
3165 reductionDecls, privateReductionVariables,
3166 reductionVariableMap, isByRef, deferredStores)))
3167 return llvm::make_error<PreviouslyReportedError>();
3172 moduleTranslation, allocaIP);
3176 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3178 return regionBlock.takeError();
3181 if (opInst.getNumReductionVars() > 0) {
3186 owningReductionGenRefDataPtrGens;
3188 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3190 owningReductionGenRefDataPtrGens,
3191 privateReductionVariables, reductionInfos, isByRef);
3194 builder.SetInsertPoint((*regionBlock)->getTerminator());
3197 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3198 builder.SetInsertPoint(tempTerminator);
3200 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3201 ompBuilder->createReductions(
3202 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3204 if (!contInsertPoint)
3205 return contInsertPoint.takeError();
3207 if (!contInsertPoint->getBlock())
3208 return llvm::make_error<PreviouslyReportedError>();
3210 tempTerminator->eraseFromParent();
3211 builder.restoreIP(*contInsertPoint);
3214 return llvm::Error::success();
3217 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3218 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3227 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3228 InsertPointTy oldIP = builder.saveIP();
3229 builder.restoreIP(codeGenIP);
3234 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3235 [](omp::DeclareReductionOp reductionDecl) {
3236 return &reductionDecl.getCleanupRegion();
3239 reductionCleanupRegions, privateReductionVariables,
3240 moduleTranslation, builder,
"omp.reduction.cleanup")))
3241 return llvm::createStringError(
3242 "failed to inline `cleanup` region of `omp.declare_reduction`");
3247 return llvm::make_error<PreviouslyReportedError>();
3251 if (isCancellable) {
3252 auto IPOrErr = ompBuilder->createBarrier(
3253 llvm::OpenMPIRBuilder::LocationDescription(builder),
3254 llvm::omp::Directive::OMPD_unknown,
3258 return IPOrErr.takeError();
3261 builder.restoreIP(oldIP);
3262 return llvm::Error::success();
3265 llvm::Value *ifCond =
nullptr;
3266 if (
auto ifVar = opInst.getIfExpr())
3268 llvm::Value *numThreads =
nullptr;
3269 if (!opInst.getNumThreadsVars().empty())
3270 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
3271 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3272 if (
auto bind = opInst.getProcBindKind())
3275 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3277 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3279 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3280 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3281 ifCond, numThreads, pbKind, isCancellable);
3286 builder.restoreIP(*afterIP);
3291static llvm::omp::OrderKind
3294 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3296 case omp::ClauseOrderKind::Concurrent:
3297 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3299 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3307 auto simdOp = cast<omp::SimdOp>(opInst);
3315 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3318 simdOp.getNumReductionVars());
3323 assert(isByRef.size() == simdOp.getNumReductionVars());
3325 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3329 builder, moduleTranslation, privateVarsInfo, allocaIP);
3334 LinearClauseProcessor linearClauseProcessor;
3336 if (!simdOp.getLinearVars().empty()) {
3337 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3339 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3340 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3341 bool isImplicit =
false;
3342 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3346 if (linearVar == mlirPrivVar) {
3348 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3349 llvmPrivateVar, idx);
3355 linearClauseProcessor.createLinearVar(
3356 builder, moduleTranslation,
3359 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3360 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3364 moduleTranslation, allocaIP, reductionDecls,
3365 privateReductionVariables, reductionVariableMap,
3366 deferredStores, isByRef)))
3377 assert(afterAllocas.get()->getSinglePredecessor());
3378 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3380 afterAllocas.get()->getSinglePredecessor(),
3381 reductionDecls, privateReductionVariables,
3382 reductionVariableMap, isByRef, deferredStores)))
3385 llvm::ConstantInt *simdlen =
nullptr;
3386 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3387 simdlen = builder.getInt64(simdlenVar.value());
3389 llvm::ConstantInt *safelen =
nullptr;
3390 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3391 safelen = builder.getInt64(safelenVar.value());
3393 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3396 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3397 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3399 for (
size_t i = 0; i < operands.size(); ++i) {
3400 llvm::Value *alignment =
nullptr;
3401 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3402 llvm::Type *ty = llvmVal->getType();
3404 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3405 alignment = builder.getInt64(intAttr.getInt());
3406 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3407 assert(alignment &&
"Invalid alignment value");
3411 if (!intAttr.getValue().isPowerOf2())
3414 auto curInsert = builder.saveIP();
3415 builder.SetInsertPoint(sourceBlock);
3416 llvmVal = builder.CreateLoad(ty, llvmVal);
3417 builder.restoreIP(curInsert);
3418 alignedVars[llvmVal] = alignment;
3422 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3429 if (simdOp.getLinearVars().size()) {
3430 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3431 loopInfo->getPreheader());
3433 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3434 loopInfo->getIndVar());
3436 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3438 ompBuilder->applySimd(loopInfo, alignedVars,
3440 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3442 order, simdlen, safelen);
3444 linearClauseProcessor.emitStoresForLinearVar(builder);
3445 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++)
3446 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3453 for (
auto [i, tuple] : llvm::enumerate(
3454 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3455 privateReductionVariables))) {
3456 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3458 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3459 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3460 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3464 llvm::Value *redValue = originalVariable;
3467 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3468 llvm::Value *privateRedValue = builder.CreateLoad(
3469 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3470 llvm::Value *reduced;
3472 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3475 builder.restoreIP(res.get());
3479 builder.CreateStore(reduced, originalVariable);
3484 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3485 [](omp::DeclareReductionOp reductionDecl) {
3486 return &reductionDecl.getCleanupRegion();
3489 moduleTranslation, builder,
3490 "omp.reduction.cleanup")))
3503 auto loopOp = cast<omp::LoopNestOp>(opInst);
3509 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3514 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3515 llvm::Value *iv) -> llvm::Error {
3518 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3523 bodyInsertPoints.push_back(ip);
3525 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3526 return llvm::Error::success();
3529 builder.restoreIP(ip);
3531 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3533 return regionBlock.takeError();
3535 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3536 return llvm::Error::success();
3544 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3545 llvm::Value *lowerBound =
3546 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3547 llvm::Value *upperBound =
3548 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3549 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3554 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3555 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3557 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3559 computeIP = loopInfos.front()->getPreheaderIP();
3563 ompBuilder->createCanonicalLoop(
3564 loc, bodyGen, lowerBound, upperBound, step,
3565 true, loopOp.getLoopInclusive(), computeIP);
3570 loopInfos.push_back(*loopResult);
3573 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3574 loopInfos.front()->getAfterIP();
3577 if (
const auto &tiles = loopOp.getTileSizes()) {
3578 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3581 for (
auto tile : tiles.value()) {
3582 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3583 tileSizes.push_back(tileVal);
3586 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3587 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3591 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3592 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3593 afterIP = {afterAfterBB, afterAfterBB->begin()};
3597 for (
const auto &newLoop : newLoops)
3598 loopInfos.push_back(newLoop);
3602 const auto &numCollapse = loopOp.getCollapseNumLoops();
3604 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3606 auto newTopLoopInfo =
3607 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3609 assert(newTopLoopInfo &&
"New top loop information is missing");
3610 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3611 [&](OpenMPLoopInfoStackFrame &frame) {
3612 frame.loopInfo = newTopLoopInfo;
3620 builder.restoreIP(afterIP);
3630 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3631 Value loopIV = op.getInductionVar();
3632 Value loopTC = op.getTripCount();
3634 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3637 ompBuilder->createCanonicalLoop(
3639 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3642 moduleTranslation.
mapValue(loopIV, llvmIV);
3644 builder.restoreIP(ip);
3649 return bodyGenStatus.takeError();
3651 llvmTC,
"omp.loop");
3653 return op.emitError(llvm::toString(llvmOrError.takeError()));
3655 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3656 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3657 builder.restoreIP(afterIP);
3660 if (
Value cli = op.getCli())
3673 Value applyee = op.getApplyee();
3674 assert(applyee &&
"Loop to apply unrolling on required");
3676 llvm::CanonicalLoopInfo *consBuilderCLI =
3678 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3679 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3687static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3690 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3695 for (
Value size : op.getSizes()) {
3696 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3697 assert(translatedSize &&
3698 "sizes clause arguments must already be translated");
3699 translatedSizes.push_back(translatedSize);
3702 for (
Value applyee : op.getApplyees()) {
3703 llvm::CanonicalLoopInfo *consBuilderCLI =
3705 assert(applyee &&
"Canonical loop must already been translated");
3706 translatedLoops.push_back(consBuilderCLI);
3709 auto generatedLoops =
3710 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3711 if (!op.getGeneratees().empty()) {
3712 for (
auto [mlirLoop,
genLoop] :
3713 zip_equal(op.getGeneratees(), generatedLoops))
3718 for (
Value applyee : op.getApplyees())
3725static llvm::AtomicOrdering
3728 return llvm::AtomicOrdering::Monotonic;
3731 case omp::ClauseMemoryOrderKind::Seq_cst:
3732 return llvm::AtomicOrdering::SequentiallyConsistent;
3733 case omp::ClauseMemoryOrderKind::Acq_rel:
3734 return llvm::AtomicOrdering::AcquireRelease;
3735 case omp::ClauseMemoryOrderKind::Acquire:
3736 return llvm::AtomicOrdering::Acquire;
3737 case omp::ClauseMemoryOrderKind::Release:
3738 return llvm::AtomicOrdering::Release;
3739 case omp::ClauseMemoryOrderKind::Relaxed:
3740 return llvm::AtomicOrdering::Monotonic;
3742 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3749 auto readOp = cast<omp::AtomicReadOp>(opInst);
3754 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3757 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3760 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3761 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3763 llvm::Type *elementType =
3764 moduleTranslation.
convertType(readOp.getElementType());
3766 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3767 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3768 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3776 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3781 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3784 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3786 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3787 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3788 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3789 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3792 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3800 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3801 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3802 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3803 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3804 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3805 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3806 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3807 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3808 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3809 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3813 bool &isIgnoreDenormalMode,
3814 bool &isFineGrainedMemory,
3815 bool &isRemoteMemory) {
3816 isIgnoreDenormalMode =
false;
3817 isFineGrainedMemory =
false;
3818 isRemoteMemory =
false;
3819 if (atomicUpdateOp &&
3820 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3821 mlir::omp::AtomicControlAttr atomicControlAttr =
3822 atomicUpdateOp.getAtomicControlAttr();
3823 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3824 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3825 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3832 llvm::IRBuilderBase &builder,
3839 auto &innerOpList = opInst.getRegion().front().getOperations();
3840 bool isXBinopExpr{
false};
3841 llvm::AtomicRMWInst::BinOp binop;
3843 llvm::Value *llvmExpr =
nullptr;
3844 llvm::Value *llvmX =
nullptr;
3845 llvm::Type *llvmXElementType =
nullptr;
3846 if (innerOpList.size() == 2) {
3852 opInst.getRegion().getArgument(0))) {
3853 return opInst.emitError(
"no atomic update operation with region argument"
3854 " as operand found inside atomic.update region");
3857 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3859 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3863 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3865 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3867 opInst.getRegion().getArgument(0).getType());
3868 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3872 llvm::AtomicOrdering atomicOrdering =
3877 [&opInst, &moduleTranslation](
3878 llvm::Value *atomicx,
3881 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3882 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3883 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3884 return llvm::make_error<PreviouslyReportedError>();
3886 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3887 assert(yieldop && yieldop.getResults().size() == 1 &&
3888 "terminator must be omp.yield op and it must have exactly one "
3890 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3893 bool isIgnoreDenormalMode;
3894 bool isFineGrainedMemory;
3895 bool isRemoteMemory;
3900 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3901 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3902 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3903 atomicOrdering, binop, updateFn,
3904 isXBinopExpr, isIgnoreDenormalMode,
3905 isFineGrainedMemory, isRemoteMemory);
3910 builder.restoreIP(*afterIP);
3916 llvm::IRBuilderBase &builder,
3923 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3924 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3926 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3927 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3929 assert((atomicUpdateOp || atomicWriteOp) &&
3930 "internal op must be an atomic.update or atomic.write op");
3932 if (atomicWriteOp) {
3933 isPostfixUpdate =
true;
3934 mlirExpr = atomicWriteOp.getExpr();
3936 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3937 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3938 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3941 if (innerOpList.size() == 2) {
3944 atomicUpdateOp.getRegion().getArgument(0))) {
3945 return atomicUpdateOp.emitError(
3946 "no atomic update operation with region argument"
3947 " as operand found inside atomic.update region");
3951 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3954 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3958 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3959 llvm::Value *llvmX =
3960 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3961 llvm::Value *llvmV =
3962 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3963 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3964 atomicCaptureOp.getAtomicReadOp().getElementType());
3965 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3968 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3972 llvm::AtomicOrdering atomicOrdering =
3976 [&](llvm::Value *atomicx,
3979 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3980 Block &bb = *atomicUpdateOp.getRegion().
begin();
3981 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3983 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3984 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3985 return llvm::make_error<PreviouslyReportedError>();
3987 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3988 assert(yieldop && yieldop.getResults().size() == 1 &&
3989 "terminator must be omp.yield op and it must have exactly one "
3991 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3994 bool isIgnoreDenormalMode;
3995 bool isFineGrainedMemory;
3996 bool isRemoteMemory;
3998 isFineGrainedMemory, isRemoteMemory);
4001 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4002 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4003 ompBuilder->createAtomicCapture(
4004 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4005 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4006 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4008 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4011 builder.restoreIP(*afterIP);
4016 omp::ClauseCancellationConstructType directive) {
4017 switch (directive) {
4018 case omp::ClauseCancellationConstructType::Loop:
4019 return llvm::omp::Directive::OMPD_for;
4020 case omp::ClauseCancellationConstructType::Parallel:
4021 return llvm::omp::Directive::OMPD_parallel;
4022 case omp::ClauseCancellationConstructType::Sections:
4023 return llvm::omp::Directive::OMPD_sections;
4024 case omp::ClauseCancellationConstructType::Taskgroup:
4025 return llvm::omp::Directive::OMPD_taskgroup;
4027 llvm_unreachable(
"Unhandled cancellation construct type");
4036 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4039 llvm::Value *ifCond =
nullptr;
4040 if (
Value ifVar = op.getIfExpr())
4043 llvm::omp::Directive cancelledDirective =
4046 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4047 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4049 if (failed(
handleError(afterIP, *op.getOperation())))
4052 builder.restoreIP(afterIP.get());
4059 llvm::IRBuilderBase &builder,
4064 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4067 llvm::omp::Directive cancelledDirective =
4070 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4071 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4073 if (failed(
handleError(afterIP, *op.getOperation())))
4076 builder.restoreIP(afterIP.get());
4086 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4088 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4093 Value symAddr = threadprivateOp.getSymAddr();
4096 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4099 if (!isa<LLVM::AddressOfOp>(symOp))
4100 return opInst.
emitError(
"Addressing symbol not found");
4101 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4103 LLVM::GlobalOp global =
4104 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4105 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4106 llvm::Type *type = globalValue->getValueType();
4107 llvm::TypeSize typeSize =
4108 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4110 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4111 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4112 ompLoc, globalValue, size, global.getSymName() +
".cache");
4118static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4120 switch (deviceClause) {
4121 case mlir::omp::DeclareTargetDeviceType::host:
4122 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4124 case mlir::omp::DeclareTargetDeviceType::nohost:
4125 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4127 case mlir::omp::DeclareTargetDeviceType::any:
4128 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4131 llvm_unreachable(
"unhandled device clause");
4134static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4136 mlir::omp::DeclareTargetCaptureClause captureClause) {
4137 switch (captureClause) {
4138 case mlir::omp::DeclareTargetCaptureClause::to:
4139 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4140 case mlir::omp::DeclareTargetCaptureClause::link:
4141 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4142 case mlir::omp::DeclareTargetCaptureClause::enter:
4143 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4144 case mlir::omp::DeclareTargetCaptureClause::none:
4145 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4147 llvm_unreachable(
"unhandled capture clause");
4152 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4154 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4155 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4156 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4161static llvm::SmallString<64>
4163 llvm::OpenMPIRBuilder &ompBuilder) {
4165 llvm::raw_svector_ostream os(suffix);
4168 auto fileInfoCallBack = [&loc]() {
4169 return std::pair<std::string, uint64_t>(
4170 llvm::StringRef(loc.getFilename()), loc.getLine());
4173 auto vfs = llvm::vfs::getRealFileSystem();
4176 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4178 os <<
"_decl_tgt_ref_ptr";
4184 if (
auto declareTargetGlobal =
4185 dyn_cast_if_present<omp::DeclareTargetInterface>(
4187 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4188 omp::DeclareTargetCaptureClause::link)
4194 if (
auto declareTargetGlobal =
4195 dyn_cast_if_present<omp::DeclareTargetInterface>(
4197 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4198 omp::DeclareTargetCaptureClause::to ||
4199 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4200 omp::DeclareTargetCaptureClause::enter)
4214 if (
auto declareTargetGlobal =
4215 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4218 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4219 omp::DeclareTargetCaptureClause::link) ||
4220 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4221 omp::DeclareTargetCaptureClause::to &&
4222 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4226 if (gOp.getSymName().contains(suffix))
4231 (gOp.getSymName().str() + suffix.str()).str());
4240struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4241 SmallVector<Operation *, 4> Mappers;
4244 void append(MapInfosTy &curInfo) {
4245 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4246 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4255struct MapInfoData : MapInfosTy {
4256 llvm::SmallVector<bool, 4> IsDeclareTarget;
4257 llvm::SmallVector<bool, 4> IsAMember;
4259 llvm::SmallVector<bool, 4> IsAMapping;
4260 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4261 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4264 llvm::SmallVector<llvm::Type *, 4> BaseType;
4267 void append(MapInfoData &CurInfo) {
4268 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4269 CurInfo.IsDeclareTarget.end());
4270 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4271 OriginalValue.append(CurInfo.OriginalValue.begin(),
4272 CurInfo.OriginalValue.end());
4273 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4274 MapInfosTy::append(CurInfo);
4278enum class TargetDirectiveEnumTy : uint32_t {
4282 TargetEnterData = 3,
4287static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4288 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4289 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4290 .Case([](omp::TargetEnterDataOp) {
4291 return TargetDirectiveEnumTy::TargetEnterData;
4293 .Case([&](omp::TargetExitDataOp) {
4294 return TargetDirectiveEnumTy::TargetExitData;
4296 .Case([&](omp::TargetUpdateOp) {
4297 return TargetDirectiveEnumTy::TargetUpdate;
4299 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4300 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4307 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4308 arrTy.getElementType()))
4325 llvm::Value *basePointer,
4326 llvm::Type *baseType,
4327 llvm::IRBuilderBase &builder,
4329 if (
auto memberClause =
4330 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4335 if (!memberClause.getBounds().empty()) {
4336 llvm::Value *elementCount = builder.getInt64(1);
4337 for (
auto bounds : memberClause.getBounds()) {
4338 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4339 bounds.getDefiningOp())) {
4344 elementCount = builder.CreateMul(
4348 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
4349 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
4350 builder.getInt64(1)));
4357 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4365 return builder.CreateMul(elementCount,
4366 builder.getInt64(underlyingTypeSzInBits / 8));
4377static llvm::omp::OpenMPOffloadMappingFlags
4379 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4380 return (mlirFlags & flag) == flag;
4382 const bool hasExplicitMap =
4383 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
4384 omp::ClauseMapFlags::none;
4386 llvm::omp::OpenMPOffloadMappingFlags mapType =
4387 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4390 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4393 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4396 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4399 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4402 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4405 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4408 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4411 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4414 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4417 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4420 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4423 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4426 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4427 if (!hasExplicitMap)
4428 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4438 ArrayRef<Value> useDevAddrOperands = {},
4439 ArrayRef<Value> hasDevAddrOperands = {}) {
4440 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4448 for (Value mapValue : mapVars) {
4449 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4450 for (
auto member : map.getMembers())
4451 if (member == mapOp)
4458 for (Value mapValue : mapVars) {
4459 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4461 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4462 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4463 mapData.Pointers.push_back(mapData.OriginalValue.back());
4465 if (llvm::Value *refPtr =
4467 mapData.IsDeclareTarget.push_back(
true);
4468 mapData.BasePointers.push_back(refPtr);
4470 mapData.IsDeclareTarget.push_back(
true);
4471 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4473 mapData.IsDeclareTarget.push_back(
false);
4474 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4477 mapData.BaseType.push_back(
4478 moduleTranslation.
convertType(mapOp.getVarType()));
4479 mapData.Sizes.push_back(
4480 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4481 mapData.BaseType.back(), builder, moduleTranslation));
4482 mapData.MapClause.push_back(mapOp.getOperation());
4484 mapData.Names.push_back(LLVM::createMappingInformation(
4486 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4487 if (mapOp.getMapperId())
4488 mapData.Mappers.push_back(
4490 mapOp, mapOp.getMapperIdAttr()));
4492 mapData.Mappers.push_back(
nullptr);
4493 mapData.IsAMapping.push_back(
true);
4494 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4497 auto findMapInfo = [&mapData](llvm::Value *val,
4498 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4501 for (llvm::Value *basePtr : mapData.OriginalValue) {
4502 if (basePtr == val && mapData.IsAMapping[index]) {
4504 mapData.Types[index] |=
4505 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4506 mapData.DevicePointers[index] = devInfoTy;
4514 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4515 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4516 for (Value mapValue : useDevOperands) {
4517 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4519 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4520 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4523 if (!findMapInfo(origValue, devInfoTy)) {
4524 mapData.OriginalValue.push_back(origValue);
4525 mapData.Pointers.push_back(mapData.OriginalValue.back());
4526 mapData.IsDeclareTarget.push_back(
false);
4527 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4528 mapData.BaseType.push_back(
4529 moduleTranslation.
convertType(mapOp.getVarType()));
4530 mapData.Sizes.push_back(builder.getInt64(0));
4531 mapData.MapClause.push_back(mapOp.getOperation());
4532 mapData.Types.push_back(
4533 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4534 mapData.Names.push_back(LLVM::createMappingInformation(
4536 mapData.DevicePointers.push_back(devInfoTy);
4537 mapData.Mappers.push_back(
nullptr);
4538 mapData.IsAMapping.push_back(
false);
4539 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4544 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4545 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4547 for (Value mapValue : hasDevAddrOperands) {
4548 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4550 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4551 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4553 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4555 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4556 omp::ClauseMapFlags::none;
4558 mapData.OriginalValue.push_back(origValue);
4559 mapData.BasePointers.push_back(origValue);
4560 mapData.Pointers.push_back(origValue);
4561 mapData.IsDeclareTarget.push_back(
false);
4562 mapData.BaseType.push_back(
4563 moduleTranslation.
convertType(mapOp.getVarType()));
4564 mapData.Sizes.push_back(
4565 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4566 mapData.MapClause.push_back(mapOp.getOperation());
4567 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4571 mapData.Types.push_back(mapType);
4575 if (mapOp.getMapperId()) {
4576 mapData.Mappers.push_back(
4578 mapOp, mapOp.getMapperIdAttr()));
4580 mapData.Mappers.push_back(
nullptr);
4585 mapData.Types.push_back(
4586 isDevicePtr ? mapType
4587 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4588 mapData.Mappers.push_back(
nullptr);
4590 mapData.Names.push_back(LLVM::createMappingInformation(
4592 mapData.DevicePointers.push_back(
4593 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4594 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4595 mapData.IsAMapping.push_back(
false);
4596 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4601 auto *res = llvm::find(mapData.MapClause, memberOp);
4602 assert(res != mapData.MapClause.end() &&
4603 "MapInfoOp for member not found in MapData, cannot return index");
4604 return std::distance(mapData.MapClause.begin(), res);
4608 omp::MapInfoOp mapInfo) {
4609 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4619 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4620 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4622 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4623 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4624 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4626 if (aIndex == bIndex)
4629 if (aIndex < bIndex)
4632 if (aIndex > bIndex)
4639 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4641 occludedChildren.push_back(
b);
4643 occludedChildren.push_back(a);
4644 return memberAParent;
4650 for (
auto v : occludedChildren)
4657 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4659 if (indexAttr.size() == 1)
4660 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4664 return llvm::cast<omp::MapInfoOp>(
4689static std::vector<llvm::Value *>
4691 llvm::IRBuilderBase &builder,
bool isArrayTy,
4693 std::vector<llvm::Value *> idx;
4704 idx.push_back(builder.getInt64(0));
4705 for (
int i = bounds.size() - 1; i >= 0; --i) {
4706 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4707 bounds[i].getDefiningOp())) {
4708 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4726 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4727 for (
int i = bounds.size() - 1; i >= 0; --i) {
4728 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4729 bounds[i].getDefiningOp())) {
4730 if (i == ((
int)bounds.size() - 1))
4732 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4734 idx.back() = builder.CreateAdd(
4735 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4736 boundOp.getExtent())),
4737 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4746 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
4747 return cast<IntegerAttr>(value).getInt();
4755 omp::MapInfoOp parentOp) {
4757 if (parentOp.getMembers().empty())
4761 if (parentOp.getMembers().size() == 1) {
4762 overlapMapDataIdxs.push_back(0);
4768 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4769 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4770 memberByIndex.push_back(
4771 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4776 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4777 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
4783 for (
auto v : memberByIndex) {
4787 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
4790 llvm::SmallVector<int64_t> xArr(x.second.size());
4791 getAsIntegers(x.second, xArr);
4792 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4793 xArr.size() >= vArr.size();
4799 for (
auto v : memberByIndex)
4800 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4801 overlapMapDataIdxs.push_back(v.first);
4813 if (mapOp.getVarPtrPtr())
4842 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4843 MapInfoData &mapData, uint64_t mapDataIndex,
4844 TargetDirectiveEnumTy targetDirective) {
4845 assert(!ompBuilder.Config.isTargetDevice() &&
4846 "function only supported for host device codegen");
4849 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4851 auto *parentMapper = mapData.Mappers[mapDataIndex];
4857 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4858 (targetDirective == TargetDirectiveEnumTy::Target &&
4859 !mapData.IsDeclareTarget[mapDataIndex])
4860 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4861 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4864 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4868 mapFlags parentFlags = mapData.Types[mapDataIndex];
4869 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4870 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4871 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4872 baseFlag |= (parentFlags & preserve);
4875 combinedInfo.Types.emplace_back(baseFlag);
4876 combinedInfo.DevicePointers.emplace_back(
4877 mapData.DevicePointers[mapDataIndex]);
4881 combinedInfo.Mappers.emplace_back(
4882 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
4884 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4885 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4894 llvm::Value *lowAddr, *highAddr;
4895 if (!parentClause.getPartialMap()) {
4896 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4897 builder.getPtrTy());
4898 highAddr = builder.CreatePointerCast(
4899 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4900 mapData.Pointers[mapDataIndex], 1),
4901 builder.getPtrTy());
4902 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4904 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4907 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4908 builder.getPtrTy());
4911 highAddr = builder.CreatePointerCast(
4912 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4913 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4914 builder.getPtrTy());
4915 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4918 llvm::Value *size = builder.CreateIntCast(
4919 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4920 builder.getInt64Ty(),
4922 combinedInfo.Sizes.push_back(size);
4924 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4925 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4933 if (!parentClause.getPartialMap()) {
4938 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4939 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
4940 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
4941 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4942 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4944 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
4945 combinedInfo.Types.emplace_back(mapFlag);
4946 combinedInfo.DevicePointers.emplace_back(
4947 mapData.DevicePointers[mapDataIndex]);
4949 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4950 combinedInfo.BasePointers.emplace_back(
4951 mapData.BasePointers[mapDataIndex]);
4952 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4953 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4954 combinedInfo.Mappers.emplace_back(
nullptr);
4965 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4966 builder.getPtrTy());
4967 highAddr = builder.CreatePointerCast(
4968 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4969 mapData.Pointers[mapDataIndex], 1),
4970 builder.getPtrTy());
4977 for (
auto v : overlapIdxs) {
4980 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
4981 combinedInfo.Types.emplace_back(mapFlag);
4982 combinedInfo.DevicePointers.emplace_back(
4983 mapData.DevicePointers[mapDataOverlapIdx]);
4985 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4986 combinedInfo.BasePointers.emplace_back(
4987 mapData.BasePointers[mapDataIndex]);
4988 combinedInfo.Mappers.emplace_back(
nullptr);
4989 combinedInfo.Pointers.emplace_back(lowAddr);
4990 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4991 builder.CreatePtrDiff(builder.getInt8Ty(),
4992 mapData.OriginalValue[mapDataOverlapIdx],
4994 builder.getInt64Ty(),
true));
4995 lowAddr = builder.CreateConstGEP1_32(
4997 mapData.MapClause[mapDataOverlapIdx]))
4998 ? builder.getPtrTy()
4999 : mapData.BaseType[mapDataOverlapIdx],
5000 mapData.BasePointers[mapDataOverlapIdx], 1);
5003 combinedInfo.Types.emplace_back(mapFlag);
5004 combinedInfo.DevicePointers.emplace_back(
5005 mapData.DevicePointers[mapDataIndex]);
5007 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5008 combinedInfo.BasePointers.emplace_back(
5009 mapData.BasePointers[mapDataIndex]);
5010 combinedInfo.Mappers.emplace_back(
nullptr);
5011 combinedInfo.Pointers.emplace_back(lowAddr);
5012 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5013 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5014 builder.getInt64Ty(),
true));
5017 return memberOfFlag;
5023 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5024 MapInfoData &mapData, uint64_t mapDataIndex,
5025 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5026 TargetDirectiveEnumTy targetDirective) {
5027 assert(!ompBuilder.Config.isTargetDevice() &&
5028 "function only supported for host device codegen");
5031 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5033 for (
auto mappedMembers : parentClause.getMembers()) {
5035 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5038 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5049 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5050 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5051 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5052 combinedInfo.Types.emplace_back(mapFlag);
5053 combinedInfo.DevicePointers.emplace_back(
5054 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5055 combinedInfo.Mappers.emplace_back(
nullptr);
5056 combinedInfo.Names.emplace_back(
5058 combinedInfo.BasePointers.emplace_back(
5059 mapData.BasePointers[mapDataIndex]);
5060 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5061 combinedInfo.Sizes.emplace_back(builder.getInt64(
5062 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5068 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5069 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5070 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5072 ? parentClause.getVarPtr()
5073 : parentClause.getVarPtrPtr());
5076 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5077 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5078 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5081 combinedInfo.Types.emplace_back(mapFlag);
5082 combinedInfo.DevicePointers.emplace_back(
5083 mapData.DevicePointers[memberDataIdx]);
5084 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5085 combinedInfo.Names.emplace_back(
5087 uint64_t basePointerIndex =
5089 combinedInfo.BasePointers.emplace_back(
5090 mapData.BasePointers[basePointerIndex]);
5091 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5093 llvm::Value *size = mapData.Sizes[memberDataIdx];
5095 size = builder.CreateSelect(
5096 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5097 builder.getInt64(0), size);
5100 combinedInfo.Sizes.emplace_back(size);
5105 MapInfosTy &combinedInfo,
5106 TargetDirectiveEnumTy targetDirective,
5107 int mapDataParentIdx = -1) {
5111 auto mapFlag = mapData.Types[mapDataIdx];
5112 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5116 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5118 if (targetDirective == TargetDirectiveEnumTy::Target &&
5119 !mapData.IsDeclareTarget[mapDataIdx])
5120 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5122 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5124 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5129 if (mapDataParentIdx >= 0)
5130 combinedInfo.BasePointers.emplace_back(
5131 mapData.BasePointers[mapDataParentIdx]);
5133 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5135 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5136 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5137 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5138 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5139 combinedInfo.Types.emplace_back(mapFlag);
5140 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5144 llvm::IRBuilderBase &builder,
5145 llvm::OpenMPIRBuilder &ompBuilder,
5147 MapInfoData &mapData, uint64_t mapDataIndex,
5148 TargetDirectiveEnumTy targetDirective) {
5149 assert(!ompBuilder.Config.isTargetDevice() &&
5150 "function only supported for host device codegen");
5153 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5158 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5159 auto memberClause = llvm::cast<omp::MapInfoOp>(
5160 parentClause.getMembers()[0].getDefiningOp());
5177 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5179 combinedInfo, mapData, mapDataIndex,
5182 combinedInfo, mapData, mapDataIndex,
5183 memberOfParentFlag, targetDirective);
5193 llvm::IRBuilderBase &builder) {
5195 "function only supported for host device codegen");
5196 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5198 if (!mapData.IsDeclareTarget[i]) {
5199 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5200 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5210 switch (captureKind) {
5211 case omp::VariableCaptureKind::ByRef: {
5212 llvm::Value *newV = mapData.Pointers[i];
5214 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5217 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5219 if (!offsetIdx.empty())
5220 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5222 mapData.Pointers[i] = newV;
5224 case omp::VariableCaptureKind::ByCopy: {
5225 llvm::Type *type = mapData.BaseType[i];
5227 if (mapData.Pointers[i]->getType()->isPointerTy())
5228 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5230 newV = mapData.Pointers[i];
5233 auto curInsert = builder.saveIP();
5234 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5236 auto *memTempAlloc =
5237 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5238 builder.SetCurrentDebugLocation(DbgLoc);
5239 builder.restoreIP(curInsert);
5241 builder.CreateStore(newV, memTempAlloc);
5242 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5245 mapData.Pointers[i] = newV;
5246 mapData.BasePointers[i] = newV;
5248 case omp::VariableCaptureKind::This:
5249 case omp::VariableCaptureKind::VLAType:
5250 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5261 MapInfoData &mapData,
5262 TargetDirectiveEnumTy targetDirective) {
5264 "function only supported for host device codegen");
5285 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5288 if (mapData.IsAMember[i])
5291 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5292 if (!mapInfoOp.getMembers().empty()) {
5294 combinedInfo, mapData, i, targetDirective);
5302static llvm::Expected<llvm::Function *>
5304 LLVM::ModuleTranslation &moduleTranslation,
5305 llvm::StringRef mapperFuncName,
5306 TargetDirectiveEnumTy targetDirective);
5308static llvm::Expected<llvm::Function *>
5311 TargetDirectiveEnumTy targetDirective) {
5313 "function only supported for host device codegen");
5314 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5315 std::string mapperFuncName =
5317 {
"omp_mapper", declMapperOp.getSymName()});
5319 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
5327 if (llvm::Function *existingFunc =
5328 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
5329 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
5330 return existingFunc;
5334 mapperFuncName, targetDirective);
5337static llvm::Expected<llvm::Function *>
5340 llvm::StringRef mapperFuncName,
5341 TargetDirectiveEnumTy targetDirective) {
5343 "function only supported for host device codegen");
5344 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5345 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5348 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
5351 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5354 MapInfosTy combinedInfo;
5356 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5357 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5358 builder.restoreIP(codeGenIP);
5359 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
5360 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
5361 builder.GetInsertBlock());
5362 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
5365 return llvm::make_error<PreviouslyReportedError>();
5366 MapInfoData mapData;
5369 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5375 return combinedInfo;
5379 if (!combinedInfo.Mappers[i])
5382 moduleTranslation, targetDirective);
5386 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5388 return newFn.takeError();
5389 if (llvm::Function *mappedFunc =
5391 assert(mappedFunc == *newFn &&
5392 "mapper function mapping disagrees with emitted function");
5394 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
5402 llvm::Value *ifCond =
nullptr;
5403 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5407 llvm::omp::RuntimeFunction RTLFn;
5409 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5412 llvm::OpenMPIRBuilder::TargetDataInfo info(
5415 assert(!ompBuilder->Config.isTargetDevice() &&
5416 "target data/enter/exit/update are host ops");
5417 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5419 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
5420 llvm::Value *v = moduleTranslation.
lookupValue(dev);
5421 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
5426 .Case([&](omp::TargetDataOp dataOp) {
5430 if (
auto ifVar = dataOp.getIfExpr())
5434 deviceID = getDeviceID(devId);
5436 mapVars = dataOp.getMapVars();
5437 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5438 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5441 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5445 if (
auto ifVar = enterDataOp.getIfExpr())
5449 deviceID = getDeviceID(devId);
5452 enterDataOp.getNowait()
5453 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5454 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5455 mapVars = enterDataOp.getMapVars();
5456 info.HasNoWait = enterDataOp.getNowait();
5459 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5463 if (
auto ifVar = exitDataOp.getIfExpr())
5467 deviceID = getDeviceID(devId);
5469 RTLFn = exitDataOp.getNowait()
5470 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5471 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5472 mapVars = exitDataOp.getMapVars();
5473 info.HasNoWait = exitDataOp.getNowait();
5476 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5480 if (
auto ifVar = updateDataOp.getIfExpr())
5484 deviceID = getDeviceID(devId);
5487 updateDataOp.getNowait()
5488 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5489 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5490 mapVars = updateDataOp.getMapVars();
5491 info.HasNoWait = updateDataOp.getNowait();
5494 .DefaultUnreachable(
"unexpected operation");
5499 if (!isOffloadEntry)
5500 ifCond = builder.getFalse();
5502 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5503 MapInfoData mapData;
5505 builder, useDevicePtrVars, useDeviceAddrVars);
5508 MapInfosTy combinedInfo;
5509 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5510 builder.restoreIP(codeGenIP);
5511 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5513 return combinedInfo;
5519 [&moduleTranslation](
5520 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5524 for (
auto [arg, useDevVar] :
5525 llvm::zip_equal(blockArgs, useDeviceVars)) {
5527 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5528 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5529 : mapInfoOp.getVarPtr();
5532 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5533 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5534 mapInfoData.MapClause, mapInfoData.DevicePointers,
5535 mapInfoData.BasePointers)) {
5536 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5537 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5538 devicePointer != type)
5541 if (llvm::Value *devPtrInfoMap =
5542 mapper ? mapper(basePointer) : basePointer) {
5543 moduleTranslation.
mapValue(arg, devPtrInfoMap);
5550 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5551 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5552 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5555 builder.restoreIP(codeGenIP);
5556 assert(isa<omp::TargetDataOp>(op) &&
5557 "BodyGen requested for non TargetDataOp");
5558 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5559 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
5560 switch (bodyGenType) {
5561 case BodyGenTy::Priv:
5563 if (!info.DevicePtrInfoMap.empty()) {
5564 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5565 blockArgIface.getUseDeviceAddrBlockArgs(),
5566 useDeviceAddrVars, mapData,
5567 [&](llvm::Value *basePointer) -> llvm::Value * {
5568 if (!info.DevicePtrInfoMap[basePointer].second)
5570 return builder.CreateLoad(
5572 info.DevicePtrInfoMap[basePointer].second);
5574 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5575 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5576 mapData, [&](llvm::Value *basePointer) {
5577 return info.DevicePtrInfoMap[basePointer].second;
5581 moduleTranslation)))
5582 return llvm::make_error<PreviouslyReportedError>();
5585 case BodyGenTy::DupNoPriv:
5586 if (info.DevicePtrInfoMap.empty()) {
5589 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5590 blockArgIface.getUseDeviceAddrBlockArgs(),
5591 useDeviceAddrVars, mapData);
5592 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5593 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5597 case BodyGenTy::NoPriv:
5599 if (info.DevicePtrInfoMap.empty()) {
5601 moduleTranslation)))
5602 return llvm::make_error<PreviouslyReportedError>();
5606 return builder.saveIP();
5609 auto customMapperCB =
5611 if (!combinedInfo.Mappers[i])
5613 info.HasMapper =
true;
5615 moduleTranslation, targetDirective);
5618 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5619 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5621 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5622 if (isa<omp::TargetDataOp>(op))
5623 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5624 deviceID, ifCond, info, genMapInfoCB,
5628 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5629 deviceID, ifCond, info, genMapInfoCB,
5630 customMapperCB, &RTLFn);
5636 builder.restoreIP(*afterIP);
5644 auto distributeOp = cast<omp::DistributeOp>(opInst);
5651 bool doDistributeReduction =
5655 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5660 if (doDistributeReduction) {
5661 isByRef =
getIsByRef(teamsOp.getReductionByref());
5662 assert(isByRef.size() == teamsOp.getNumReductionVars());
5665 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5669 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5670 .getReductionBlockArgs();
5673 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5674 reductionDecls, privateReductionVariables, reductionVariableMap,
5679 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5680 auto bodyGenCB = [&](InsertPointTy allocaIP,
5681 InsertPointTy codeGenIP) -> llvm::Error {
5685 moduleTranslation, allocaIP);
5688 builder.restoreIP(codeGenIP);
5694 return llvm::make_error<PreviouslyReportedError>();
5699 return llvm::make_error<PreviouslyReportedError>();
5702 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
5704 distributeOp.getPrivateNeedsBarrier())))
5705 return llvm::make_error<PreviouslyReportedError>();
5708 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5711 builder, moduleTranslation);
5713 return regionBlock.takeError();
5714 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5719 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5722 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5723 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5724 : omp::ClauseScheduleKind::Static;
5726 bool isOrdered = hasDistSchedule;
5727 std::optional<omp::ScheduleModifier> scheduleMod;
5728 bool isSimd =
false;
5729 llvm::omp::WorksharingLoopType workshareLoopType =
5730 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5731 bool loopNeedsBarrier =
false;
5732 llvm::Value *chunk = moduleTranslation.
lookupValue(
5733 distributeOp.getDistScheduleChunkSize());
5734 llvm::CanonicalLoopInfo *loopInfo =
5736 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5737 ompBuilder->applyWorkshareLoop(
5738 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5739 convertToScheduleKind(schedule), chunk, isSimd,
5740 scheduleMod == omp::ScheduleModifier::monotonic,
5741 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5742 workshareLoopType,
false, hasDistSchedule, chunk);
5745 return wsloopIP.takeError();
5748 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5750 return llvm::make_error<PreviouslyReportedError>();
5752 return llvm::Error::success();
5755 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5757 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5758 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5759 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5764 builder.restoreIP(*afterIP);
5766 if (doDistributeReduction) {
5769 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5770 privateReductionVariables, isByRef,
5782 if (!cast<mlir::ModuleOp>(op))
5787 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5788 attribute.getOpenmpDeviceVersion());
5790 if (attribute.getNoGpuLib())
5793 ompBuilder->createGlobalFlag(
5794 attribute.getDebugKind() ,
5795 "__omp_rtl_debug_kind");
5796 ompBuilder->createGlobalFlag(
5798 .getAssumeTeamsOversubscription()
5800 "__omp_rtl_assume_teams_oversubscription");
5801 ompBuilder->createGlobalFlag(
5803 .getAssumeThreadsOversubscription()
5805 "__omp_rtl_assume_threads_oversubscription");
5806 ompBuilder->createGlobalFlag(
5807 attribute.getAssumeNoThreadState() ,
5808 "__omp_rtl_assume_no_thread_state");
5809 ompBuilder->createGlobalFlag(
5811 .getAssumeNoNestedParallelism()
5813 "__omp_rtl_assume_no_nested_parallelism");
5818 omp::TargetOp targetOp,
5819 llvm::StringRef parentName =
"") {
5820 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5822 assert(fileLoc &&
"No file found from location");
5823 StringRef fileName = fileLoc.getFilename().getValue();
5825 llvm::sys::fs::UniqueID id;
5826 uint64_t line = fileLoc.getLine();
5827 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5828 size_t fileHash = llvm::hash_value(fileName.str());
5829 size_t deviceId = 0xdeadf17e;
5831 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5833 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5834 id.getFile(), line);
5841 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5843 "function only supported for target device codegen");
5844 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5845 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5858 if (mapData.IsDeclareTarget[i]) {
5865 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5866 convertUsersOfConstantsToInstructions(constant,
func,
false);
5873 for (llvm::User *user : mapData.OriginalValue[i]->users())
5874 userVec.push_back(user);
5876 for (llvm::User *user : userVec) {
5877 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5878 if (insn->getFunction() ==
func) {
5879 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5880 llvm::Value *substitute = mapData.BasePointers[i];
5882 : mapOp.getVarPtr())) {
5883 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5884 substitute = builder.CreateLoad(
5885 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5886 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5888 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5935static llvm::IRBuilderBase::InsertPoint
5937 llvm::Value *input, llvm::Value *&retVal,
5938 llvm::IRBuilderBase &builder,
5939 llvm::OpenMPIRBuilder &ompBuilder,
5941 llvm::IRBuilderBase::InsertPoint allocaIP,
5942 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5943 assert(ompBuilder.Config.isTargetDevice() &&
5944 "function only supported for target device codegen");
5945 builder.restoreIP(allocaIP);
5947 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5949 ompBuilder.M.getContext());
5950 unsigned alignmentValue = 0;
5952 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
5953 if (mapData.OriginalValue[i] == input) {
5954 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5955 capture = mapOp.getMapCaptureType();
5958 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5962 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5963 unsigned int defaultAS =
5964 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5967 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5969 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5970 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5972 builder.CreateStore(&arg, v);
5974 builder.restoreIP(codeGenIP);
5977 case omp::VariableCaptureKind::ByCopy: {
5981 case omp::VariableCaptureKind::ByRef: {
5982 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5984 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5999 if (v->getType()->isPointerTy() && alignmentValue) {
6000 llvm::MDBuilder MDB(builder.getContext());
6001 loadInst->setMetadata(
6002 llvm::LLVMContext::MD_align,
6003 llvm::MDNode::get(builder.getContext(),
6004 MDB.createConstant(llvm::ConstantInt::get(
6005 llvm::Type::getInt64Ty(builder.getContext()),
6012 case omp::VariableCaptureKind::This:
6013 case omp::VariableCaptureKind::VLAType:
6016 assert(
false &&
"Currently unsupported capture kind");
6020 return builder.saveIP();
6037 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6038 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6039 blockArgIface.getHostEvalBlockArgs())) {
6040 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6044 .Case([&](omp::TeamsOp teamsOp) {
6045 if (teamsOp.getNumTeamsLower() == blockArg)
6046 numTeamsLower = hostEvalVar;
6047 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6049 numTeamsUpper = hostEvalVar;
6050 else if (!teamsOp.getThreadLimitVars().empty() &&
6051 teamsOp.getThreadLimit(0) == blockArg)
6052 threadLimit = hostEvalVar;
6054 llvm_unreachable(
"unsupported host_eval use");
6056 .Case([&](omp::ParallelOp parallelOp) {
6057 if (!parallelOp.getNumThreadsVars().empty() &&
6058 parallelOp.getNumThreads(0) == blockArg)
6059 numThreads = hostEvalVar;
6061 llvm_unreachable(
"unsupported host_eval use");
6063 .Case([&](omp::LoopNestOp loopOp) {
6064 auto processBounds =
6068 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6069 if (lb == blockArg) {
6072 (*outBounds)[i] = hostEvalVar;
6078 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6079 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6081 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6083 assert(found &&
"unsupported host_eval use");
6085 .DefaultUnreachable(
"unsupported host_eval use");
6097template <
typename OpTy>
6102 if (OpTy casted = dyn_cast<OpTy>(op))
6105 if (immediateParent)
6106 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6115 return std::nullopt;
6118 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6119 return constAttr.getInt();
6121 return std::nullopt;
6126 uint64_t sizeInBytes = sizeInBits / 8;
6130template <
typename OpTy>
6132 if (op.getNumReductionVars() > 0) {
6137 members.reserve(reductions.size());
6138 for (omp::DeclareReductionOp &red : reductions)
6139 members.push_back(red.getType());
6141 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6157 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6158 bool isTargetDevice,
bool isGPU) {
6161 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6162 if (!isTargetDevice) {
6170 numTeamsLower = teamsOp.getNumTeamsLower();
6172 if (!teamsOp.getNumTeamsUpperVars().empty())
6173 numTeamsUpper = teamsOp.getNumTeams(0);
6174 if (!teamsOp.getThreadLimitVars().empty())
6175 threadLimit = teamsOp.getThreadLimit(0);
6179 if (!parallelOp.getNumThreadsVars().empty())
6180 numThreads = parallelOp.getNumThreads(0);
6186 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6190 if (numTeamsUpper) {
6192 minTeamsVal = maxTeamsVal = *val;
6194 minTeamsVal = maxTeamsVal = 0;
6200 minTeamsVal = maxTeamsVal = 1;
6202 minTeamsVal = maxTeamsVal = -1;
6207 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6221 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6222 if (!targetOp.getThreadLimitVars().empty())
6223 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6224 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6227 int32_t maxThreadsVal = -1;
6229 setMaxValueFromClause(numThreads, maxThreadsVal);
6237 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6238 if (combinedMaxThreadsVal < 0 ||
6239 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6240 combinedMaxThreadsVal = teamsThreadLimitVal;
6242 if (combinedMaxThreadsVal < 0 ||
6243 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6244 combinedMaxThreadsVal = maxThreadsVal;
6246 int32_t reductionDataSize = 0;
6247 if (isGPU && capturedOp) {
6253 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6255 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6256 omp::TargetRegionFlags::spmd) &&
6257 "invalid kernel flags");
6259 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6260 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6261 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6262 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6263 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6264 if (omp::bitEnumContainsAll(kernelFlags,
6265 omp::TargetRegionFlags::spmd |
6266 omp::TargetRegionFlags::no_loop) &&
6267 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6268 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6270 attrs.MinTeams = minTeamsVal;
6271 attrs.MaxTeams.front() = maxTeamsVal;
6272 attrs.MinThreads = 1;
6273 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6274 attrs.ReductionDataSize = reductionDataSize;
6277 if (attrs.ReductionDataSize != 0)
6278 attrs.ReductionBufferLength = 1024;
6290 omp::TargetOp targetOp,
Operation *capturedOp,
6291 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6293 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6295 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6299 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6302 if (!targetOp.getThreadLimitVars().empty()) {
6303 Value targetThreadLimit = targetOp.getThreadLimit(0);
6304 attrs.TargetThreadLimit.front() =
6312 attrs.MinTeams = builder.CreateSExtOrTrunc(
6313 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
6316 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6317 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
6319 if (teamsThreadLimit)
6320 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6321 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
6324 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
6326 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6327 omp::TargetRegionFlags::trip_count)) {
6329 attrs.LoopTripCount =
nullptr;
6334 for (
auto [loopLower, loopUpper, loopStep] :
6335 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6336 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
6337 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
6338 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
6340 if (!lowerBound || !upperBound || !step) {
6341 attrs.LoopTripCount =
nullptr;
6345 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6346 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6347 loc, lowerBound, upperBound, step,
true,
6348 loopOp.getLoopInclusive());
6350 if (!attrs.LoopTripCount) {
6351 attrs.LoopTripCount = tripCount;
6356 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6361 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6363 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
6365 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6372 auto targetOp = cast<omp::TargetOp>(opInst);
6376 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6385 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6386 assert(parentBB &&
"No insert block is set for the builder");
6387 llvm::Function *parentLLVMFn = parentBB->getParent();
6388 assert(parentLLVMFn &&
"Parent Function must be valid");
6389 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6390 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6391 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6392 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6395 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6396 bool isGPU = ompBuilder->Config.isGPU();
6399 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6400 auto &targetRegion = targetOp.getRegion();
6417 llvm::Function *llvmOutlinedFn =
nullptr;
6418 TargetDirectiveEnumTy targetDirective =
6419 getTargetDirectiveEnumTyFromOp(&opInst);
6423 bool isOffloadEntry =
6424 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6431 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6433 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6434 std::optional<DenseI64ArrayAttr> privateMapIndices =
6435 targetOp.getPrivateMapsAttr();
6437 for (
auto [privVarIdx, privVarSymPair] :
6438 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6439 auto privVar = std::get<0>(privVarSymPair);
6440 auto privSym = std::get<1>(privVarSymPair);
6442 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6443 omp::PrivateClauseOp privatizer =
6446 if (!privatizer.needsMap())
6450 targetOp.getMappedValueForPrivateVar(privVarIdx);
6451 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6452 "variable that needs mapping");
6457 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6458 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6462 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6464 varType == privVar.getType() &&
6465 "Type of private var doesn't match the type of the mapped value");
6469 mappedPrivateVars.insert(
6471 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6472 (*privateMapIndices)[privVarIdx])});
6476 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6477 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6478 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6479 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6480 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6483 llvm::Function *llvmParentFn =
6485 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6486 assert(llvmParentFn && llvmOutlinedFn &&
6487 "Both parent and outlined functions must exist at this point");
6489 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6490 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6492 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6493 attr.isStringAttribute())
6494 llvmOutlinedFn->addFnAttr(attr);
6496 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6497 attr.isStringAttribute())
6498 llvmOutlinedFn->addFnAttr(attr);
6500 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6501 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6502 llvm::Value *mapOpValue =
6503 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6504 moduleTranslation.
mapValue(arg, mapOpValue);
6506 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6507 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6508 llvm::Value *mapOpValue =
6509 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6510 moduleTranslation.
mapValue(arg, mapOpValue);
6519 allocaIP, &mappedPrivateVars);
6522 return llvm::make_error<PreviouslyReportedError>();
6524 builder.restoreIP(codeGenIP);
6526 &mappedPrivateVars),
6529 return llvm::make_error<PreviouslyReportedError>();
6532 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
6534 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6535 return llvm::make_error<PreviouslyReportedError>();
6539 std::back_inserter(privateCleanupRegions),
6540 [](omp::PrivateClauseOp privatizer) {
6541 return &privatizer.getDeallocRegion();
6545 targetRegion,
"omp.target", builder, moduleTranslation);
6548 return exitBlock.takeError();
6550 builder.SetInsertPoint(*exitBlock);
6551 if (!privateCleanupRegions.empty()) {
6553 privateCleanupRegions, privateVarsInfo.
llvmVars,
6554 moduleTranslation, builder,
"omp.targetop.private.cleanup",
6556 return llvm::createStringError(
6557 "failed to inline `dealloc` region of `omp.private` "
6558 "op in the target region");
6560 return builder.saveIP();
6563 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6566 StringRef parentName = parentFn.getName();
6568 llvm::TargetRegionEntryInfo entryInfo;
6572 MapInfoData mapData;
6577 MapInfosTy combinedInfos;
6579 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6580 builder.restoreIP(codeGenIP);
6581 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6583 return combinedInfos;
6586 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6587 llvm::Value *&retVal, InsertPointTy allocaIP,
6588 InsertPointTy codeGenIP)
6589 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6590 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6591 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6597 if (!isTargetDevice) {
6598 retVal = cast<llvm::Value>(&arg);
6603 *ompBuilder, moduleTranslation,
6604 allocaIP, codeGenIP);
6607 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6608 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6609 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6611 isTargetDevice, isGPU);
6615 if (!isTargetDevice)
6617 targetCapturedOp, runtimeAttrs);
6625 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6626 llvm::Value *value = moduleTranslation.
lookupValue(var);
6627 moduleTranslation.
mapValue(arg, value);
6629 if (!llvm::isa<llvm::Constant>(value))
6630 kernelInput.push_back(value);
6633 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6640 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6641 kernelInput.push_back(mapData.OriginalValue[i]);
6646 moduleTranslation, dds);
6648 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6650 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6652 llvm::OpenMPIRBuilder::TargetDataInfo info(
6656 auto customMapperCB =
6658 if (!combinedInfos.Mappers[i])
6660 info.HasMapper =
true;
6662 moduleTranslation, targetDirective);
6665 llvm::Value *ifCond =
nullptr;
6666 if (
Value targetIfCond = targetOp.getIfExpr())
6667 ifCond = moduleTranslation.
lookupValue(targetIfCond);
6669 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6671 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6672 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6673 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6678 builder.restoreIP(*afterIP);
6699 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6700 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6702 if (!offloadMod.getIsTargetDevice())
6705 omp::DeclareTargetDeviceType declareType =
6706 attribute.getDeviceType().getValue();
6708 if (declareType == omp::DeclareTargetDeviceType::host) {
6709 llvm::Function *llvmFunc =
6711 llvmFunc->dropAllReferences();
6712 llvmFunc->eraseFromParent();
6718 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6719 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6720 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6722 bool isDeclaration = gOp.isDeclaration();
6723 bool isExternallyVisible =
6726 llvm::StringRef mangledName = gOp.getSymName();
6727 auto captureClause =
6733 std::vector<llvm::GlobalVariable *> generatedRefs;
6735 std::vector<llvm::Triple> targetTriple;
6736 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6738 LLVM::LLVMDialect::getTargetTripleAttrName()));
6739 if (targetTripleAttr)
6740 targetTriple.emplace_back(targetTripleAttr.data());
6742 auto fileInfoCallBack = [&loc]() {
6743 std::string filename =
"";
6744 std::uint64_t lineNo = 0;
6747 filename = loc.getFilename().str();
6748 lineNo = loc.getLine();
6751 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6755 auto vfs = llvm::vfs::getRealFileSystem();
6757 ompBuilder->registerTargetGlobalVariable(
6758 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6759 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6760 mangledName, generatedRefs,
false, targetTriple,
6762 gVal->getType(), gVal);
6764 if (ompBuilder->Config.isTargetDevice() &&
6765 (attribute.getCaptureClause().getValue() !=
6766 mlir::omp::DeclareTargetCaptureClause::to ||
6767 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6768 ompBuilder->getAddrOfDeclareTargetVar(
6769 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6770 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6771 mangledName, generatedRefs,
false, targetTriple,
6772 gVal->getType(),
nullptr,
6785class OpenMPDialectLLVMIRTranslationInterface
6786 :
public LLVMTranslationDialectInterface {
6793 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6794 LLVM::ModuleTranslation &moduleTranslation)
const final;
6799 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6800 NamedAttribute attribute,
6801 LLVM::ModuleTranslation &moduleTranslation)
const final;
6806LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6807 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6808 NamedAttribute attribute,
6809 LLVM::ModuleTranslation &moduleTranslation)
const {
6810 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6812 .Case(
"omp.is_target_device",
6813 [&](Attribute attr) {
6814 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6815 llvm::OpenMPIRBuilderConfig &
config =
6817 config.setIsTargetDevice(deviceAttr.getValue());
6823 [&](Attribute attr) {
6824 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6825 llvm::OpenMPIRBuilderConfig &
config =
6827 config.setIsGPU(gpuAttr.getValue());
6832 .Case(
"omp.host_ir_filepath",
6833 [&](Attribute attr) {
6834 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6835 llvm::OpenMPIRBuilder *ompBuilder =
6837 auto VFS = llvm::vfs::getRealFileSystem();
6838 ompBuilder->loadOffloadInfoMetadata(*VFS,
6839 filepathAttr.getValue());
6845 [&](Attribute attr) {
6846 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6850 .Case(
"omp.version",
6851 [&](Attribute attr) {
6852 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6853 llvm::OpenMPIRBuilder *ompBuilder =
6855 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
6856 versionAttr.getVersion());
6861 .Case(
"omp.declare_target",
6862 [&](Attribute attr) {
6863 if (
auto declareTargetAttr =
6864 dyn_cast<omp::DeclareTargetAttr>(attr))
6869 .Case(
"omp.requires",
6870 [&](Attribute attr) {
6871 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6872 using Requires = omp::ClauseRequires;
6873 Requires flags = requiresAttr.getValue();
6874 llvm::OpenMPIRBuilderConfig &
config =
6876 config.setHasRequiresReverseOffload(
6877 bitEnumContainsAll(flags, Requires::reverse_offload));
6878 config.setHasRequiresUnifiedAddress(
6879 bitEnumContainsAll(flags, Requires::unified_address));
6880 config.setHasRequiresUnifiedSharedMemory(
6881 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6882 config.setHasRequiresDynamicAllocators(
6883 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6888 .Case(
"omp.target_triples",
6889 [&](Attribute attr) {
6890 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6891 llvm::OpenMPIRBuilderConfig &
config =
6893 config.TargetTriples.clear();
6894 config.TargetTriples.reserve(triplesAttr.size());
6895 for (Attribute tripleAttr : triplesAttr) {
6896 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6897 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6905 .Default([](Attribute) {
6921 if (
auto declareTargetIface =
6922 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6923 parentFn.getOperation()))
6924 if (declareTargetIface.isDeclareTarget() &&
6925 declareTargetIface.getDeclareTargetDeviceType() !=
6926 mlir::omp::DeclareTargetDeviceType::host)
6936 llvm::Module *llvmModule) {
6937 llvm::Type *i64Ty = builder.getInt64Ty();
6938 llvm::Type *i32Ty = builder.getInt32Ty();
6939 llvm::Type *returnType = builder.getPtrTy(0);
6940 llvm::FunctionType *fnType =
6941 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
6942 llvm::Function *
func = cast<llvm::Function>(
6943 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
6950 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6955 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6959 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6961 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6962 mlir::Type heapTy = allocMemOp.getAllocatedType();
6963 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
6964 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6965 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6966 for (
auto typeParam : allocMemOp.getTypeparams())
6968 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
6970 llvm::CallInst *call =
6971 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6972 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6975 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
6980 llvm::Module *llvmModule) {
6981 llvm::Type *ptrTy = builder.getPtrTy(0);
6982 llvm::Type *i32Ty = builder.getInt32Ty();
6983 llvm::Type *voidTy = builder.getVoidTy();
6984 llvm::FunctionType *fnType =
6985 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
6986 llvm::Function *
func = dyn_cast<llvm::Function>(
6987 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
6994 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6999 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7003 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7006 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
7008 llvm::Value *intToPtr =
7009 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7010 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7016LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7017 Operation *op, llvm::IRBuilderBase &builder,
7018 LLVM::ModuleTranslation &moduleTranslation)
const {
7021 if (ompBuilder->Config.isTargetDevice() &&
7022 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7025 return op->
emitOpError() <<
"unsupported host op found in device";
7033 bool isOutermostLoopWrapper =
7034 isa_and_present<omp::LoopWrapperInterface>(op) &&
7035 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
7037 if (isOutermostLoopWrapper)
7038 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
7041 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7042 .Case([&](omp::BarrierOp op) -> LogicalResult {
7046 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7047 ompBuilder->createBarrier(builder.saveIP(),
7048 llvm::omp::OMPD_barrier);
7050 if (res.succeeded()) {
7053 builder.restoreIP(*afterIP);
7057 .Case([&](omp::TaskyieldOp op) {
7061 ompBuilder->createTaskyield(builder.saveIP());
7064 .Case([&](omp::FlushOp op) {
7076 ompBuilder->createFlush(builder.saveIP());
7079 .Case([&](omp::ParallelOp op) {
7082 .Case([&](omp::MaskedOp) {
7085 .Case([&](omp::MasterOp) {
7088 .Case([&](omp::CriticalOp) {
7091 .Case([&](omp::OrderedRegionOp) {
7094 .Case([&](omp::OrderedOp) {
7097 .Case([&](omp::WsloopOp) {
7100 .Case([&](omp::SimdOp) {
7103 .Case([&](omp::AtomicReadOp) {
7106 .Case([&](omp::AtomicWriteOp) {
7109 .Case([&](omp::AtomicUpdateOp op) {
7112 .Case([&](omp::AtomicCaptureOp op) {
7115 .Case([&](omp::CancelOp op) {
7118 .Case([&](omp::CancellationPointOp op) {
7121 .Case([&](omp::SectionsOp) {
7124 .Case([&](omp::SingleOp op) {
7127 .Case([&](omp::TeamsOp op) {
7130 .Case([&](omp::TaskOp op) {
7133 .Case([&](omp::TaskloopOp op) {
7136 .Case([&](omp::TaskgroupOp op) {
7139 .Case([&](omp::TaskwaitOp op) {
7142 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7143 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7144 omp::CriticalDeclareOp>([](
auto op) {
7157 .Case([&](omp::ThreadprivateOp) {
7160 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7161 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
7164 .Case([&](omp::TargetOp) {
7167 .Case([&](omp::DistributeOp) {
7170 .Case([&](omp::LoopNestOp) {
7173 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7180 .Case([&](omp::NewCliOp op) {
7185 .Case([&](omp::CanonicalLoopOp op) {
7188 .Case([&](omp::UnrollHeuristicOp op) {
7197 .Case([&](omp::TileOp op) {
7198 return applyTile(op, builder, moduleTranslation);
7200 .Case([&](omp::TargetAllocMemOp) {
7203 .Case([&](omp::TargetFreeMemOp) {
7206 .Default([&](Operation *inst) {
7208 <<
"not yet implemented: " << inst->
getName();
7211 if (isOutermostLoopWrapper)
7218 registry.
insert<omp::OpenMPDialect>();
7220 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars