26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
31 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/MDBuilder.h"
37 #include "llvm/IR/ReplaceConstant.h"
38 #include "llvm/Support/FileSystem.h"
39 #include "llvm/TargetParser/Triple.h"
40 #include "llvm/Transforms/Utils/ModuleUtils.h"
52 static llvm::omp::ScheduleKind
53 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
66 return llvm::omp::OMP_SCHEDULE_Runtime;
68 llvm_unreachable(
"unhandled schedule clause argument");
73 class OpenMPAllocaStackFrame
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
86 class OpenMPLoopInfoStackFrame
90 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
109 class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
132 SymbolRefAttr symbolName) {
133 omp::PrivateClauseOp privatizer =
134 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
136 assert(privatizer &&
"privatizer not found in the symbol table");
147 auto todo = [&op](StringRef clauseName) {
148 return op.
emitError() <<
"not yet implemented: Unhandled clause "
149 << clauseName <<
" in " << op.
getName()
153 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
154 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
155 result = todo(
"allocate");
157 auto checkBare = [&todo](
auto op, LogicalResult &result) {
159 result = todo(
"ompx_bare");
161 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
162 omp::ClauseCancellationConstructType cancelledDirective =
163 op.getCancelDirective();
166 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
173 if (isa_and_nonnull<omp::TaskloopOp>(parent))
174 result = todo(
"cancel directive inside of taskloop");
177 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
178 if (!op.getDependVars().empty() || op.getDependKinds())
179 result = todo(
"depend");
181 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
183 result = todo(
"device");
185 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
186 if (op.getDistScheduleChunkSize())
187 result = todo(
"dist_schedule with chunk_size");
189 auto checkHint = [](
auto op, LogicalResult &) {
193 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
194 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
195 op.getInReductionSyms())
196 result = todo(
"in_reduction");
198 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
199 if (!op.getIsDevicePtrVars().empty())
200 result = todo(
"is_device_ptr");
202 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
203 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
204 result = todo(
"linear");
206 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
208 result = todo(
"nowait");
210 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
211 if (op.getOrder() || op.getOrderMod())
212 result = todo(
"order");
214 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
215 if (op.getParLevelSimd())
216 result = todo(
"parallelization-level");
218 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
219 if (op.getPriority())
220 result = todo(
"priority");
222 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
223 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
225 if (!op.getPrivateVars().empty() && op.getNowait())
226 result = todo(
"privatization for deferred target tasks");
228 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
229 result = todo(
"privatization");
232 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
233 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
234 if (!op.getReductionVars().empty() || op.getReductionByref() ||
235 op.getReductionSyms())
236 result = todo(
"reduction");
237 if (op.getReductionMod() &&
238 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
239 result = todo(
"reduction with modifier");
241 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
242 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
243 op.getTaskReductionSyms())
244 result = todo(
"task_reduction");
246 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
248 result = todo(
"untied");
251 LogicalResult result = success();
253 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
254 .Case([&](omp::CancellationPointOp op) {
255 checkCancelDirective(op, result);
257 .Case([&](omp::DistributeOp op) {
258 checkAllocate(op, result);
259 checkDistSchedule(op, result);
260 checkOrder(op, result);
262 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
263 .Case([&](omp::SectionsOp op) {
264 checkAllocate(op, result);
265 checkPrivate(op, result);
266 checkReduction(op, result);
268 .Case([&](omp::SingleOp op) {
269 checkAllocate(op, result);
270 checkPrivate(op, result);
272 .Case([&](omp::TeamsOp op) {
273 checkAllocate(op, result);
274 checkPrivate(op, result);
276 .Case([&](omp::TaskOp op) {
277 checkAllocate(op, result);
278 checkInReduction(op, result);
280 .Case([&](omp::TaskgroupOp op) {
281 checkAllocate(op, result);
282 checkTaskReduction(op, result);
284 .Case([&](omp::TaskwaitOp op) {
285 checkDepend(op, result);
286 checkNowait(op, result);
288 .Case([&](omp::TaskloopOp op) {
290 checkUntied(op, result);
291 checkPriority(op, result);
293 .Case([&](omp::WsloopOp op) {
294 checkAllocate(op, result);
295 checkLinear(op, result);
296 checkOrder(op, result);
297 checkReduction(op, result);
299 .Case([&](omp::ParallelOp op) {
300 checkAllocate(op, result);
301 checkReduction(op, result);
303 .Case([&](omp::SimdOp op) {
304 checkLinear(op, result);
305 checkReduction(op, result);
307 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
308 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
309 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
310 [&](
auto op) { checkDepend(op, result); })
311 .Case([&](omp::TargetOp op) {
312 checkAllocate(op, result);
313 checkBare(op, result);
314 checkDevice(op, result);
315 checkInReduction(op, result);
316 checkIsDevicePtr(op, result);
317 checkPrivate(op, result);
327 LogicalResult result = success();
329 llvm::handleAllErrors(
331 [&](
const PreviouslyReportedError &) { result = failure(); },
332 [&](
const llvm::ErrorInfoBase &err) {
339 template <
typename T>
349 static llvm::OpenMPIRBuilder::InsertPointTy
355 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
357 [&](OpenMPAllocaStackFrame &frame) {
358 allocaInsertPoint = frame.allocaInsertPoint;
362 return allocaInsertPoint;
371 if (builder.GetInsertBlock() ==
372 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
373 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
374 "Assuming end of basic block");
375 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
376 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
377 builder.GetInsertBlock()->getNextNode());
378 builder.CreateBr(entryBB);
379 builder.SetInsertPoint(entryBB);
382 llvm::BasicBlock &funcEntryBlock =
383 builder.GetInsertBlock()->getParent()->getEntryBlock();
384 return llvm::OpenMPIRBuilder::InsertPointTy(
385 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
391 static llvm::CanonicalLoopInfo *
393 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
394 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
395 [&](OpenMPLoopInfoStackFrame &frame) {
396 loopInfo = frame.loopInfo;
408 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
411 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
413 llvm::BasicBlock *continuationBlock =
414 splitBB(builder,
true,
"omp.region.cont");
415 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
417 llvm::LLVMContext &llvmContext = builder.getContext();
418 for (
Block &bb : region) {
419 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
420 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
421 builder.GetInsertBlock()->getNextNode());
422 moduleTranslation.
mapBlock(&bb, llvmBB);
425 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
432 unsigned numYields = 0;
434 if (!isLoopWrapper) {
435 bool operandsProcessed =
false;
437 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
438 if (!operandsProcessed) {
439 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
440 continuationBlockPHITypes.push_back(
441 moduleTranslation.
convertType(yield->getOperand(i).getType()));
443 operandsProcessed =
true;
445 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
446 "mismatching number of values yielded from the region");
447 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
448 llvm::Type *operandType =
449 moduleTranslation.
convertType(yield->getOperand(i).getType());
451 assert(continuationBlockPHITypes[i] == operandType &&
452 "values of mismatching types yielded from the region");
462 if (!continuationBlockPHITypes.empty())
464 continuationBlockPHIs &&
465 "expected continuation block PHIs if converted regions yield values");
466 if (continuationBlockPHIs) {
467 llvm::IRBuilderBase::InsertPointGuard guard(builder);
468 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
469 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
470 for (llvm::Type *ty : continuationBlockPHITypes)
471 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
477 for (
Block *bb : blocks) {
478 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
481 if (bb->isEntryBlock()) {
482 assert(sourceTerminator->getNumSuccessors() == 1 &&
483 "provided entry block has multiple successors");
484 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
485 "ContinuationBlock is not the successor of the entry block");
486 sourceTerminator->setSuccessor(0, llvmBB);
489 llvm::IRBuilderBase::InsertPointGuard guard(builder);
491 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
492 return llvm::make_error<PreviouslyReportedError>();
497 builder.CreateBr(continuationBlock);
508 Operation *terminator = bb->getTerminator();
509 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
510 builder.CreateBr(continuationBlock);
512 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
513 (*continuationBlockPHIs)[i]->addIncoming(
527 return continuationBlock;
533 case omp::ClauseProcBindKind::Close:
534 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
535 case omp::ClauseProcBindKind::Master:
536 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
537 case omp::ClauseProcBindKind::Primary:
538 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
539 case omp::ClauseProcBindKind::Spread:
540 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
542 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
552 omp::BlockArgOpenMPOpInterface blockArgIface) {
554 blockArgIface.getBlockArgsPairs(blockArgsPairs);
555 for (
auto [var, arg] : blockArgsPairs)
572 .Case([&](omp::SimdOp op) {
574 cast<omp::BlockArgOpenMPOpInterface>(*op));
575 op.emitWarning() <<
"simd information on composite construct discarded";
579 return op->emitError() <<
"cannot ignore wrapper";
587 auto maskedOp = cast<omp::MaskedOp>(opInst);
588 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
593 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
595 auto ®ion = maskedOp.getRegion();
596 builder.restoreIP(codeGenIP);
604 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
606 llvm::Value *filterVal =
nullptr;
607 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
608 filterVal = moduleTranslation.
lookupValue(filterVar);
610 llvm::LLVMContext &llvmContext = builder.getContext();
614 assert(filterVal !=
nullptr);
615 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
616 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
623 builder.restoreIP(*afterIP);
631 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
632 auto masterOp = cast<omp::MasterOp>(opInst);
637 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
639 auto ®ion = masterOp.getRegion();
640 builder.restoreIP(codeGenIP);
648 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
650 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
651 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
658 builder.restoreIP(*afterIP);
666 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
667 auto criticalOp = cast<omp::CriticalOp>(opInst);
672 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
674 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
675 builder.restoreIP(codeGenIP);
683 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
685 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
686 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
687 llvm::Constant *hint =
nullptr;
690 if (criticalOp.getNameAttr()) {
693 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
694 auto criticalDeclareOp =
695 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
699 static_cast<int>(criticalDeclareOp.getHint()));
701 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
703 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
708 builder.restoreIP(*afterIP);
715 template <
typename OP>
718 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
719 mlirVars.reserve(blockArgs.size());
720 llvmVars.reserve(blockArgs.size());
721 collectPrivatizationDecls<OP>(op);
724 mlirVars.push_back(privateVar);
736 void collectPrivatizationDecls(OP op) {
737 std::optional<ArrayAttr> attr = op.getPrivateSyms();
741 privatizers.reserve(privatizers.size() + attr->size());
742 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
749 template <
typename T>
753 std::optional<ArrayAttr> attr = op.getReductionSyms();
757 reductions.reserve(reductions.size() + op.getNumReductionVars());
758 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
759 reductions.push_back(
760 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
771 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
779 if (llvm::hasSingleElement(region)) {
780 llvm::Instruction *potentialTerminator =
781 builder.GetInsertBlock()->empty() ? nullptr
782 : &builder.GetInsertBlock()->back();
784 if (potentialTerminator && potentialTerminator->isTerminator())
785 potentialTerminator->removeFromParent();
786 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
789 region.
front(),
true, builder)))
793 if (continuationBlockArgs)
795 *continuationBlockArgs,
802 if (potentialTerminator && potentialTerminator->isTerminator()) {
803 llvm::BasicBlock *block = builder.GetInsertBlock();
804 if (block->empty()) {
810 potentialTerminator->insertInto(block, block->begin());
812 potentialTerminator->insertAfter(&block->back());
826 if (continuationBlockArgs)
827 llvm::append_range(*continuationBlockArgs, phis);
828 builder.SetInsertPoint(*continuationBlock,
829 (*continuationBlock)->getFirstInsertionPt());
836 using OwningReductionGen =
837 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
838 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
840 using OwningAtomicReductionGen =
841 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
842 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
849 static OwningReductionGen
855 OwningReductionGen gen =
856 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
857 llvm::Value *lhs, llvm::Value *rhs,
858 llvm::Value *&result)
mutable
859 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
860 moduleTranslation.
mapValue(decl.getReductionLhsArg(), lhs);
861 moduleTranslation.
mapValue(decl.getReductionRhsArg(), rhs);
862 builder.restoreIP(insertPoint);
865 "omp.reduction.nonatomic.body", builder,
866 moduleTranslation, &phis)))
867 return llvm::createStringError(
868 "failed to inline `combiner` region of `omp.declare_reduction`");
869 result = llvm::getSingleElement(phis);
870 return builder.saveIP();
879 static OwningAtomicReductionGen
881 llvm::IRBuilderBase &builder,
883 if (decl.getAtomicReductionRegion().empty())
884 return OwningAtomicReductionGen();
889 OwningAtomicReductionGen atomicGen =
890 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
891 llvm::Value *lhs, llvm::Value *rhs)
mutable
892 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
893 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(), lhs);
894 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(), rhs);
895 builder.restoreIP(insertPoint);
898 "omp.reduction.atomic.body", builder,
899 moduleTranslation, &phis)))
900 return llvm::createStringError(
901 "failed to inline `atomic` region of `omp.declare_reduction`");
902 assert(phis.empty());
903 return builder.saveIP();
912 auto orderedOp = cast<omp::OrderedOp>(opInst);
917 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
918 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
919 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
921 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
923 size_t indexVecValues = 0;
924 while (indexVecValues < vecValues.size()) {
926 storeValues.reserve(numLoops);
927 for (
unsigned i = 0; i < numLoops; i++) {
928 storeValues.push_back(vecValues[indexVecValues]);
931 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
933 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
934 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
935 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
945 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
946 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
951 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
953 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
954 builder.restoreIP(codeGenIP);
962 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
964 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
965 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
967 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
972 builder.restoreIP(*afterIP);
978 struct DeferredStore {
979 DeferredStore(llvm::Value *value, llvm::Value *address)
980 : value(value), address(address) {}
983 llvm::Value *address;
990 template <
typename T>
993 llvm::IRBuilderBase &builder,
995 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1001 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1002 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1005 deferredStores.reserve(loop.getNumReductionVars());
1007 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1008 Region &allocRegion = reductionDecls[i].getAllocRegion();
1010 if (allocRegion.
empty())
1015 builder, moduleTranslation, &phis)))
1016 return loop.emitError(
1017 "failed to inline `alloc` region of `omp.declare_reduction`");
1019 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1020 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1024 llvm::Value *var = builder.CreateAlloca(
1025 moduleTranslation.
convertType(reductionDecls[i].getType()));
1027 llvm::Type *ptrTy = builder.getPtrTy();
1028 llvm::Value *castVar =
1029 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1030 llvm::Value *castPhi =
1031 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1033 deferredStores.emplace_back(castPhi, castVar);
1035 privateReductionVariables[i] = castVar;
1036 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1037 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1039 assert(allocRegion.
empty() &&
1040 "allocaction is implicit for by-val reduction");
1041 llvm::Value *var = builder.CreateAlloca(
1042 moduleTranslation.
convertType(reductionDecls[i].getType()));
1044 llvm::Type *ptrTy = builder.getPtrTy();
1045 llvm::Value *castVar =
1046 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1048 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1049 privateReductionVariables[i] = castVar;
1050 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1058 template <
typename T>
1065 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1066 Region &initializerRegion = reduction.getInitializerRegion();
1069 mlir::Value mlirSource = loop.getReductionVars()[i];
1070 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1071 assert(llvmSource &&
"lookup reduction var");
1072 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1075 llvm::Value *allocation =
1076 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1077 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1083 llvm::BasicBlock *block =
nullptr) {
1084 if (block ==
nullptr)
1085 block = builder.GetInsertBlock();
1087 if (block->empty() || block->getTerminator() ==
nullptr)
1088 builder.SetInsertPoint(block);
1090 builder.SetInsertPoint(block->getTerminator());
1098 template <
typename OP>
1099 static LogicalResult
1101 llvm::IRBuilderBase &builder,
1103 llvm::BasicBlock *latestAllocaBlock,
1109 if (op.getNumReductionVars() == 0)
1112 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1113 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1114 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1115 builder.restoreIP(allocaIP);
1118 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1120 if (!reductionDecls[i].getAllocRegion().empty())
1126 byRefVars[i] = builder.CreateAlloca(
1127 moduleTranslation.
convertType(reductionDecls[i].getType()));
1135 for (
auto [data, addr] : deferredStores)
1136 builder.CreateStore(data, addr);
1141 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1146 reductionVariableMap, i);
1149 "omp.reduction.neutral", builder,
1150 moduleTranslation, &phis)))
1153 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1154 "reduction neutral element declaration region");
1159 if (!reductionDecls[i].getAllocRegion().empty())
1168 builder.CreateStore(phis[0], byRefVars[i]);
1170 privateReductionVariables[i] = byRefVars[i];
1171 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1172 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1175 builder.CreateStore(phis[0], privateReductionVariables[i]);
1182 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1189 template <
typename T>
1191 T loop, llvm::IRBuilderBase &builder,
1198 unsigned numReductions = loop.getNumReductionVars();
1200 for (
unsigned i = 0; i < numReductions; ++i) {
1201 owningReductionGens.push_back(
1203 owningAtomicReductionGens.push_back(
1208 reductionInfos.reserve(numReductions);
1209 for (
unsigned i = 0; i < numReductions; ++i) {
1210 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1211 if (owningAtomicReductionGens[i])
1212 atomicGen = owningAtomicReductionGens[i];
1213 llvm::Value *variable =
1214 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1215 reductionInfos.push_back(
1217 privateReductionVariables[i],
1218 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1219 owningReductionGens[i],
1220 nullptr, atomicGen});
1225 static LogicalResult
1229 llvm::IRBuilderBase &builder, StringRef regionName,
1230 bool shouldLoadCleanupRegionArg =
true) {
1232 if (cleanupRegion->empty())
1238 llvm::Instruction *potentialTerminator =
1239 builder.GetInsertBlock()->empty() ? nullptr
1240 : &builder.GetInsertBlock()->back();
1241 if (potentialTerminator && potentialTerminator->isTerminator())
1242 builder.SetInsertPoint(potentialTerminator);
1243 llvm::Value *privateVarValue =
1244 shouldLoadCleanupRegionArg
1245 ? builder.CreateLoad(
1247 privateVariables[i])
1248 : privateVariables[i];
1253 moduleTranslation)))
1266 OP op, llvm::IRBuilderBase &builder,
1268 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1271 bool isNowait =
false,
bool isTeamsReduction =
false) {
1273 if (op.getNumReductionVars() == 0)
1285 owningReductionGens, owningAtomicReductionGens,
1286 privateReductionVariables, reductionInfos);
1291 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1292 builder.SetInsertPoint(tempTerminator);
1293 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1294 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1295 isByRef, isNowait, isTeamsReduction);
1300 if (!contInsertPoint->getBlock())
1301 return op->emitOpError() <<
"failed to convert reductions";
1303 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1304 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1309 tempTerminator->eraseFromParent();
1310 builder.restoreIP(*afterIP);
1314 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1315 [](omp::DeclareReductionOp reductionDecl) {
1316 return &reductionDecl.getCleanupRegion();
1319 moduleTranslation, builder,
1320 "omp.reduction.cleanup");
1331 template <
typename OP>
1335 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1340 if (op.getNumReductionVars() == 0)
1346 allocaIP, reductionDecls,
1347 privateReductionVariables, reductionVariableMap,
1348 deferredStores, isByRef)))
1352 allocaIP.getBlock(), reductionDecls,
1353 privateReductionVariables, reductionVariableMap,
1354 isByRef, deferredStores);
1364 static llvm::Value *
1368 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1371 Value blockArg = (*mappedPrivateVars)[privateVar];
1374 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1375 "A block argument corresponding to a mapped var should have "
1378 if (privVarType == blockArgType)
1385 if (!isa<LLVM::LLVMPointerType>(privVarType))
1386 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1399 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1401 Region &initRegion = privDecl.getInitRegion();
1402 if (initRegion.
empty())
1403 return llvmPrivateVar;
1407 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1408 assert(nonPrivateVar);
1409 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1410 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1415 moduleTranslation, &phis)))
1416 return llvm::createStringError(
1417 "failed to inline `init` region of `omp.private`");
1419 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1438 return llvm::Error::success();
1440 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1446 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1448 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1449 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1452 return privVarOrErr.takeError();
1454 llvmPrivateVar = privVarOrErr.get();
1455 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1460 return llvm::Error::success();
1470 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1473 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1474 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1475 allocaTerminator->getIterator()),
1476 true, allocaTerminator->getStableDebugLoc(),
1477 "omp.region.after_alloca");
1479 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1481 allocaTerminator = allocaIP.getBlock()->getTerminator();
1482 builder.SetInsertPoint(allocaTerminator);
1484 assert(allocaTerminator->getNumSuccessors() == 1 &&
1485 "This is an unconditional branch created by splitBB");
1487 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1488 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1490 unsigned int allocaAS =
1491 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1494 .getProgramAddressSpace();
1496 for (
auto [privDecl, mlirPrivVar, blockArg] :
1499 llvm::Type *llvmAllocType =
1500 moduleTranslation.
convertType(privDecl.getType());
1501 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1502 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1503 llvmAllocType,
nullptr,
"omp.private.alloc");
1504 if (allocaAS != defaultAS)
1505 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1506 builder.getPtrTy(defaultAS));
1508 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1511 return afterAllocas;
1521 bool needsFirstprivate =
1522 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1523 return privOp.getDataSharingType() ==
1524 omp::DataSharingClauseType::FirstPrivate;
1527 if (!needsFirstprivate)
1530 llvm::BasicBlock *copyBlock =
1531 splitBB(builder,
true,
"omp.private.copy");
1534 for (
auto [decl, mlirVar, llvmVar] :
1535 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1536 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1540 Region ©Region = decl.getCopyRegion();
1544 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1545 assert(nonPrivateVar);
1546 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1549 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1553 moduleTranslation)))
1554 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1569 static LogicalResult
1576 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1577 [](omp::PrivateClauseOp privatizer) {
1578 return &privatizer.getDeallocRegion();
1582 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1583 "omp.private.dealloc",
false)))
1584 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1585 "`omp.private` op in");
1597 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1604 static LogicalResult
1607 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1608 using StorableBodyGenCallbackTy =
1609 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1611 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1617 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1621 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1625 sectionsOp.getNumReductionVars());
1629 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1632 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1633 reductionDecls, privateReductionVariables, reductionVariableMap,
1640 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1644 Region ®ion = sectionOp.getRegion();
1645 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1646 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1647 builder.restoreIP(codeGenIP);
1654 sectionsOp.getRegion().getNumArguments());
1655 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1656 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1657 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1659 moduleTranslation.
mapValue(sectionArg, llvmVal);
1666 sectionCBs.push_back(sectionCB);
1672 if (sectionCBs.empty())
1675 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1680 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1681 llvm::Value &vPtr, llvm::Value *&replacementValue)
1682 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1683 replacementValue = &vPtr;
1689 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1693 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1694 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1696 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1697 sectionsOp.getNowait());
1702 builder.restoreIP(*afterIP);
1706 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1707 privateReductionVariables, isByRef, sectionsOp.getNowait());
1711 static LogicalResult
1714 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1715 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1720 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1721 builder.restoreIP(codegenIP);
1723 builder, moduleTranslation)
1726 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1730 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1733 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1734 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1735 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1736 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1737 llvmCPFuncs.push_back(
1741 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1743 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1749 builder.restoreIP(*afterIP);
1755 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1760 for (
auto ra : iface.getReductionBlockArgs())
1761 for (
auto &use : ra.getUses()) {
1762 auto *useOp = use.getOwner();
1764 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1765 debugUses.push_back(useOp);
1769 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1774 Operation *currentOp = currentDistOp.getOperation();
1775 if (distOp && (distOp != currentOp))
1784 for (
auto use : debugUses)
1790 static LogicalResult
1793 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1798 unsigned numReductionVars = op.getNumReductionVars();
1802 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1808 if (doTeamsReduction) {
1809 isByRef =
getIsByRef(op.getReductionByref());
1811 assert(isByRef.size() == op.getNumReductionVars());
1814 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1819 op, reductionArgs, builder, moduleTranslation, allocaIP,
1820 reductionDecls, privateReductionVariables, reductionVariableMap,
1825 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1827 moduleTranslation, allocaIP);
1828 builder.restoreIP(codegenIP);
1834 llvm::Value *numTeamsLower =
nullptr;
1835 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1836 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1838 llvm::Value *numTeamsUpper =
nullptr;
1839 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1840 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1842 llvm::Value *threadLimit =
nullptr;
1843 if (
Value threadLimitVar = op.getThreadLimit())
1844 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1846 llvm::Value *ifExpr =
nullptr;
1847 if (
Value ifVar = op.getIfExpr())
1850 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1851 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1853 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1858 builder.restoreIP(*afterIP);
1859 if (doTeamsReduction) {
1862 op, builder, moduleTranslation, allocaIP, reductionDecls,
1863 privateReductionVariables, isByRef,
1873 if (dependVars.empty())
1875 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
1876 llvm::omp::RTLDependenceKindTy type;
1878 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
1879 case mlir::omp::ClauseTaskDepend::taskdependin:
1880 type = llvm::omp::RTLDependenceKindTy::DepIn;
1885 case mlir::omp::ClauseTaskDepend::taskdependout:
1886 case mlir::omp::ClauseTaskDepend::taskdependinout:
1887 type = llvm::omp::RTLDependenceKindTy::DepInOut;
1889 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
1890 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
1892 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
1893 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
1896 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
1897 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
1898 dds.emplace_back(dd);
1910 llvm::IRBuilderBase &llvmBuilder,
1912 llvm::omp::Directive cancelDirective) {
1913 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1914 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
1918 llvmBuilder.restoreIP(ip);
1924 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
1925 return llvm::Error::success();
1930 ompBuilder.pushFinalizationCB(
1940 llvm::OpenMPIRBuilder &ompBuilder,
1941 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1942 ompBuilder.popFinalizationCB();
1943 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
1944 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1945 assert(cancelBranch->getNumSuccessors() == 1 &&
1946 "cancel branch should have one target");
1947 cancelBranch->setSuccessor(0, constructFini);
1954 class TaskContextStructManager {
1956 TaskContextStructManager(llvm::IRBuilderBase &builder,
1959 : builder{builder}, moduleTranslation{moduleTranslation},
1960 privateDecls{privateDecls} {}
1966 void generateTaskContextStruct();
1972 void createGEPsToPrivateVars();
1975 void freeStructPtr();
1978 return llvmPrivateVarGEPs;
1981 llvm::Value *getStructPtr() {
return structPtr; }
1984 llvm::IRBuilderBase &builder;
1996 llvm::Value *structPtr =
nullptr;
1998 llvm::Type *structTy =
nullptr;
2002 void TaskContextStructManager::generateTaskContextStruct() {
2003 if (privateDecls.empty())
2005 privateVarTypes.reserve(privateDecls.size());
2007 for (omp::PrivateClauseOp &privOp : privateDecls) {
2010 if (!privOp.readsFromMold())
2012 Type mlirType = privOp.getType();
2013 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2019 llvm::DataLayout dataLayout =
2020 builder.GetInsertBlock()->getModule()->getDataLayout();
2021 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2022 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2025 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2027 "omp.task.context_ptr");
2030 void TaskContextStructManager::createGEPsToPrivateVars() {
2032 assert(privateVarTypes.empty());
2037 llvmPrivateVarGEPs.clear();
2038 llvmPrivateVarGEPs.reserve(privateDecls.size());
2039 llvm::Value *zero = builder.getInt32(0);
2041 for (
auto privDecl : privateDecls) {
2042 if (!privDecl.readsFromMold()) {
2044 llvmPrivateVarGEPs.push_back(
nullptr);
2047 llvm::Value *iVal = builder.getInt32(i);
2048 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2049 llvmPrivateVarGEPs.push_back(gep);
2054 void TaskContextStructManager::freeStructPtr() {
2058 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2060 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2061 builder.CreateFree(structPtr);
2065 static LogicalResult
2068 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2073 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2085 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2090 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2091 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2092 builder.getContext(),
"omp.task.start",
2093 builder.GetInsertBlock()->getParent());
2094 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2095 builder.SetInsertPoint(branchToTaskStartBlock);
2098 llvm::BasicBlock *copyBlock =
2099 splitBB(builder,
true,
"omp.private.copy");
2100 llvm::BasicBlock *initBlock =
2101 splitBB(builder,
true,
"omp.private.init");
2117 moduleTranslation, allocaIP);
2120 builder.SetInsertPoint(initBlock->getTerminator());
2123 taskStructMgr.generateTaskContextStruct();
2130 taskStructMgr.createGEPsToPrivateVars();
2132 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2135 taskStructMgr.getLLVMPrivateVarGEPs())) {
2137 if (!privDecl.readsFromMold())
2139 assert(llvmPrivateVarAlloc &&
2140 "reads from mold so shouldn't have been skipped");
2143 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2144 blockArg, llvmPrivateVarAlloc, initBlock);
2145 if (!privateVarOrErr)
2146 return handleError(privateVarOrErr, *taskOp.getOperation());
2148 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2149 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2156 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2157 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2158 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2160 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2161 llvmPrivateVarAlloc);
2163 assert(llvmPrivateVarAlloc->getType() ==
2164 moduleTranslation.
convertType(blockArg.getType()));
2174 builder, moduleTranslation, privateVarsInfo.
mlirVars,
2175 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers)))
2176 return llvm::failure();
2179 builder.SetInsertPoint(taskStartBlock);
2181 auto bodyCB = [&](InsertPointTy allocaIP,
2182 InsertPointTy codegenIP) -> llvm::Error {
2186 moduleTranslation, allocaIP);
2189 builder.restoreIP(codegenIP);
2191 llvm::BasicBlock *privInitBlock =
nullptr;
2196 auto [blockArg, privDecl, mlirPrivVar] = zip;
2198 if (privDecl.readsFromMold())
2201 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2202 llvm::Type *llvmAllocType =
2203 moduleTranslation.
convertType(privDecl.getType());
2204 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2205 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2206 llvmAllocType,
nullptr,
"omp.private.alloc");
2209 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2210 blockArg, llvmPrivateVar, privInitBlock);
2211 if (!privateVarOrError)
2212 return privateVarOrError.takeError();
2213 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2214 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2217 taskStructMgr.createGEPsToPrivateVars();
2218 for (
auto [i, llvmPrivVar] :
2221 assert(privateVarsInfo.
llvmVars[i] &&
2222 "This is added in the loop above");
2225 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2230 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2234 if (!privateDecl.readsFromMold())
2237 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2238 llvmPrivateVar = builder.CreateLoad(
2239 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2241 assert(llvmPrivateVar->getType() ==
2242 moduleTranslation.
convertType(blockArg.getType()));
2243 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2247 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2248 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2249 return llvm::make_error<PreviouslyReportedError>();
2251 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2256 return llvm::make_error<PreviouslyReportedError>();
2259 taskStructMgr.freeStructPtr();
2261 return llvm::Error::success();
2270 llvm::omp::Directive::OMPD_taskgroup);
2274 moduleTranslation, dds);
2276 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2277 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2279 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2281 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2282 taskOp.getMergeable(),
2283 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2284 moduleTranslation.
lookupValue(taskOp.getPriority()));
2292 builder.restoreIP(*afterIP);
2297 static LogicalResult
2300 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2304 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2305 builder.restoreIP(codegenIP);
2307 builder, moduleTranslation)
2312 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2313 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2320 builder.restoreIP(*afterIP);
2324 static LogicalResult
2335 static LogicalResult
2339 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2343 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2345 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2349 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2352 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2353 llvm::Type *ivType = step->getType();
2354 llvm::Value *chunk =
nullptr;
2355 if (wsloopOp.getScheduleChunk()) {
2356 llvm::Value *chunkVar =
2357 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2358 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2365 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2369 wsloopOp.getNumReductionVars());
2372 builder, moduleTranslation, privateVarsInfo, allocaIP);
2379 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2384 moduleTranslation, allocaIP, reductionDecls,
2385 privateReductionVariables, reductionVariableMap,
2386 deferredStores, isByRef)))
2395 builder, moduleTranslation, privateVarsInfo.
mlirVars,
2399 assert(afterAllocas.get()->getSinglePredecessor());
2402 afterAllocas.get()->getSinglePredecessor(),
2403 reductionDecls, privateReductionVariables,
2404 reductionVariableMap, isByRef, deferredStores)))
2408 bool isOrdered = wsloopOp.getOrdered().has_value();
2409 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2410 bool isSimd = wsloopOp.getScheduleSimd();
2411 bool loopNeedsBarrier = !wsloopOp.getNowait();
2416 llvm::omp::WorksharingLoopType workshareLoopType =
2417 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2418 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2419 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2423 llvm::omp::Directive::OMPD_for);
2425 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2427 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2432 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2435 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2436 ompBuilder->applyWorkshareLoop(
2437 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2438 convertToScheduleKind(schedule), chunk, isSimd,
2439 scheduleMod == omp::ScheduleModifier::monotonic,
2440 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2451 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2452 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2462 static LogicalResult
2465 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2467 assert(isByRef.size() == opInst.getNumReductionVars());
2479 opInst.getNumReductionVars());
2482 auto bodyGenCB = [&](InsertPointTy allocaIP,
2483 InsertPointTy codeGenIP) -> llvm::Error {
2485 builder, moduleTranslation, privateVarsInfo, allocaIP);
2487 return llvm::make_error<PreviouslyReportedError>();
2493 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2496 InsertPointTy(allocaIP.getBlock(),
2497 allocaIP.getBlock()->getTerminator()->getIterator());
2500 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2501 reductionDecls, privateReductionVariables, reductionVariableMap,
2502 deferredStores, isByRef)))
2503 return llvm::make_error<PreviouslyReportedError>();
2505 assert(afterAllocas.get()->getSinglePredecessor());
2506 builder.restoreIP(codeGenIP);
2512 return llvm::make_error<PreviouslyReportedError>();
2515 builder, moduleTranslation, privateVarsInfo.
mlirVars,
2517 return llvm::make_error<PreviouslyReportedError>();
2521 afterAllocas.get()->getSinglePredecessor(),
2522 reductionDecls, privateReductionVariables,
2523 reductionVariableMap, isByRef, deferredStores)))
2524 return llvm::make_error<PreviouslyReportedError>();
2529 moduleTranslation, allocaIP);
2533 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2535 return regionBlock.takeError();
2538 if (opInst.getNumReductionVars() > 0) {
2544 owningReductionGens, owningAtomicReductionGens,
2545 privateReductionVariables, reductionInfos);
2548 builder.SetInsertPoint((*regionBlock)->getTerminator());
2551 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2552 builder.SetInsertPoint(tempTerminator);
2554 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2555 ompBuilder->createReductions(
2556 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2558 if (!contInsertPoint)
2559 return contInsertPoint.takeError();
2561 if (!contInsertPoint->getBlock())
2562 return llvm::make_error<PreviouslyReportedError>();
2564 tempTerminator->eraseFromParent();
2565 builder.restoreIP(*contInsertPoint);
2568 return llvm::Error::success();
2571 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2572 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2581 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2582 InsertPointTy oldIP = builder.saveIP();
2583 builder.restoreIP(codeGenIP);
2588 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2589 [](omp::DeclareReductionOp reductionDecl) {
2590 return &reductionDecl.getCleanupRegion();
2593 reductionCleanupRegions, privateReductionVariables,
2594 moduleTranslation, builder,
"omp.reduction.cleanup")))
2595 return llvm::createStringError(
2596 "failed to inline `cleanup` region of `omp.declare_reduction`");
2601 return llvm::make_error<PreviouslyReportedError>();
2603 builder.restoreIP(oldIP);
2604 return llvm::Error::success();
2607 llvm::Value *ifCond =
nullptr;
2608 if (
auto ifVar = opInst.getIfExpr())
2610 llvm::Value *numThreads =
nullptr;
2611 if (
auto numThreadsVar = opInst.getNumThreads())
2612 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2613 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2614 if (
auto bind = opInst.getProcBindKind())
2618 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2620 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2622 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2623 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2624 ifCond, numThreads, pbKind, isCancellable);
2629 builder.restoreIP(*afterIP);
2634 static llvm::omp::OrderKind
2637 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2639 case omp::ClauseOrderKind::Concurrent:
2640 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2642 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2646 static LogicalResult
2650 auto simdOp = cast<omp::SimdOp>(opInst);
2656 if (simdOp.isComposite()) {
2661 builder, moduleTranslation);
2669 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2673 builder, moduleTranslation, privateVarsInfo, allocaIP);
2682 llvm::ConstantInt *simdlen =
nullptr;
2683 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2684 simdlen = builder.getInt64(simdlenVar.value());
2686 llvm::ConstantInt *safelen =
nullptr;
2687 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2688 safelen = builder.getInt64(safelenVar.value());
2690 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2693 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2694 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2696 for (
size_t i = 0; i < operands.size(); ++i) {
2697 llvm::Value *alignment =
nullptr;
2698 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
2699 llvm::Type *ty = llvmVal->getType();
2701 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2702 alignment = builder.getInt64(intAttr.getInt());
2703 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2704 assert(alignment &&
"Invalid alignment value");
2705 auto curInsert = builder.saveIP();
2706 builder.SetInsertPoint(sourceBlock);
2707 llvmVal = builder.CreateLoad(ty, llvmVal);
2708 builder.restoreIP(curInsert);
2709 alignedVars[llvmVal] = alignment;
2713 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2718 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2720 ompBuilder->applySimd(loopInfo, alignedVars,
2722 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2724 order, simdlen, safelen);
2732 static LogicalResult
2736 auto loopOp = cast<omp::LoopNestOp>(opInst);
2739 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2744 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2745 llvm::Value *iv) -> llvm::Error {
2748 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2753 bodyInsertPoints.push_back(ip);
2755 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2756 return llvm::Error::success();
2759 builder.restoreIP(ip);
2761 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
2763 return regionBlock.takeError();
2765 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2766 return llvm::Error::success();
2774 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2775 llvm::Value *lowerBound =
2776 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
2777 llvm::Value *upperBound =
2778 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
2779 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
2784 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2785 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2787 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2789 computeIP = loopInfos.front()->getPreheaderIP();
2793 ompBuilder->createCanonicalLoop(
2794 loc, bodyGen, lowerBound, upperBound, step,
2795 true, loopOp.getLoopInclusive(), computeIP);
2800 loopInfos.push_back(*loopResult);
2805 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2806 loopInfos.front()->getAfterIP();
2810 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
2811 [&](OpenMPLoopInfoStackFrame &frame) {
2812 frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2820 builder.restoreIP(afterIP);
2825 static llvm::AtomicOrdering
2828 return llvm::AtomicOrdering::Monotonic;
2831 case omp::ClauseMemoryOrderKind::Seq_cst:
2832 return llvm::AtomicOrdering::SequentiallyConsistent;
2833 case omp::ClauseMemoryOrderKind::Acq_rel:
2834 return llvm::AtomicOrdering::AcquireRelease;
2835 case omp::ClauseMemoryOrderKind::Acquire:
2836 return llvm::AtomicOrdering::Acquire;
2837 case omp::ClauseMemoryOrderKind::Release:
2838 return llvm::AtomicOrdering::Release;
2839 case omp::ClauseMemoryOrderKind::Relaxed:
2840 return llvm::AtomicOrdering::Monotonic;
2842 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
2846 static LogicalResult
2849 auto readOp = cast<omp::AtomicReadOp>(opInst);
2854 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2857 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2860 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
2861 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
2863 llvm::Type *elementType =
2864 moduleTranslation.
convertType(readOp.getElementType());
2866 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
2867 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
2868 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
2873 static LogicalResult
2876 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
2881 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2884 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2886 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
2887 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
2888 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
2889 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
2892 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
2900 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
2901 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
2902 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
2903 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
2904 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
2905 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
2906 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
2907 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
2908 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
2909 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
2913 static LogicalResult
2915 llvm::IRBuilderBase &builder,
2922 auto &innerOpList = opInst.getRegion().front().getOperations();
2923 bool isXBinopExpr{
false};
2924 llvm::AtomicRMWInst::BinOp binop;
2926 llvm::Value *llvmExpr =
nullptr;
2927 llvm::Value *llvmX =
nullptr;
2928 llvm::Type *llvmXElementType =
nullptr;
2929 if (innerOpList.size() == 2) {
2935 opInst.getRegion().getArgument(0))) {
2936 return opInst.emitError(
"no atomic update operation with region argument"
2937 " as operand found inside atomic.update region");
2940 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
2942 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
2946 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2948 llvmX = moduleTranslation.
lookupValue(opInst.getX());
2950 opInst.getRegion().getArgument(0).getType());
2951 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2955 llvm::AtomicOrdering atomicOrdering =
2960 [&opInst, &moduleTranslation](
2961 llvm::Value *atomicx,
2964 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
2965 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
2966 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
2967 return llvm::make_error<PreviouslyReportedError>();
2969 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
2970 assert(yieldop && yieldop.getResults().size() == 1 &&
2971 "terminator must be omp.yield op and it must have exactly one "
2973 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
2978 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2979 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2980 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
2981 atomicOrdering, binop, updateFn,
2987 builder.restoreIP(*afterIP);
2991 static LogicalResult
2993 llvm::IRBuilderBase &builder,
3000 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3001 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3003 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3004 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3006 assert((atomicUpdateOp || atomicWriteOp) &&
3007 "internal op must be an atomic.update or atomic.write op");
3009 if (atomicWriteOp) {
3010 isPostfixUpdate =
true;
3011 mlirExpr = atomicWriteOp.getExpr();
3013 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3014 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3015 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3018 if (innerOpList.size() == 2) {
3021 atomicUpdateOp.getRegion().getArgument(0))) {
3022 return atomicUpdateOp.emitError(
3023 "no atomic update operation with region argument"
3024 " as operand found inside atomic.update region");
3028 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3031 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3035 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3036 llvm::Value *llvmX =
3037 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3038 llvm::Value *llvmV =
3039 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3040 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3041 atomicCaptureOp.getAtomicReadOp().getElementType());
3042 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3045 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3049 llvm::AtomicOrdering atomicOrdering =
3053 [&](llvm::Value *atomicx,
3056 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3057 Block &bb = *atomicUpdateOp.getRegion().
begin();
3058 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3060 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3061 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3062 return llvm::make_error<PreviouslyReportedError>();
3064 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3065 assert(yieldop && yieldop.getResults().size() == 1 &&
3066 "terminator must be omp.yield op and it must have exactly one "
3068 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3073 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3074 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3075 ompBuilder->createAtomicCapture(
3076 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3077 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
3079 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3082 builder.restoreIP(*afterIP);
3087 omp::ClauseCancellationConstructType directive) {
3088 switch (directive) {
3089 case omp::ClauseCancellationConstructType::Loop:
3090 return llvm::omp::Directive::OMPD_for;
3091 case omp::ClauseCancellationConstructType::Parallel:
3092 return llvm::omp::Directive::OMPD_parallel;
3093 case omp::ClauseCancellationConstructType::Sections:
3094 return llvm::omp::Directive::OMPD_sections;
3095 case omp::ClauseCancellationConstructType::Taskgroup:
3096 return llvm::omp::Directive::OMPD_taskgroup;
3100 static LogicalResult
3106 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3109 llvm::Value *ifCond =
nullptr;
3110 if (
Value ifVar = op.getIfExpr())
3113 llvm::omp::Directive cancelledDirective =
3116 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3117 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3119 if (failed(
handleError(afterIP, *op.getOperation())))
3122 builder.restoreIP(afterIP.get());
3127 static LogicalResult
3129 llvm::IRBuilderBase &builder,
3134 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3137 llvm::omp::Directive cancelledDirective =
3140 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3141 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3143 if (failed(
handleError(afterIP, *op.getOperation())))
3146 builder.restoreIP(afterIP.get());
3153 static LogicalResult
3156 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3158 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3163 Value symAddr = threadprivateOp.getSymAddr();
3166 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3169 if (!isa<LLVM::AddressOfOp>(symOp))
3170 return opInst.
emitError(
"Addressing symbol not found");
3171 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3173 LLVM::GlobalOp global =
3174 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3175 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3177 if (!ompBuilder->Config.isTargetDevice()) {
3178 llvm::Type *type = globalValue->getValueType();
3179 llvm::TypeSize typeSize =
3180 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3182 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3183 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3184 ompLoc, globalValue, size, global.getSymName() +
".cache");
3193 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3195 switch (deviceClause) {
3196 case mlir::omp::DeclareTargetDeviceType::host:
3197 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3199 case mlir::omp::DeclareTargetDeviceType::nohost:
3200 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3202 case mlir::omp::DeclareTargetDeviceType::any:
3203 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3206 llvm_unreachable(
"unhandled device clause");
3209 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3211 mlir::omp::DeclareTargetCaptureClause captureClause) {
3212 switch (captureClause) {
3213 case mlir::omp::DeclareTargetCaptureClause::to:
3214 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3215 case mlir::omp::DeclareTargetCaptureClause::link:
3216 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3217 case mlir::omp::DeclareTargetCaptureClause::enter:
3218 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3220 llvm_unreachable(
"unhandled capture clause");
3225 llvm::OpenMPIRBuilder &ompBuilder) {
3227 llvm::raw_svector_ostream os(suffix);
3230 auto fileInfoCallBack = [&loc]() {
3231 return std::pair<std::string, uint64_t>(
3232 llvm::StringRef(loc.getFilename()), loc.getLine());
3236 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3238 os <<
"_decl_tgt_ref_ptr";
3244 if (
auto addressOfOp =
3245 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3246 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3247 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3248 if (
auto declareTargetGlobal =
3249 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3250 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3251 mlir::omp::DeclareTargetCaptureClause::link)
3260 static llvm::Value *
3267 if (
auto addressOfOp =
3268 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3269 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3270 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3271 addressOfOp.getGlobalName()))) {
3273 if (
auto declareTargetGlobal =
3274 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3275 gOp.getOperation())) {
3279 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3280 mlir::omp::DeclareTargetCaptureClause::link) ||
3281 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3282 mlir::omp::DeclareTargetCaptureClause::to &&
3283 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3287 if (gOp.getSymName().contains(suffix))
3292 (gOp.getSymName().str() + suffix.str()).str());
3303 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3307 void append(MapInfosTy &curInfo) {
3308 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3309 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3318 struct MapInfoData : MapInfosTy {
3330 void append(MapInfoData &CurInfo) {
3331 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3332 CurInfo.IsDeclareTarget.end());
3333 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3334 OriginalValue.append(CurInfo.OriginalValue.begin(),
3335 CurInfo.OriginalValue.end());
3336 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3337 MapInfosTy::append(CurInfo);
3343 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3344 arrTy.getElementType()))
3360 Operation *clauseOp, llvm::Value *basePointer,
3361 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3363 if (
auto memberClause =
3364 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3369 if (!memberClause.getBounds().empty()) {
3370 llvm::Value *elementCount = builder.getInt64(1);
3371 for (
auto bounds : memberClause.getBounds()) {
3372 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3373 bounds.getDefiningOp())) {
3378 elementCount = builder.CreateMul(
3382 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3383 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3384 builder.getInt64(1)));
3391 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3399 return builder.CreateMul(elementCount,
3400 builder.getInt64(underlyingTypeSzInBits / 8));
3413 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3421 for (
Value mapValue : mapVars) {
3422 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3423 for (
auto member : map.getMembers())
3424 if (member == mapOp)
3431 for (
Value mapValue : mapVars) {
3432 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3434 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3435 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
3436 mapData.Pointers.push_back(mapData.OriginalValue.back());
3438 if (llvm::Value *refPtr =
3440 moduleTranslation)) {
3441 mapData.IsDeclareTarget.push_back(
true);
3442 mapData.BasePointers.push_back(refPtr);
3444 mapData.IsDeclareTarget.push_back(
false);
3445 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3448 mapData.BaseType.push_back(
3449 moduleTranslation.
convertType(mapOp.getVarType()));
3450 mapData.Sizes.push_back(
3451 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3452 mapData.BaseType.back(), builder, moduleTranslation));
3453 mapData.MapClause.push_back(mapOp.getOperation());
3454 mapData.Types.push_back(
3455 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3459 if (mapOp.getMapperId())
3460 mapData.Mappers.push_back(
3461 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3462 mapOp, mapOp.getMapperIdAttr()));
3464 mapData.Mappers.push_back(
nullptr);
3465 mapData.IsAMapping.push_back(
true);
3466 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3469 auto findMapInfo = [&mapData](llvm::Value *val,
3470 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3473 for (llvm::Value *basePtr : mapData.OriginalValue) {
3474 if (basePtr == val && mapData.IsAMapping[index]) {
3476 mapData.Types[index] |=
3477 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3478 mapData.DevicePointers[index] = devInfoTy;
3487 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3488 for (
Value mapValue : useDevOperands) {
3489 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3491 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3492 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3495 if (!findMapInfo(origValue, devInfoTy)) {
3496 mapData.OriginalValue.push_back(origValue);
3497 mapData.Pointers.push_back(mapData.OriginalValue.back());
3498 mapData.IsDeclareTarget.push_back(
false);
3499 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3500 mapData.BaseType.push_back(
3501 moduleTranslation.
convertType(mapOp.getVarType()));
3502 mapData.Sizes.push_back(builder.getInt64(0));
3503 mapData.MapClause.push_back(mapOp.getOperation());
3504 mapData.Types.push_back(
3505 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3508 mapData.DevicePointers.push_back(devInfoTy);
3509 mapData.Mappers.push_back(
nullptr);
3510 mapData.IsAMapping.push_back(
false);
3511 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3516 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3517 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3519 for (
Value mapValue : hasDevAddrOperands) {
3520 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3522 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3523 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3525 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3526 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3528 mapData.OriginalValue.push_back(origValue);
3529 mapData.BasePointers.push_back(origValue);
3530 mapData.Pointers.push_back(origValue);
3531 mapData.IsDeclareTarget.push_back(
false);
3532 mapData.BaseType.push_back(
3533 moduleTranslation.
convertType(mapOp.getVarType()));
3534 mapData.Sizes.push_back(
3535 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3536 mapData.MapClause.push_back(mapOp.getOperation());
3537 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3541 mapData.Types.push_back(mapType);
3545 if (mapOp.getMapperId()) {
3546 mapData.Mappers.push_back(
3547 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3548 mapOp, mapOp.getMapperIdAttr()));
3550 mapData.Mappers.push_back(
nullptr);
3553 mapData.Types.push_back(
3554 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3555 mapData.Mappers.push_back(
nullptr);
3559 mapData.DevicePointers.push_back(
3560 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3561 mapData.IsAMapping.push_back(
false);
3562 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3567 auto *res = llvm::find(mapData.MapClause, memberOp);
3568 assert(res != mapData.MapClause.end() &&
3569 "MapInfoOp for member not found in MapData, cannot return index");
3570 return std::distance(mapData.MapClause.begin(), res);
3575 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3577 if (indexAttr.size() == 1)
3578 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3581 std::iota(indices.begin(), indices.end(), 0);
3583 llvm::sort(indices.begin(), indices.end(),
3584 [&](
const size_t a,
const size_t b) {
3585 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3586 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3587 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3588 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3589 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3591 if (aIndex == bIndex)
3594 if (aIndex < bIndex)
3597 if (aIndex > bIndex)
3604 return memberIndicesA.size() < memberIndicesB.size();
3607 return llvm::cast<omp::MapInfoOp>(
3608 mapInfo.getMembers()[indices.front()].getDefiningOp());
3630 std::vector<llvm::Value *>
3632 llvm::IRBuilderBase &builder,
bool isArrayTy,
3634 std::vector<llvm::Value *> idx;
3645 idx.push_back(builder.getInt64(0));
3646 for (
int i = bounds.size() - 1; i >= 0; --i) {
3647 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3648 bounds[i].getDefiningOp())) {
3649 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
3671 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3672 for (
size_t i = 1; i < bounds.size(); ++i) {
3673 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3674 bounds[i].getDefiningOp())) {
3675 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3676 moduleTranslation.
lookupValue(boundOp.getExtent()),
3677 dimensionIndexSizeOffset[i - 1]));
3685 for (
int i = bounds.size() - 1; i >= 0; --i) {
3686 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3687 bounds[i].getDefiningOp())) {
3689 idx.emplace_back(builder.CreateMul(
3690 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
3691 dimensionIndexSizeOffset[i]));
3693 idx.back() = builder.CreateAdd(
3694 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
3695 boundOp.getLowerBound()),
3696 dimensionIndexSizeOffset[i]));
3721 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3722 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
3723 assert(!ompBuilder.Config.isTargetDevice() &&
3724 "function only supported for host device codegen");
3727 combinedInfo.Types.emplace_back(
3729 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3730 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3731 combinedInfo.DevicePointers.emplace_back(
3732 mapData.DevicePointers[mapDataIndex]);
3733 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
3735 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3736 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3746 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3748 llvm::Value *lowAddr, *highAddr;
3749 if (!parentClause.getPartialMap()) {
3750 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3751 builder.getPtrTy());
3752 highAddr = builder.CreatePointerCast(
3753 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3754 mapData.Pointers[mapDataIndex], 1),
3755 builder.getPtrTy());
3756 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3758 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3761 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3762 builder.getPtrTy());
3765 highAddr = builder.CreatePointerCast(
3766 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3767 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3768 builder.getPtrTy());
3769 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3772 llvm::Value *size = builder.CreateIntCast(
3773 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3774 builder.getInt64Ty(),
3776 combinedInfo.Sizes.push_back(size);
3778 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3779 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3787 if (!parentClause.getPartialMap()) {
3792 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3793 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3794 combinedInfo.Types.emplace_back(mapFlag);
3795 combinedInfo.DevicePointers.emplace_back(
3797 combinedInfo.Mappers.emplace_back(
nullptr);
3799 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3800 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3801 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3802 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3804 return memberOfFlag;
3816 if (mapOp.getVarPtrPtr())
3831 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3832 MapInfoData &mapData, uint64_t mapDataIndex,
3833 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
3834 assert(!ompBuilder.Config.isTargetDevice() &&
3835 "function only supported for host device codegen");
3838 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3840 for (
auto mappedMembers : parentClause.getMembers()) {
3842 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
3845 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
3856 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
3857 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3858 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3859 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3860 combinedInfo.Types.emplace_back(mapFlag);
3861 combinedInfo.DevicePointers.emplace_back(
3863 combinedInfo.Mappers.emplace_back(
nullptr);
3864 combinedInfo.Names.emplace_back(
3866 combinedInfo.BasePointers.emplace_back(
3867 mapData.BasePointers[mapDataIndex]);
3868 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
3869 combinedInfo.Sizes.emplace_back(builder.getInt64(
3870 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
3876 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
3877 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3878 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3879 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3881 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3883 combinedInfo.Types.emplace_back(mapFlag);
3884 combinedInfo.DevicePointers.emplace_back(
3885 mapData.DevicePointers[memberDataIdx]);
3886 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
3887 combinedInfo.Names.emplace_back(
3889 uint64_t basePointerIndex =
3891 combinedInfo.BasePointers.emplace_back(
3892 mapData.BasePointers[basePointerIndex]);
3893 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
3895 llvm::Value *size = mapData.Sizes[memberDataIdx];
3897 size = builder.CreateSelect(
3898 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
3899 builder.getInt64(0), size);
3902 combinedInfo.Sizes.emplace_back(size);
3907 MapInfosTy &combinedInfo,
bool isTargetParams,
3908 int mapDataParentIdx = -1) {
3912 auto mapFlag = mapData.Types[mapDataIdx];
3913 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
3917 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3919 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
3920 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3922 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
3924 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3929 if (mapDataParentIdx >= 0)
3930 combinedInfo.BasePointers.emplace_back(
3931 mapData.BasePointers[mapDataParentIdx]);
3933 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
3935 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
3936 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
3937 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
3938 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
3939 combinedInfo.Types.emplace_back(mapFlag);
3940 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
3944 llvm::IRBuilderBase &builder,
3945 llvm::OpenMPIRBuilder &ompBuilder,
3947 MapInfoData &mapData, uint64_t mapDataIndex,
3948 bool isTargetParams) {
3949 assert(!ompBuilder.Config.isTargetDevice() &&
3950 "function only supported for host device codegen");
3953 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3958 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
3959 auto memberClause = llvm::cast<omp::MapInfoOp>(
3960 parentClause.getMembers()[0].getDefiningOp());
3977 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
3979 combinedInfo, mapData, mapDataIndex, isTargetParams);
3981 combinedInfo, mapData, mapDataIndex,
3982 memberOfParentFlag);
3992 llvm::IRBuilderBase &builder) {
3994 "function only supported for host device codegen");
3995 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
3997 if (!mapData.IsDeclareTarget[i]) {
3998 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3999 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4009 switch (captureKind) {
4010 case omp::VariableCaptureKind::ByRef: {
4011 llvm::Value *newV = mapData.Pointers[i];
4013 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4016 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4018 if (!offsetIdx.empty())
4019 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4021 mapData.Pointers[i] = newV;
4023 case omp::VariableCaptureKind::ByCopy: {
4024 llvm::Type *type = mapData.BaseType[i];
4026 if (mapData.Pointers[i]->getType()->isPointerTy())
4027 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4029 newV = mapData.Pointers[i];
4032 auto curInsert = builder.saveIP();
4034 auto *memTempAlloc =
4035 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4036 builder.restoreIP(curInsert);
4038 builder.CreateStore(newV, memTempAlloc);
4039 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4042 mapData.Pointers[i] = newV;
4043 mapData.BasePointers[i] = newV;
4045 case omp::VariableCaptureKind::This:
4046 case omp::VariableCaptureKind::VLAType:
4047 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4058 MapInfoData &mapData,
bool isTargetParams =
false) {
4060 "function only supported for host device codegen");
4082 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4085 if (mapData.IsAMember[i])
4088 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4089 if (!mapInfoOp.getMembers().empty()) {
4091 combinedInfo, mapData, i, isTargetParams);
4102 llvm::StringRef mapperFuncName);
4108 "function only supported for host device codegen");
4109 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4110 std::string mapperFuncName =
4112 {
"omp_mapper", declMapperOp.getSymName()});
4114 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4124 llvm::StringRef mapperFuncName) {
4126 "function only supported for host device codegen");
4127 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4128 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4131 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4134 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4137 MapInfosTy combinedInfo;
4139 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4140 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4141 builder.restoreIP(codeGenIP);
4142 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4143 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4144 builder.GetInsertBlock());
4145 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4148 return llvm::make_error<PreviouslyReportedError>();
4149 MapInfoData mapData;
4152 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4157 return combinedInfo;
4161 if (!combinedInfo.Mappers[i])
4168 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4170 return newFn.takeError();
4171 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4175 static LogicalResult
4178 llvm::Value *ifCond =
nullptr;
4179 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4183 llvm::omp::RuntimeFunction RTLFn;
4187 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4190 LogicalResult result =
4192 .Case([&](omp::TargetDataOp dataOp) {
4196 if (
auto ifVar = dataOp.getIfExpr())
4199 if (
auto devId = dataOp.getDevice())
4201 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4202 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4203 deviceID = intAttr.getInt();
4205 mapVars = dataOp.getMapVars();
4206 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4207 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4210 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4214 if (
auto ifVar = enterDataOp.getIfExpr())
4217 if (
auto devId = enterDataOp.getDevice())
4219 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4220 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4221 deviceID = intAttr.getInt();
4223 enterDataOp.getNowait()
4224 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4225 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4226 mapVars = enterDataOp.getMapVars();
4227 info.HasNoWait = enterDataOp.getNowait();
4230 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4234 if (
auto ifVar = exitDataOp.getIfExpr())
4237 if (
auto devId = exitDataOp.getDevice())
4239 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4240 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4241 deviceID = intAttr.getInt();
4243 RTLFn = exitDataOp.getNowait()
4244 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4245 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4246 mapVars = exitDataOp.getMapVars();
4247 info.HasNoWait = exitDataOp.getNowait();
4250 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4254 if (
auto ifVar = updateDataOp.getIfExpr())
4257 if (
auto devId = updateDataOp.getDevice())
4259 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4260 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4261 deviceID = intAttr.getInt();
4264 updateDataOp.getNowait()
4265 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4266 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4267 mapVars = updateDataOp.getMapVars();
4268 info.HasNoWait = updateDataOp.getNowait();
4272 llvm_unreachable(
"unexpected operation");
4279 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4280 MapInfoData mapData;
4282 builder, useDevicePtrVars, useDeviceAddrVars);
4285 MapInfosTy combinedInfo;
4286 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4287 builder.restoreIP(codeGenIP);
4288 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4289 return combinedInfo;
4295 [&moduleTranslation](
4296 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4300 for (
auto [arg, useDevVar] :
4301 llvm::zip_equal(blockArgs, useDeviceVars)) {
4303 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4304 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4305 : mapInfoOp.getVarPtr();
4308 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4309 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4310 mapInfoData.MapClause, mapInfoData.DevicePointers,
4311 mapInfoData.BasePointers)) {
4312 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4313 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4314 devicePointer != type)
4317 if (llvm::Value *devPtrInfoMap =
4318 mapper ? mapper(basePointer) : basePointer) {
4319 moduleTranslation.
mapValue(arg, devPtrInfoMap);
4326 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4327 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4328 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4329 builder.restoreIP(codeGenIP);
4330 assert(isa<omp::TargetDataOp>(op) &&
4331 "BodyGen requested for non TargetDataOp");
4332 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4333 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4334 switch (bodyGenType) {
4335 case BodyGenTy::Priv:
4337 if (!info.DevicePtrInfoMap.empty()) {
4338 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4339 blockArgIface.getUseDeviceAddrBlockArgs(),
4340 useDeviceAddrVars, mapData,
4341 [&](llvm::Value *basePointer) -> llvm::Value * {
4342 if (!info.DevicePtrInfoMap[basePointer].second)
4344 return builder.CreateLoad(
4346 info.DevicePtrInfoMap[basePointer].second);
4348 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4349 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4350 mapData, [&](llvm::Value *basePointer) {
4351 return info.DevicePtrInfoMap[basePointer].second;
4355 moduleTranslation)))
4356 return llvm::make_error<PreviouslyReportedError>();
4359 case BodyGenTy::DupNoPriv:
4362 builder.restoreIP(codeGenIP);
4364 case BodyGenTy::NoPriv:
4366 if (info.DevicePtrInfoMap.empty()) {
4369 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4370 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4371 blockArgIface.getUseDeviceAddrBlockArgs(),
4372 useDeviceAddrVars, mapData);
4373 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4374 blockArgIface.getUseDevicePtrBlockArgs(),
4375 useDevicePtrVars, mapData);
4379 moduleTranslation)))
4380 return llvm::make_error<PreviouslyReportedError>();
4384 return builder.saveIP();
4387 auto customMapperCB =
4389 if (!combinedInfo.Mappers[i])
4391 info.HasMapper =
true;
4396 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4397 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4399 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4400 if (isa<omp::TargetDataOp>(op))
4401 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4402 builder.getInt64(deviceID), ifCond,
4403 info, genMapInfoCB, customMapperCB,
4406 return ompBuilder->createTargetData(
4407 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4408 info, genMapInfoCB, customMapperCB, &RTLFn);
4414 builder.restoreIP(*afterIP);
4418 static LogicalResult
4422 auto distributeOp = cast<omp::DistributeOp>(opInst);
4429 bool doDistributeReduction =
4433 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4438 if (doDistributeReduction) {
4439 isByRef =
getIsByRef(teamsOp.getReductionByref());
4440 assert(isByRef.size() == teamsOp.getNumReductionVars());
4443 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4447 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4448 .getReductionBlockArgs();
4451 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4452 reductionDecls, privateReductionVariables, reductionVariableMap,
4457 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4458 auto bodyGenCB = [&](InsertPointTy allocaIP,
4459 InsertPointTy codeGenIP) -> llvm::Error {
4463 moduleTranslation, allocaIP);
4466 builder.restoreIP(codeGenIP);
4472 return llvm::make_error<PreviouslyReportedError>();
4477 return llvm::make_error<PreviouslyReportedError>();
4480 builder, moduleTranslation, privVarsInfo.
mlirVars,
4482 return llvm::make_error<PreviouslyReportedError>();
4485 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4488 builder, moduleTranslation);
4490 return regionBlock.takeError();
4491 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4496 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4499 auto schedule = omp::ClauseScheduleKind::Static;
4500 bool isOrdered =
false;
4501 std::optional<omp::ScheduleModifier> scheduleMod;
4502 bool isSimd =
false;
4503 llvm::omp::WorksharingLoopType workshareLoopType =
4504 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4505 bool loopNeedsBarrier =
false;
4506 llvm::Value *chunk =
nullptr;
4508 llvm::CanonicalLoopInfo *loopInfo =
4510 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4511 ompBuilder->applyWorkshareLoop(
4512 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4513 convertToScheduleKind(schedule), chunk, isSimd,
4514 scheduleMod == omp::ScheduleModifier::monotonic,
4515 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4519 return wsloopIP.takeError();
4523 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4525 return llvm::make_error<PreviouslyReportedError>();
4527 return llvm::Error::success();
4530 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4532 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4533 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4534 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4539 builder.restoreIP(*afterIP);
4541 if (doDistributeReduction) {
4544 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4545 privateReductionVariables, isByRef,
4556 if (!cast<mlir::ModuleOp>(op))
4561 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
4562 attribute.getOpenmpDeviceVersion());
4564 if (attribute.getNoGpuLib())
4567 ompBuilder->createGlobalFlag(
4568 attribute.getDebugKind() ,
4569 "__omp_rtl_debug_kind");
4570 ompBuilder->createGlobalFlag(
4572 .getAssumeTeamsOversubscription()
4574 "__omp_rtl_assume_teams_oversubscription");
4575 ompBuilder->createGlobalFlag(
4577 .getAssumeThreadsOversubscription()
4579 "__omp_rtl_assume_threads_oversubscription");
4580 ompBuilder->createGlobalFlag(
4581 attribute.getAssumeNoThreadState() ,
4582 "__omp_rtl_assume_no_thread_state");
4583 ompBuilder->createGlobalFlag(
4585 .getAssumeNoNestedParallelism()
4587 "__omp_rtl_assume_no_nested_parallelism");
4592 omp::TargetOp targetOp,
4593 llvm::StringRef parentName =
"") {
4594 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
4596 assert(fileLoc &&
"No file found from location");
4597 StringRef fileName = fileLoc.getFilename().getValue();
4599 llvm::sys::fs::UniqueID id;
4600 uint64_t line = fileLoc.getLine();
4601 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
4603 size_t deviceId = 0xdeadf17e;
4605 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4607 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
4608 id.getFile(), line);
4615 llvm::IRBuilderBase &builder, llvm::Function *func) {
4617 "function only supported for target device codegen");
4618 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4631 if (mapData.IsDeclareTarget[i]) {
4638 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
4639 convertUsersOfConstantsToInstructions(constant, func,
false);
4646 for (llvm::User *user : mapData.OriginalValue[i]->users())
4647 userVec.push_back(user);
4649 for (llvm::User *user : userVec) {
4650 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
4651 if (insn->getFunction() == func) {
4652 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
4653 mapData.BasePointers[i]);
4654 load->moveBefore(insn->getIterator());
4655 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
4702 static llvm::IRBuilderBase::InsertPoint
4704 llvm::Value *input, llvm::Value *&retVal,
4705 llvm::IRBuilderBase &builder,
4706 llvm::OpenMPIRBuilder &ompBuilder,
4708 llvm::IRBuilderBase::InsertPoint allocaIP,
4709 llvm::IRBuilderBase::InsertPoint codeGenIP) {
4710 assert(ompBuilder.Config.isTargetDevice() &&
4711 "function only supported for target device codegen");
4712 builder.restoreIP(allocaIP);
4714 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
4716 ompBuilder.M.getContext());
4717 unsigned alignmentValue = 0;
4719 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
4720 if (mapData.OriginalValue[i] == input) {
4721 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4722 capture = mapOp.getMapCaptureType();
4725 mapOp.getVarType(), ompBuilder.M.getDataLayout());
4729 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
4730 unsigned int defaultAS =
4731 ompBuilder.M.getDataLayout().getProgramAddressSpace();
4734 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
4736 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
4737 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
4739 builder.CreateStore(&arg, v);
4741 builder.restoreIP(codeGenIP);
4744 case omp::VariableCaptureKind::ByCopy: {
4748 case omp::VariableCaptureKind::ByRef: {
4749 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
4751 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
4766 if (v->getType()->isPointerTy() && alignmentValue) {
4767 llvm::MDBuilder MDB(builder.getContext());
4768 loadInst->setMetadata(
4769 llvm::LLVMContext::MD_align,
4772 llvm::Type::getInt64Ty(builder.getContext()),
4779 case omp::VariableCaptureKind::This:
4780 case omp::VariableCaptureKind::VLAType:
4783 assert(
false &&
"Currently unsupported capture kind");
4787 return builder.saveIP();
4804 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4805 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
4806 blockArgIface.getHostEvalBlockArgs())) {
4807 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
4811 .Case([&](omp::TeamsOp teamsOp) {
4812 if (teamsOp.getNumTeamsLower() == blockArg)
4813 numTeamsLower = hostEvalVar;
4814 else if (teamsOp.getNumTeamsUpper() == blockArg)
4815 numTeamsUpper = hostEvalVar;
4816 else if (teamsOp.getThreadLimit() == blockArg)
4817 threadLimit = hostEvalVar;
4819 llvm_unreachable(
"unsupported host_eval use");
4821 .Case([&](omp::ParallelOp parallelOp) {
4822 if (parallelOp.getNumThreads() == blockArg)
4823 numThreads = hostEvalVar;
4825 llvm_unreachable(
"unsupported host_eval use");
4827 .Case([&](omp::LoopNestOp loopOp) {
4828 auto processBounds =
4833 if (lb == blockArg) {
4836 (*outBounds)[i] = hostEvalVar;
4842 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
4843 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
4845 found = processBounds(loopOp.getLoopSteps(), steps) || found;
4847 assert(found &&
"unsupported host_eval use");
4850 llvm_unreachable(
"unsupported host_eval use");
4863 template <
typename OpTy>
4868 if (OpTy casted = dyn_cast<OpTy>(op))
4871 if (immediateParent)
4872 return dyn_cast_if_present<OpTy>(op->
getParentOp());
4881 return std::nullopt;
4884 dyn_cast_if_present<LLVM::ConstantOp>(value.
getDefiningOp()))
4885 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4886 return constAttr.getInt();
4888 return std::nullopt;
4893 uint64_t sizeInBytes = sizeInBits / 8;
4897 template <
typename OpTy>
4899 if (op.getNumReductionVars() > 0) {
4904 members.reserve(reductions.size());
4905 for (omp::DeclareReductionOp &red : reductions)
4906 members.push_back(red.getType());
4908 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
4924 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
4925 bool isTargetDevice,
bool isGPU) {
4928 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
4929 if (!isTargetDevice) {
4936 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4937 numTeamsLower = teamsOp.getNumTeamsLower();
4938 numTeamsUpper = teamsOp.getNumTeamsUpper();
4939 threadLimit = teamsOp.getThreadLimit();
4942 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4943 numThreads = parallelOp.getNumThreads();
4948 int32_t minTeamsVal = 1, maxTeamsVal = -1;
4949 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4952 if (numTeamsUpper) {
4954 minTeamsVal = maxTeamsVal = *val;
4956 minTeamsVal = maxTeamsVal = 0;
4958 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
4960 castOrGetParentOfType<omp::SimdOp>(capturedOp,
4962 minTeamsVal = maxTeamsVal = 1;
4964 minTeamsVal = maxTeamsVal = -1;
4969 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
4983 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
4984 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
4985 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
4988 int32_t maxThreadsVal = -1;
4989 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4990 setMaxValueFromClause(numThreads, maxThreadsVal);
4991 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
4998 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
4999 if (combinedMaxThreadsVal < 0 ||
5000 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5001 combinedMaxThreadsVal = teamsThreadLimitVal;
5003 if (combinedMaxThreadsVal < 0 ||
5004 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5005 combinedMaxThreadsVal = maxThreadsVal;
5007 int32_t reductionDataSize = 0;
5008 if (isGPU && capturedOp) {
5009 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5014 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5016 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5017 omp::TargetRegionFlags::spmd) &&
5018 "invalid kernel flags");
5020 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5021 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5022 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5023 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5024 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5025 attrs.MinTeams = minTeamsVal;
5026 attrs.MaxTeams.front() = maxTeamsVal;
5027 attrs.MinThreads = 1;
5028 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5029 attrs.ReductionDataSize = reductionDataSize;
5032 if (attrs.ReductionDataSize != 0)
5033 attrs.ReductionBufferLength = 1024;
5045 omp::TargetOp targetOp,
Operation *capturedOp,
5046 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5047 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5048 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5050 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5054 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5057 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5058 attrs.TargetThreadLimit.front() =
5062 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5065 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5067 if (teamsThreadLimit)
5068 attrs.TeamsThreadLimit.front() =
5072 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5074 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5075 omp::TargetRegionFlags::trip_count)) {
5077 attrs.LoopTripCount =
nullptr;
5082 for (
auto [loopLower, loopUpper, loopStep] :
5083 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5084 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5085 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5086 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5088 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5089 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5090 loc, lowerBound, upperBound, step,
true,
5091 loopOp.getLoopInclusive());
5093 if (!attrs.LoopTripCount) {
5094 attrs.LoopTripCount = tripCount;
5099 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5105 static LogicalResult
5108 auto targetOp = cast<omp::TargetOp>(opInst);
5113 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5114 bool isGPU = ompBuilder->Config.isGPU();
5117 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5118 auto &targetRegion = targetOp.getRegion();
5135 llvm::Function *llvmOutlinedFn =
nullptr;
5139 bool isOffloadEntry =
5140 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5147 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5149 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5150 std::optional<DenseI64ArrayAttr> privateMapIndices =
5151 targetOp.getPrivateMapsAttr();
5153 for (
auto [privVarIdx, privVarSymPair] :
5155 auto privVar = std::get<0>(privVarSymPair);
5156 auto privSym = std::get<1>(privVarSymPair);
5158 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5159 omp::PrivateClauseOp privatizer =
5162 if (!privatizer.needsMap())
5166 targetOp.getMappedValueForPrivateVar(privVarIdx);
5167 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5168 "variable that needs mapping");
5173 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5174 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5178 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5180 varType == privVar.getType() &&
5181 "Type of private var doesn't match the type of the mapped value");
5185 mappedPrivateVars.insert(
5187 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5188 (*privateMapIndices)[privVarIdx])});
5192 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5193 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5194 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5195 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5196 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5199 llvm::Function *llvmParentFn =
5201 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5202 assert(llvmParentFn && llvmOutlinedFn &&
5203 "Both parent and outlined functions must exist at this point");
5205 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5206 attr.isStringAttribute())
5207 llvmOutlinedFn->addFnAttr(attr);
5209 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5210 attr.isStringAttribute())
5211 llvmOutlinedFn->addFnAttr(attr);
5213 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5214 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5215 llvm::Value *mapOpValue =
5216 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5217 moduleTranslation.
mapValue(arg, mapOpValue);
5219 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5220 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5221 llvm::Value *mapOpValue =
5222 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5223 moduleTranslation.
mapValue(arg, mapOpValue);
5232 allocaIP, &mappedPrivateVars);
5235 return llvm::make_error<PreviouslyReportedError>();
5237 builder.restoreIP(codeGenIP);
5239 &mappedPrivateVars),
5242 return llvm::make_error<PreviouslyReportedError>();
5245 builder, moduleTranslation, privateVarsInfo.
mlirVars,
5247 &mappedPrivateVars)))
5248 return llvm::make_error<PreviouslyReportedError>();
5252 std::back_inserter(privateCleanupRegions),
5253 [](omp::PrivateClauseOp privatizer) {
5254 return &privatizer.getDeallocRegion();
5258 targetRegion,
"omp.target", builder, moduleTranslation);
5261 return exitBlock.takeError();
5263 builder.SetInsertPoint(*exitBlock);
5264 if (!privateCleanupRegions.empty()) {
5266 privateCleanupRegions, privateVarsInfo.
llvmVars,
5267 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5269 return llvm::createStringError(
5270 "failed to inline `dealloc` region of `omp.private` "
5271 "op in the target region");
5273 return builder.saveIP();
5276 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5279 StringRef parentName = parentFn.getName();
5281 llvm::TargetRegionEntryInfo entryInfo;
5285 MapInfoData mapData;
5290 MapInfosTy combinedInfos;
5292 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5293 builder.restoreIP(codeGenIP);
5294 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5295 return combinedInfos;
5298 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5299 llvm::Value *&retVal, InsertPointTy allocaIP,
5300 InsertPointTy codeGenIP)
5301 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5302 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5303 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5309 if (!isTargetDevice) {
5310 retVal = cast<llvm::Value>(&arg);
5315 *ompBuilder, moduleTranslation,
5316 allocaIP, codeGenIP);
5319 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5320 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5321 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5323 isTargetDevice, isGPU);
5327 if (!isTargetDevice)
5329 targetCapturedOp, runtimeAttrs);
5337 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5338 llvm::Value *value = moduleTranslation.
lookupValue(var);
5339 moduleTranslation.
mapValue(arg, value);
5341 if (!llvm::isa<llvm::Constant>(value))
5342 kernelInput.push_back(value);
5345 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5352 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5353 kernelInput.push_back(mapData.OriginalValue[i]);
5358 moduleTranslation, dds);
5360 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5362 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5364 llvm::OpenMPIRBuilder::TargetDataInfo info(
5368 auto customMapperCB =
5370 if (!combinedInfos.Mappers[i])
5372 info.HasMapper =
true;
5377 llvm::Value *ifCond =
nullptr;
5378 if (
Value targetIfCond = targetOp.getIfExpr())
5379 ifCond = moduleTranslation.
lookupValue(targetIfCond);
5381 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5383 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5384 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5385 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5390 builder.restoreIP(*afterIP);
5401 static LogicalResult
5411 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5412 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5414 if (!offloadMod.getIsTargetDevice())
5417 omp::DeclareTargetDeviceType declareType =
5418 attribute.getDeviceType().getValue();
5420 if (declareType == omp::DeclareTargetDeviceType::host) {
5421 llvm::Function *llvmFunc =
5423 llvmFunc->dropAllReferences();
5424 llvmFunc->eraseFromParent();
5430 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5431 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
5432 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5434 bool isDeclaration = gOp.isDeclaration();
5435 bool isExternallyVisible =
5438 llvm::StringRef mangledName = gOp.getSymName();
5439 auto captureClause =
5445 std::vector<llvm::GlobalVariable *> generatedRefs;
5447 std::vector<llvm::Triple> targetTriple;
5448 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5450 LLVM::LLVMDialect::getTargetTripleAttrName()));
5451 if (targetTripleAttr)
5452 targetTriple.emplace_back(targetTripleAttr.data());
5454 auto fileInfoCallBack = [&loc]() {
5455 std::string filename =
"";
5456 std::uint64_t lineNo = 0;
5459 filename = loc.getFilename().str();
5460 lineNo = loc.getLine();
5463 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5467 ompBuilder->registerTargetGlobalVariable(
5468 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5469 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5470 generatedRefs,
false, targetTriple,
5472 gVal->getType(), gVal);
5474 if (ompBuilder->Config.isTargetDevice() &&
5475 (attribute.getCaptureClause().getValue() !=
5476 mlir::omp::DeclareTargetCaptureClause::to ||
5477 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5478 ompBuilder->getAddrOfDeclareTargetVar(
5479 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5480 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5481 generatedRefs,
false, targetTriple, gVal->getType(),
5503 if (mlir::isa<omp::ThreadprivateOp>(op))
5507 if (
auto declareTargetIface =
5508 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5509 parentFn.getOperation()))
5510 if (declareTargetIface.isDeclareTarget() &&
5511 declareTargetIface.getDeclareTargetDeviceType() !=
5512 mlir::omp::DeclareTargetDeviceType::host)
5520 static LogicalResult
5531 bool isOutermostLoopWrapper =
5532 isa_and_present<omp::LoopWrapperInterface>(op) &&
5533 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
5535 if (isOutermostLoopWrapper)
5536 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
5540 .Case([&](omp::BarrierOp op) -> LogicalResult {
5544 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5545 ompBuilder->createBarrier(builder.saveIP(),
5546 llvm::omp::OMPD_barrier);
5549 .Case([&](omp::TaskyieldOp op) {
5553 ompBuilder->createTaskyield(builder.saveIP());
5556 .Case([&](omp::FlushOp op) {
5568 ompBuilder->createFlush(builder.saveIP());
5571 .Case([&](omp::ParallelOp op) {
5574 .Case([&](omp::MaskedOp) {
5577 .Case([&](omp::MasterOp) {
5580 .Case([&](omp::CriticalOp) {
5583 .Case([&](omp::OrderedRegionOp) {
5586 .Case([&](omp::OrderedOp) {
5589 .Case([&](omp::WsloopOp) {
5592 .Case([&](omp::SimdOp) {
5595 .Case([&](omp::AtomicReadOp) {
5598 .Case([&](omp::AtomicWriteOp) {
5601 .Case([&](omp::AtomicUpdateOp op) {
5604 .Case([&](omp::AtomicCaptureOp op) {
5607 .Case([&](omp::CancelOp op) {
5610 .Case([&](omp::CancellationPointOp op) {
5613 .Case([&](omp::SectionsOp) {
5616 .Case([&](omp::SingleOp op) {
5619 .Case([&](omp::TeamsOp op) {
5622 .Case([&](omp::TaskOp op) {
5625 .Case([&](omp::TaskgroupOp op) {
5628 .Case([&](omp::TaskwaitOp op) {
5631 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5632 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5633 omp::CriticalDeclareOp>([](
auto op) {
5646 .Case([&](omp::ThreadprivateOp) {
5649 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
5650 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
5653 .Case([&](omp::TargetOp) {
5656 .Case([&](omp::DistributeOp) {
5659 .Case([&](omp::LoopNestOp) {
5662 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
5671 <<
"not yet implemented: " << inst->
getName();
5674 if (isOutermostLoopWrapper)
5680 static LogicalResult
5686 static LogicalResult
5689 if (isa<omp::TargetOp>(op))
5691 if (isa<omp::TargetDataOp>(op))
5695 if (isa<omp::TargetOp>(oper)) {
5697 return WalkResult::interrupt();
5698 return WalkResult::skip();
5700 if (isa<omp::TargetDataOp>(oper)) {
5702 return WalkResult::interrupt();
5703 return WalkResult::skip();
5710 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5711 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5712 !oper->getRegions().empty()) {
5713 if (
auto blockArgsIface =
5714 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5720 if (isa<mlir::omp::AtomicUpdateOp>(oper))
5721 for (
auto [operand, arg] :
5722 llvm::zip_equal(oper->getOperands(),
5723 oper->getRegion(0).getArguments())) {
5725 arg, builder.CreateLoad(
5731 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5732 assert(builder.GetInsertBlock() &&
5733 "No insert block is set for the builder");
5734 for (
auto iv : loopNest.getIVs()) {
5742 for (
Region ®ion : oper->getRegions()) {
5749 region, oper->getName().getStringRef().str() +
".fake.region",
5750 builder, moduleTranslation, &phis);
5752 return WalkResult::interrupt();
5754 builder.SetInsertPoint(result.get(), result.get()->end());
5757 return WalkResult::skip();
5760 return WalkResult::advance();
5761 }).wasInterrupted();
5762 return failure(interrupted);
5769 class OpenMPDialectLLVMIRTranslationInterface
5790 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
5796 .Case(
"omp.is_target_device",
5798 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
5799 llvm::OpenMPIRBuilderConfig &
config =
5801 config.setIsTargetDevice(deviceAttr.getValue());
5808 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
5809 llvm::OpenMPIRBuilderConfig &
config =
5811 config.setIsGPU(gpuAttr.getValue());
5816 .Case(
"omp.host_ir_filepath",
5818 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
5819 llvm::OpenMPIRBuilder *ompBuilder =
5821 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
5828 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
5832 .Case(
"omp.version",
5834 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
5835 llvm::OpenMPIRBuilder *ompBuilder =
5837 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
5838 versionAttr.getVersion());
5843 .Case(
"omp.declare_target",
5845 if (
auto declareTargetAttr =
5846 dyn_cast<omp::DeclareTargetAttr>(attr))
5851 .Case(
"omp.requires",
5853 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
5854 using Requires = omp::ClauseRequires;
5855 Requires flags = requiresAttr.getValue();
5856 llvm::OpenMPIRBuilderConfig &
config =
5858 config.setHasRequiresReverseOffload(
5859 bitEnumContainsAll(flags, Requires::reverse_offload));
5860 config.setHasRequiresUnifiedAddress(
5861 bitEnumContainsAll(flags, Requires::unified_address));
5862 config.setHasRequiresUnifiedSharedMemory(
5863 bitEnumContainsAll(flags, Requires::unified_shared_memory));
5864 config.setHasRequiresDynamicAllocators(
5865 bitEnumContainsAll(flags, Requires::dynamic_allocators));
5870 .Case(
"omp.target_triples",
5872 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
5873 llvm::OpenMPIRBuilderConfig &
config =
5875 config.TargetTriples.clear();
5876 config.TargetTriples.reserve(triplesAttr.size());
5877 for (
Attribute tripleAttr : triplesAttr) {
5878 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
5879 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
5897 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
5898 Operation *op, llvm::IRBuilderBase &builder,
5902 if (ompBuilder->Config.isTargetDevice()) {
5913 registry.
insert<omp::OpenMPDialect>();
5915 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult copyFirstPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult convertIgnoredWrapper(omp::LoopWrapperInterface opInst, LLVM::ModuleTranslation &moduleTranslation)
Helper function to map block arguments defined by ignored loop wrappers to LLVM values and prevent an...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.