26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
31 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/MDBuilder.h"
37 #include "llvm/IR/ReplaceConstant.h"
38 #include "llvm/Support/FileSystem.h"
39 #include "llvm/TargetParser/Triple.h"
40 #include "llvm/Transforms/Utils/ModuleUtils.h"
52 static llvm::omp::ScheduleKind
53 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
66 return llvm::omp::OMP_SCHEDULE_Runtime;
68 llvm_unreachable(
"unhandled schedule clause argument");
73 class OpenMPAllocaStackFrame
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
86 class OpenMPLoopInfoStackFrame
90 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
109 class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
132 SymbolRefAttr symbolName) {
133 omp::PrivateClauseOp privatizer =
134 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
136 assert(privatizer &&
"privatizer not found in the symbol table");
147 auto todo = [&op](StringRef clauseName) {
148 return op.
emitError() <<
"not yet implemented: Unhandled clause "
149 << clauseName <<
" in " << op.
getName()
153 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
154 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
155 result = todo(
"allocate");
157 auto checkBare = [&todo](
auto op, LogicalResult &result) {
159 result = todo(
"ompx_bare");
161 auto checkCancelDirective = [&todo](
auto op, LogicalResult &result) {
162 omp::ClauseCancellationConstructType cancelledDirective =
163 op.getCancelDirective();
166 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
173 if (isa_and_nonnull<omp::TaskloopOp>(parent))
174 result = todo(
"cancel directive inside of taskloop");
177 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
178 if (!op.getDependVars().empty() || op.getDependKinds())
179 result = todo(
"depend");
181 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
183 result = todo(
"device");
185 auto checkDistSchedule = [&todo](
auto op, LogicalResult &result) {
186 if (op.getDistScheduleChunkSize())
187 result = todo(
"dist_schedule with chunk_size");
189 auto checkHint = [](
auto op, LogicalResult &) {
193 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
194 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
195 op.getInReductionSyms())
196 result = todo(
"in_reduction");
198 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
199 if (!op.getIsDevicePtrVars().empty())
200 result = todo(
"is_device_ptr");
202 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
203 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
204 result = todo(
"linear");
206 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
208 result = todo(
"nowait");
210 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
211 if (op.getOrder() || op.getOrderMod())
212 result = todo(
"order");
214 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
215 if (op.getParLevelSimd())
216 result = todo(
"parallelization-level");
218 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
219 if (op.getPriority())
220 result = todo(
"priority");
222 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
223 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
225 if (!op.getPrivateVars().empty() && op.getNowait())
226 result = todo(
"privatization for deferred target tasks");
228 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
229 result = todo(
"privatization");
232 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
233 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
234 if (!op.getReductionVars().empty() || op.getReductionByref() ||
235 op.getReductionSyms())
236 result = todo(
"reduction");
237 if (op.getReductionMod() &&
238 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
239 result = todo(
"reduction with modifier");
241 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
242 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
243 op.getTaskReductionSyms())
244 result = todo(
"task_reduction");
246 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
248 result = todo(
"untied");
251 LogicalResult result = success();
253 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
254 .Case([&](omp::CancellationPointOp op) {
255 checkCancelDirective(op, result);
257 .Case([&](omp::DistributeOp op) {
258 checkAllocate(op, result);
259 checkDistSchedule(op, result);
260 checkOrder(op, result);
262 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
263 .Case([&](omp::SectionsOp op) {
264 checkAllocate(op, result);
265 checkPrivate(op, result);
266 checkReduction(op, result);
268 .Case([&](omp::SingleOp op) {
269 checkAllocate(op, result);
270 checkPrivate(op, result);
272 .Case([&](omp::TeamsOp op) {
273 checkAllocate(op, result);
274 checkPrivate(op, result);
276 .Case([&](omp::TaskOp op) {
277 checkAllocate(op, result);
278 checkInReduction(op, result);
280 .Case([&](omp::TaskgroupOp op) {
281 checkAllocate(op, result);
282 checkTaskReduction(op, result);
284 .Case([&](omp::TaskwaitOp op) {
285 checkDepend(op, result);
286 checkNowait(op, result);
288 .Case([&](omp::TaskloopOp op) {
290 checkUntied(op, result);
291 checkPriority(op, result);
293 .Case([&](omp::WsloopOp op) {
294 checkAllocate(op, result);
295 checkLinear(op, result);
296 checkOrder(op, result);
297 checkReduction(op, result);
299 .Case([&](omp::ParallelOp op) {
300 checkAllocate(op, result);
301 checkReduction(op, result);
303 .Case([&](omp::SimdOp op) {
304 checkLinear(op, result);
305 checkReduction(op, result);
307 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
308 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
309 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
310 [&](
auto op) { checkDepend(op, result); })
311 .Case([&](omp::TargetOp op) {
312 checkAllocate(op, result);
313 checkBare(op, result);
314 checkDevice(op, result);
315 checkInReduction(op, result);
316 checkIsDevicePtr(op, result);
317 checkPrivate(op, result);
327 LogicalResult result = success();
329 llvm::handleAllErrors(
331 [&](
const PreviouslyReportedError &) { result = failure(); },
332 [&](
const llvm::ErrorInfoBase &err) {
339 template <
typename T>
349 static llvm::OpenMPIRBuilder::InsertPointTy
355 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
357 [&](OpenMPAllocaStackFrame &frame) {
358 allocaInsertPoint = frame.allocaInsertPoint;
362 return allocaInsertPoint;
371 if (builder.GetInsertBlock() ==
372 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
373 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
374 "Assuming end of basic block");
375 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
376 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
377 builder.GetInsertBlock()->getNextNode());
378 builder.CreateBr(entryBB);
379 builder.SetInsertPoint(entryBB);
382 llvm::BasicBlock &funcEntryBlock =
383 builder.GetInsertBlock()->getParent()->getEntryBlock();
384 return llvm::OpenMPIRBuilder::InsertPointTy(
385 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
391 static llvm::CanonicalLoopInfo *
393 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
394 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
395 [&](OpenMPLoopInfoStackFrame &frame) {
396 loopInfo = frame.loopInfo;
408 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
411 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
413 llvm::BasicBlock *continuationBlock =
414 splitBB(builder,
true,
"omp.region.cont");
415 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
417 llvm::LLVMContext &llvmContext = builder.getContext();
418 for (
Block &bb : region) {
419 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
420 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
421 builder.GetInsertBlock()->getNextNode());
422 moduleTranslation.
mapBlock(&bb, llvmBB);
425 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
432 unsigned numYields = 0;
434 if (!isLoopWrapper) {
435 bool operandsProcessed =
false;
437 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
438 if (!operandsProcessed) {
439 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
440 continuationBlockPHITypes.push_back(
441 moduleTranslation.
convertType(yield->getOperand(i).getType()));
443 operandsProcessed =
true;
445 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
446 "mismatching number of values yielded from the region");
447 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
448 llvm::Type *operandType =
449 moduleTranslation.
convertType(yield->getOperand(i).getType());
451 assert(continuationBlockPHITypes[i] == operandType &&
452 "values of mismatching types yielded from the region");
462 if (!continuationBlockPHITypes.empty())
464 continuationBlockPHIs &&
465 "expected continuation block PHIs if converted regions yield values");
466 if (continuationBlockPHIs) {
467 llvm::IRBuilderBase::InsertPointGuard guard(builder);
468 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
469 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
470 for (llvm::Type *ty : continuationBlockPHITypes)
471 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
477 for (
Block *bb : blocks) {
478 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
481 if (bb->isEntryBlock()) {
482 assert(sourceTerminator->getNumSuccessors() == 1 &&
483 "provided entry block has multiple successors");
484 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
485 "ContinuationBlock is not the successor of the entry block");
486 sourceTerminator->setSuccessor(0, llvmBB);
489 llvm::IRBuilderBase::InsertPointGuard guard(builder);
491 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
492 return llvm::make_error<PreviouslyReportedError>();
497 builder.CreateBr(continuationBlock);
508 Operation *terminator = bb->getTerminator();
509 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
510 builder.CreateBr(continuationBlock);
512 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
513 (*continuationBlockPHIs)[i]->addIncoming(
527 return continuationBlock;
533 case omp::ClauseProcBindKind::Close:
534 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
535 case omp::ClauseProcBindKind::Master:
536 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
537 case omp::ClauseProcBindKind::Primary:
538 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
539 case omp::ClauseProcBindKind::Spread:
540 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
542 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
552 omp::BlockArgOpenMPOpInterface blockArgIface) {
554 blockArgIface.getBlockArgsPairs(blockArgsPairs);
555 for (
auto [var, arg] : blockArgsPairs)
572 .Case([&](omp::SimdOp op) {
574 cast<omp::BlockArgOpenMPOpInterface>(*op));
575 op.emitWarning() <<
"simd information on composite construct discarded";
579 return op->emitError() <<
"cannot ignore wrapper";
587 auto maskedOp = cast<omp::MaskedOp>(opInst);
588 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
593 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
595 auto ®ion = maskedOp.getRegion();
596 builder.restoreIP(codeGenIP);
604 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
606 llvm::Value *filterVal =
nullptr;
607 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
608 filterVal = moduleTranslation.
lookupValue(filterVar);
610 llvm::LLVMContext &llvmContext = builder.getContext();
614 assert(filterVal !=
nullptr);
615 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
616 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
623 builder.restoreIP(*afterIP);
631 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
632 auto masterOp = cast<omp::MasterOp>(opInst);
637 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
639 auto ®ion = masterOp.getRegion();
640 builder.restoreIP(codeGenIP);
648 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
650 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
651 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
658 builder.restoreIP(*afterIP);
666 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
667 auto criticalOp = cast<omp::CriticalOp>(opInst);
672 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
674 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
675 builder.restoreIP(codeGenIP);
683 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
685 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
686 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
687 llvm::Constant *hint =
nullptr;
690 if (criticalOp.getNameAttr()) {
693 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
694 auto criticalDeclareOp =
695 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
699 static_cast<int>(criticalDeclareOp.getHint()));
701 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
703 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
708 builder.restoreIP(*afterIP);
715 template <
typename OP>
718 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
719 mlirVars.reserve(blockArgs.size());
720 llvmVars.reserve(blockArgs.size());
721 collectPrivatizationDecls<OP>(op);
724 mlirVars.push_back(privateVar);
736 void collectPrivatizationDecls(OP op) {
737 std::optional<ArrayAttr> attr = op.getPrivateSyms();
741 privatizers.reserve(privatizers.size() + attr->size());
742 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
749 template <
typename T>
753 std::optional<ArrayAttr> attr = op.getReductionSyms();
757 reductions.reserve(reductions.size() + op.getNumReductionVars());
758 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
759 reductions.push_back(
760 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
771 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
779 if (llvm::hasSingleElement(region)) {
780 llvm::Instruction *potentialTerminator =
781 builder.GetInsertBlock()->empty() ? nullptr
782 : &builder.GetInsertBlock()->back();
784 if (potentialTerminator && potentialTerminator->isTerminator())
785 potentialTerminator->removeFromParent();
786 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
789 region.
front(),
true, builder)))
793 if (continuationBlockArgs)
795 *continuationBlockArgs,
802 if (potentialTerminator && potentialTerminator->isTerminator()) {
803 llvm::BasicBlock *block = builder.GetInsertBlock();
804 if (block->empty()) {
810 potentialTerminator->insertInto(block, block->begin());
812 potentialTerminator->insertAfter(&block->back());
826 if (continuationBlockArgs)
827 llvm::append_range(*continuationBlockArgs, phis);
828 builder.SetInsertPoint(*continuationBlock,
829 (*continuationBlock)->getFirstInsertionPt());
836 using OwningReductionGen =
837 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
838 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
840 using OwningAtomicReductionGen =
841 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
842 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
849 static OwningReductionGen
855 OwningReductionGen gen =
856 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
857 llvm::Value *lhs, llvm::Value *rhs,
858 llvm::Value *&result)
mutable
859 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
860 moduleTranslation.
mapValue(decl.getReductionLhsArg(), lhs);
861 moduleTranslation.
mapValue(decl.getReductionRhsArg(), rhs);
862 builder.restoreIP(insertPoint);
865 "omp.reduction.nonatomic.body", builder,
866 moduleTranslation, &phis)))
867 return llvm::createStringError(
868 "failed to inline `combiner` region of `omp.declare_reduction`");
869 result = llvm::getSingleElement(phis);
870 return builder.saveIP();
879 static OwningAtomicReductionGen
881 llvm::IRBuilderBase &builder,
883 if (decl.getAtomicReductionRegion().empty())
884 return OwningAtomicReductionGen();
889 OwningAtomicReductionGen atomicGen =
890 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
891 llvm::Value *lhs, llvm::Value *rhs)
mutable
892 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
893 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(), lhs);
894 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(), rhs);
895 builder.restoreIP(insertPoint);
898 "omp.reduction.atomic.body", builder,
899 moduleTranslation, &phis)))
900 return llvm::createStringError(
901 "failed to inline `atomic` region of `omp.declare_reduction`");
902 assert(phis.empty());
903 return builder.saveIP();
912 auto orderedOp = cast<omp::OrderedOp>(opInst);
917 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
918 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
919 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
921 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
923 size_t indexVecValues = 0;
924 while (indexVecValues < vecValues.size()) {
926 storeValues.reserve(numLoops);
927 for (
unsigned i = 0; i < numLoops; i++) {
928 storeValues.push_back(vecValues[indexVecValues]);
931 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
933 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
934 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
935 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
945 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
946 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
951 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
953 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
954 builder.restoreIP(codeGenIP);
962 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
964 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
965 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
967 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
972 builder.restoreIP(*afterIP);
978 struct DeferredStore {
979 DeferredStore(llvm::Value *value, llvm::Value *address)
980 : value(value), address(address) {}
983 llvm::Value *address;
990 template <
typename T>
993 llvm::IRBuilderBase &builder,
995 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1001 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1002 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1005 deferredStores.reserve(loop.getNumReductionVars());
1007 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1008 Region &allocRegion = reductionDecls[i].getAllocRegion();
1010 if (allocRegion.
empty())
1015 builder, moduleTranslation, &phis)))
1016 return loop.emitError(
1017 "failed to inline `alloc` region of `omp.declare_reduction`");
1019 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1020 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1024 llvm::Value *var = builder.CreateAlloca(
1025 moduleTranslation.
convertType(reductionDecls[i].getType()));
1027 llvm::Type *ptrTy = builder.getPtrTy();
1028 llvm::Value *castVar =
1029 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1030 llvm::Value *castPhi =
1031 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1033 deferredStores.emplace_back(castPhi, castVar);
1035 privateReductionVariables[i] = castVar;
1036 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1037 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1039 assert(allocRegion.
empty() &&
1040 "allocaction is implicit for by-val reduction");
1041 llvm::Value *var = builder.CreateAlloca(
1042 moduleTranslation.
convertType(reductionDecls[i].getType()));
1044 llvm::Type *ptrTy = builder.getPtrTy();
1045 llvm::Value *castVar =
1046 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1048 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1049 privateReductionVariables[i] = castVar;
1050 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1058 template <
typename T>
1065 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1066 Region &initializerRegion = reduction.getInitializerRegion();
1069 mlir::Value mlirSource = loop.getReductionVars()[i];
1070 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1071 assert(llvmSource &&
"lookup reduction var");
1072 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1075 llvm::Value *allocation =
1076 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1077 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1083 llvm::BasicBlock *block =
nullptr) {
1084 if (block ==
nullptr)
1085 block = builder.GetInsertBlock();
1087 if (block->empty() || block->getTerminator() ==
nullptr)
1088 builder.SetInsertPoint(block);
1090 builder.SetInsertPoint(block->getTerminator());
1098 template <
typename OP>
1099 static LogicalResult
1101 llvm::IRBuilderBase &builder,
1103 llvm::BasicBlock *latestAllocaBlock,
1109 if (op.getNumReductionVars() == 0)
1112 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1113 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1114 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1115 builder.restoreIP(allocaIP);
1118 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1120 if (!reductionDecls[i].getAllocRegion().empty())
1126 byRefVars[i] = builder.CreateAlloca(
1127 moduleTranslation.
convertType(reductionDecls[i].getType()));
1135 for (
auto [data, addr] : deferredStores)
1136 builder.CreateStore(data, addr);
1141 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1146 reductionVariableMap, i);
1149 "omp.reduction.neutral", builder,
1150 moduleTranslation, &phis)))
1153 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1154 "reduction neutral element declaration region");
1159 if (!reductionDecls[i].getAllocRegion().empty())
1168 builder.CreateStore(phis[0], byRefVars[i]);
1170 privateReductionVariables[i] = byRefVars[i];
1171 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1172 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1175 builder.CreateStore(phis[0], privateReductionVariables[i]);
1182 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1189 template <
typename T>
1191 T loop, llvm::IRBuilderBase &builder,
1198 unsigned numReductions = loop.getNumReductionVars();
1200 for (
unsigned i = 0; i < numReductions; ++i) {
1201 owningReductionGens.push_back(
1203 owningAtomicReductionGens.push_back(
1208 reductionInfos.reserve(numReductions);
1209 for (
unsigned i = 0; i < numReductions; ++i) {
1210 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1211 if (owningAtomicReductionGens[i])
1212 atomicGen = owningAtomicReductionGens[i];
1213 llvm::Value *variable =
1214 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1215 reductionInfos.push_back(
1217 privateReductionVariables[i],
1218 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1219 owningReductionGens[i],
1220 nullptr, atomicGen});
1225 static LogicalResult
1229 llvm::IRBuilderBase &builder, StringRef regionName,
1230 bool shouldLoadCleanupRegionArg =
true) {
1232 if (cleanupRegion->empty())
1238 llvm::Instruction *potentialTerminator =
1239 builder.GetInsertBlock()->empty() ? nullptr
1240 : &builder.GetInsertBlock()->back();
1241 if (potentialTerminator && potentialTerminator->isTerminator())
1242 builder.SetInsertPoint(potentialTerminator);
1243 llvm::Value *privateVarValue =
1244 shouldLoadCleanupRegionArg
1245 ? builder.CreateLoad(
1247 privateVariables[i])
1248 : privateVariables[i];
1253 moduleTranslation)))
1266 OP op, llvm::IRBuilderBase &builder,
1268 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1271 bool isNowait =
false,
bool isTeamsReduction =
false) {
1273 if (op.getNumReductionVars() == 0)
1285 owningReductionGens, owningAtomicReductionGens,
1286 privateReductionVariables, reductionInfos);
1291 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1292 builder.SetInsertPoint(tempTerminator);
1293 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1294 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1295 isByRef, isNowait, isTeamsReduction);
1300 if (!contInsertPoint->getBlock())
1301 return op->emitOpError() <<
"failed to convert reductions";
1303 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1304 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1309 tempTerminator->eraseFromParent();
1310 builder.restoreIP(*afterIP);
1314 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1315 [](omp::DeclareReductionOp reductionDecl) {
1316 return &reductionDecl.getCleanupRegion();
1319 moduleTranslation, builder,
1320 "omp.reduction.cleanup");
1331 template <
typename OP>
1335 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1340 if (op.getNumReductionVars() == 0)
1346 allocaIP, reductionDecls,
1347 privateReductionVariables, reductionVariableMap,
1348 deferredStores, isByRef)))
1352 allocaIP.getBlock(), reductionDecls,
1353 privateReductionVariables, reductionVariableMap,
1354 isByRef, deferredStores);
1364 static llvm::Value *
1368 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1371 Value blockArg = (*mappedPrivateVars)[privateVar];
1374 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1375 "A block argument corresponding to a mapped var should have "
1378 if (privVarType == blockArgType)
1385 if (!isa<LLVM::LLVMPointerType>(privVarType))
1386 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1399 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1401 Region &initRegion = privDecl.getInitRegion();
1402 if (initRegion.
empty())
1403 return llvmPrivateVar;
1407 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1408 assert(nonPrivateVar);
1409 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1410 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1415 moduleTranslation, &phis)))
1416 return llvm::createStringError(
1417 "failed to inline `init` region of `omp.private`");
1419 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1438 return llvm::Error::success();
1440 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1446 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1448 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1449 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1452 return privVarOrErr.takeError();
1454 llvmPrivateVar = privVarOrErr.get();
1455 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1460 return llvm::Error::success();
1470 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1473 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1474 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1475 allocaTerminator->getIterator()),
1476 true, allocaTerminator->getStableDebugLoc(),
1477 "omp.region.after_alloca");
1479 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1481 allocaTerminator = allocaIP.getBlock()->getTerminator();
1482 builder.SetInsertPoint(allocaTerminator);
1484 assert(allocaTerminator->getNumSuccessors() == 1 &&
1485 "This is an unconditional branch created by splitBB");
1487 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1488 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1490 unsigned int allocaAS =
1491 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1494 .getProgramAddressSpace();
1496 for (
auto [privDecl, mlirPrivVar, blockArg] :
1499 llvm::Type *llvmAllocType =
1500 moduleTranslation.
convertType(privDecl.getType());
1501 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1502 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1503 llvmAllocType,
nullptr,
"omp.private.alloc");
1504 if (allocaAS != defaultAS)
1505 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1506 builder.getPtrTy(defaultAS));
1508 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1511 return afterAllocas;
1521 bool needsFirstprivate =
1522 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1523 return privOp.getDataSharingType() ==
1524 omp::DataSharingClauseType::FirstPrivate;
1527 if (!needsFirstprivate)
1530 llvm::BasicBlock *copyBlock =
1531 splitBB(builder,
true,
"omp.private.copy");
1534 for (
auto [decl, mlirVar, llvmVar] :
1535 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1536 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1540 Region ©Region = decl.getCopyRegion();
1544 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1545 assert(nonPrivateVar);
1546 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1549 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1553 moduleTranslation)))
1554 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1569 static LogicalResult
1576 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1577 [](omp::PrivateClauseOp privatizer) {
1578 return &privatizer.getDeallocRegion();
1582 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1583 "omp.private.dealloc",
false)))
1584 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1585 "`omp.private` op in");
1597 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1604 static LogicalResult
1607 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1608 using StorableBodyGenCallbackTy =
1609 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1611 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1617 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1621 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1625 sectionsOp.getNumReductionVars());
1629 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1632 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1633 reductionDecls, privateReductionVariables, reductionVariableMap,
1640 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1644 Region ®ion = sectionOp.getRegion();
1645 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1646 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1647 builder.restoreIP(codeGenIP);
1654 sectionsOp.getRegion().getNumArguments());
1655 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1656 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1657 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1659 moduleTranslation.
mapValue(sectionArg, llvmVal);
1666 sectionCBs.push_back(sectionCB);
1672 if (sectionCBs.empty())
1675 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1680 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1681 llvm::Value &vPtr, llvm::Value *&replacementValue)
1682 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1683 replacementValue = &vPtr;
1689 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1693 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1694 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1696 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1697 sectionsOp.getNowait());
1702 builder.restoreIP(*afterIP);
1706 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1707 privateReductionVariables, isByRef, sectionsOp.getNowait());
1711 static LogicalResult
1714 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1715 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1720 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1721 builder.restoreIP(codegenIP);
1723 builder, moduleTranslation)
1726 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1730 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1733 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1734 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1735 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1736 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1737 llvmCPFuncs.push_back(
1741 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1743 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1749 builder.restoreIP(*afterIP);
1755 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1760 for (
auto ra : iface.getReductionBlockArgs())
1761 for (
auto &use : ra.getUses()) {
1762 auto *useOp = use.getOwner();
1764 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1765 debugUses.push_back(useOp);
1769 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1774 Operation *currentOp = currentDistOp.getOperation();
1775 if (distOp && (distOp != currentOp))
1784 for (
auto use : debugUses)
1790 static LogicalResult
1793 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1798 unsigned numReductionVars = op.getNumReductionVars();
1802 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1808 if (doTeamsReduction) {
1809 isByRef =
getIsByRef(op.getReductionByref());
1811 assert(isByRef.size() == op.getNumReductionVars());
1814 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1819 op, reductionArgs, builder, moduleTranslation, allocaIP,
1820 reductionDecls, privateReductionVariables, reductionVariableMap,
1825 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1827 moduleTranslation, allocaIP);
1828 builder.restoreIP(codegenIP);
1834 llvm::Value *numTeamsLower =
nullptr;
1835 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1836 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1838 llvm::Value *numTeamsUpper =
nullptr;
1839 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1840 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1842 llvm::Value *threadLimit =
nullptr;
1843 if (
Value threadLimitVar = op.getThreadLimit())
1844 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1846 llvm::Value *ifExpr =
nullptr;
1847 if (
Value ifVar = op.getIfExpr())
1850 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1851 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1853 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1858 builder.restoreIP(*afterIP);
1859 if (doTeamsReduction) {
1862 op, builder, moduleTranslation, allocaIP, reductionDecls,
1863 privateReductionVariables, isByRef,
1873 if (dependVars.empty())
1875 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
1876 llvm::omp::RTLDependenceKindTy type;
1878 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
1879 case mlir::omp::ClauseTaskDepend::taskdependin:
1880 type = llvm::omp::RTLDependenceKindTy::DepIn;
1885 case mlir::omp::ClauseTaskDepend::taskdependout:
1886 case mlir::omp::ClauseTaskDepend::taskdependinout:
1887 type = llvm::omp::RTLDependenceKindTy::DepInOut;
1889 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
1890 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
1892 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
1893 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
1896 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
1897 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
1898 dds.emplace_back(dd);
1910 llvm::IRBuilderBase &llvmBuilder,
1912 llvm::omp::Directive cancelDirective) {
1913 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1914 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
1918 llvmBuilder.restoreIP(ip);
1924 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
1925 return llvm::Error::success();
1930 ompBuilder.pushFinalizationCB(
1940 llvm::OpenMPIRBuilder &ompBuilder,
1941 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1942 ompBuilder.popFinalizationCB();
1943 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
1944 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1945 assert(cancelBranch->getNumSuccessors() == 1 &&
1946 "cancel branch should have one target");
1947 cancelBranch->setSuccessor(0, constructFini);
1954 class TaskContextStructManager {
1956 TaskContextStructManager(llvm::IRBuilderBase &builder,
1959 : builder{builder}, moduleTranslation{moduleTranslation},
1960 privateDecls{privateDecls} {}
1966 void generateTaskContextStruct();
1972 void createGEPsToPrivateVars();
1975 void freeStructPtr();
1978 return llvmPrivateVarGEPs;
1981 llvm::Value *getStructPtr() {
return structPtr; }
1984 llvm::IRBuilderBase &builder;
1996 llvm::Value *structPtr =
nullptr;
1998 llvm::Type *structTy =
nullptr;
2002 void TaskContextStructManager::generateTaskContextStruct() {
2003 if (privateDecls.empty())
2005 privateVarTypes.reserve(privateDecls.size());
2007 for (omp::PrivateClauseOp &privOp : privateDecls) {
2010 if (!privOp.readsFromMold())
2012 Type mlirType = privOp.getType();
2013 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2019 llvm::DataLayout dataLayout =
2020 builder.GetInsertBlock()->getModule()->getDataLayout();
2021 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2022 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2025 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2027 "omp.task.context_ptr");
2030 void TaskContextStructManager::createGEPsToPrivateVars() {
2032 assert(privateVarTypes.empty());
2037 llvmPrivateVarGEPs.clear();
2038 llvmPrivateVarGEPs.reserve(privateDecls.size());
2039 llvm::Value *zero = builder.getInt32(0);
2041 for (
auto privDecl : privateDecls) {
2042 if (!privDecl.readsFromMold()) {
2044 llvmPrivateVarGEPs.push_back(
nullptr);
2047 llvm::Value *iVal = builder.getInt32(i);
2048 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2049 llvmPrivateVarGEPs.push_back(gep);
2054 void TaskContextStructManager::freeStructPtr() {
2058 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2060 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2061 builder.CreateFree(structPtr);
2065 static LogicalResult
2068 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2073 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2085 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2090 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2091 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2092 builder.getContext(),
"omp.task.start",
2093 builder.GetInsertBlock()->getParent());
2094 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2095 builder.SetInsertPoint(branchToTaskStartBlock);
2098 llvm::BasicBlock *copyBlock =
2099 splitBB(builder,
true,
"omp.private.copy");
2100 llvm::BasicBlock *initBlock =
2101 splitBB(builder,
true,
"omp.private.init");
2117 moduleTranslation, allocaIP);
2120 builder.SetInsertPoint(initBlock->getTerminator());
2123 taskStructMgr.generateTaskContextStruct();
2130 taskStructMgr.createGEPsToPrivateVars();
2132 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2135 taskStructMgr.getLLVMPrivateVarGEPs())) {
2137 if (!privDecl.readsFromMold())
2139 assert(llvmPrivateVarAlloc &&
2140 "reads from mold so shouldn't have been skipped");
2143 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2144 blockArg, llvmPrivateVarAlloc, initBlock);
2145 if (!privateVarOrErr)
2146 return handleError(privateVarOrErr, *taskOp.getOperation());
2148 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2149 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2156 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2157 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2158 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2160 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2161 llvmPrivateVarAlloc);
2163 assert(llvmPrivateVarAlloc->getType() ==
2164 moduleTranslation.
convertType(blockArg.getType()));
2174 builder, moduleTranslation, privateVarsInfo.
mlirVars,
2175 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers)))
2176 return llvm::failure();
2179 builder.SetInsertPoint(taskStartBlock);
2181 auto bodyCB = [&](InsertPointTy allocaIP,
2182 InsertPointTy codegenIP) -> llvm::Error {
2186 moduleTranslation, allocaIP);
2189 builder.restoreIP(codegenIP);
2191 llvm::BasicBlock *privInitBlock =
nullptr;
2196 auto [blockArg, privDecl, mlirPrivVar] = zip;
2198 if (privDecl.readsFromMold())
2201 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2202 llvm::Type *llvmAllocType =
2203 moduleTranslation.
convertType(privDecl.getType());
2204 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2205 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2206 llvmAllocType,
nullptr,
"omp.private.alloc");
2209 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2210 blockArg, llvmPrivateVar, privInitBlock);
2211 if (!privateVarOrError)
2212 return privateVarOrError.takeError();
2213 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2214 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2217 taskStructMgr.createGEPsToPrivateVars();
2218 for (
auto [i, llvmPrivVar] :
2221 assert(privateVarsInfo.
llvmVars[i] &&
2222 "This is added in the loop above");
2225 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2230 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2234 if (!privateDecl.readsFromMold())
2237 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2238 llvmPrivateVar = builder.CreateLoad(
2239 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2241 assert(llvmPrivateVar->getType() ==
2242 moduleTranslation.
convertType(blockArg.getType()));
2243 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2247 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2248 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2249 return llvm::make_error<PreviouslyReportedError>();
2251 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2256 return llvm::make_error<PreviouslyReportedError>();
2259 taskStructMgr.freeStructPtr();
2261 return llvm::Error::success();
2270 llvm::omp::Directive::OMPD_taskgroup);
2274 moduleTranslation, dds);
2276 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2277 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2279 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2281 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2282 taskOp.getMergeable(),
2283 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2284 moduleTranslation.
lookupValue(taskOp.getPriority()));
2292 builder.restoreIP(*afterIP);
2297 static LogicalResult
2300 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2304 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2305 builder.restoreIP(codegenIP);
2307 builder, moduleTranslation)
2312 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2313 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2320 builder.restoreIP(*afterIP);
2324 static LogicalResult
2335 static LogicalResult
2339 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2343 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2345 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2349 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2352 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2353 llvm::Type *ivType = step->getType();
2354 llvm::Value *chunk =
nullptr;
2355 if (wsloopOp.getScheduleChunk()) {
2356 llvm::Value *chunkVar =
2357 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2358 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2365 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2369 wsloopOp.getNumReductionVars());
2372 builder, moduleTranslation, privateVarsInfo, allocaIP);
2379 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2384 moduleTranslation, allocaIP, reductionDecls,
2385 privateReductionVariables, reductionVariableMap,
2386 deferredStores, isByRef)))
2395 builder, moduleTranslation, privateVarsInfo.
mlirVars,
2399 assert(afterAllocas.get()->getSinglePredecessor());
2402 afterAllocas.get()->getSinglePredecessor(),
2403 reductionDecls, privateReductionVariables,
2404 reductionVariableMap, isByRef, deferredStores)))
2408 bool isOrdered = wsloopOp.getOrdered().has_value();
2409 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2410 bool isSimd = wsloopOp.getScheduleSimd();
2411 bool loopNeedsBarrier = !wsloopOp.getNowait();
2416 llvm::omp::WorksharingLoopType workshareLoopType =
2417 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2418 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2419 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2423 llvm::omp::Directive::OMPD_for);
2425 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2427 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
2432 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2435 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2436 ompBuilder->applyWorkshareLoop(
2437 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2438 convertToScheduleKind(schedule), chunk, isSimd,
2439 scheduleMod == omp::ScheduleModifier::monotonic,
2440 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2451 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2452 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2462 static LogicalResult
2465 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2467 assert(isByRef.size() == opInst.getNumReductionVars());
2479 opInst.getNumReductionVars());
2482 auto bodyGenCB = [&](InsertPointTy allocaIP,
2483 InsertPointTy codeGenIP) -> llvm::Error {
2485 builder, moduleTranslation, privateVarsInfo, allocaIP);
2487 return llvm::make_error<PreviouslyReportedError>();
2493 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2496 InsertPointTy(allocaIP.getBlock(),
2497 allocaIP.getBlock()->getTerminator()->getIterator());
2500 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2501 reductionDecls, privateReductionVariables, reductionVariableMap,
2502 deferredStores, isByRef)))
2503 return llvm::make_error<PreviouslyReportedError>();
2505 assert(afterAllocas.get()->getSinglePredecessor());
2506 builder.restoreIP(codeGenIP);
2512 return llvm::make_error<PreviouslyReportedError>();
2515 builder, moduleTranslation, privateVarsInfo.
mlirVars,
2517 return llvm::make_error<PreviouslyReportedError>();
2521 afterAllocas.get()->getSinglePredecessor(),
2522 reductionDecls, privateReductionVariables,
2523 reductionVariableMap, isByRef, deferredStores)))
2524 return llvm::make_error<PreviouslyReportedError>();
2529 moduleTranslation, allocaIP);
2533 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2535 return regionBlock.takeError();
2538 if (opInst.getNumReductionVars() > 0) {
2544 owningReductionGens, owningAtomicReductionGens,
2545 privateReductionVariables, reductionInfos);
2548 builder.SetInsertPoint((*regionBlock)->getTerminator());
2551 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2552 builder.SetInsertPoint(tempTerminator);
2554 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2555 ompBuilder->createReductions(
2556 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2558 if (!contInsertPoint)
2559 return contInsertPoint.takeError();
2561 if (!contInsertPoint->getBlock())
2562 return llvm::make_error<PreviouslyReportedError>();
2564 tempTerminator->eraseFromParent();
2565 builder.restoreIP(*contInsertPoint);
2568 return llvm::Error::success();
2571 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2572 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2581 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2582 InsertPointTy oldIP = builder.saveIP();
2583 builder.restoreIP(codeGenIP);
2588 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2589 [](omp::DeclareReductionOp reductionDecl) {
2590 return &reductionDecl.getCleanupRegion();
2593 reductionCleanupRegions, privateReductionVariables,
2594 moduleTranslation, builder,
"omp.reduction.cleanup")))
2595 return llvm::createStringError(
2596 "failed to inline `cleanup` region of `omp.declare_reduction`");
2601 return llvm::make_error<PreviouslyReportedError>();
2603 builder.restoreIP(oldIP);
2604 return llvm::Error::success();
2607 llvm::Value *ifCond =
nullptr;
2608 if (
auto ifVar = opInst.getIfExpr())
2610 llvm::Value *numThreads =
nullptr;
2611 if (
auto numThreadsVar = opInst.getNumThreads())
2612 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2613 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2614 if (
auto bind = opInst.getProcBindKind())
2618 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2620 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2622 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2623 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2624 ifCond, numThreads, pbKind, isCancellable);
2629 builder.restoreIP(*afterIP);
2634 static llvm::omp::OrderKind
2637 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2639 case omp::ClauseOrderKind::Concurrent:
2640 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2642 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2646 static LogicalResult
2650 auto simdOp = cast<omp::SimdOp>(opInst);
2656 if (simdOp.isComposite()) {
2661 builder, moduleTranslation);
2669 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2673 builder, moduleTranslation, privateVarsInfo, allocaIP);
2682 llvm::ConstantInt *simdlen =
nullptr;
2683 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2684 simdlen = builder.getInt64(simdlenVar.value());
2686 llvm::ConstantInt *safelen =
nullptr;
2687 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2688 safelen = builder.getInt64(safelenVar.value());
2690 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2693 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2694 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2696 for (
size_t i = 0; i < operands.size(); ++i) {
2697 llvm::Value *alignment =
nullptr;
2698 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
2699 llvm::Type *ty = llvmVal->getType();
2701 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2702 alignment = builder.getInt64(intAttr.getInt());
2703 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
2704 assert(alignment &&
"Invalid alignment value");
2705 auto curInsert = builder.saveIP();
2706 builder.SetInsertPoint(sourceBlock);
2707 llvmVal = builder.CreateLoad(ty, llvmVal);
2708 builder.restoreIP(curInsert);
2709 alignedVars[llvmVal] = alignment;
2713 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
2718 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2720 ompBuilder->applySimd(loopInfo, alignedVars,
2722 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2724 order, simdlen, safelen);
2732 static LogicalResult
2736 auto loopOp = cast<omp::LoopNestOp>(opInst);
2739 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2744 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2745 llvm::Value *iv) -> llvm::Error {
2748 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2753 bodyInsertPoints.push_back(ip);
2755 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2756 return llvm::Error::success();
2759 builder.restoreIP(ip);
2761 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
2763 return regionBlock.takeError();
2765 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2766 return llvm::Error::success();
2774 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2775 llvm::Value *lowerBound =
2776 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
2777 llvm::Value *upperBound =
2778 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
2779 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
2784 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2785 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2787 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2789 computeIP = loopInfos.front()->getPreheaderIP();
2793 ompBuilder->createCanonicalLoop(
2794 loc, bodyGen, lowerBound, upperBound, step,
2795 true, loopOp.getLoopInclusive(), computeIP);
2800 loopInfos.push_back(*loopResult);
2805 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2806 loopInfos.front()->getAfterIP();
2810 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
2811 [&](OpenMPLoopInfoStackFrame &frame) {
2812 frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2820 builder.restoreIP(afterIP);
2825 static llvm::AtomicOrdering
2828 return llvm::AtomicOrdering::Monotonic;
2831 case omp::ClauseMemoryOrderKind::Seq_cst:
2832 return llvm::AtomicOrdering::SequentiallyConsistent;
2833 case omp::ClauseMemoryOrderKind::Acq_rel:
2834 return llvm::AtomicOrdering::AcquireRelease;
2835 case omp::ClauseMemoryOrderKind::Acquire:
2836 return llvm::AtomicOrdering::Acquire;
2837 case omp::ClauseMemoryOrderKind::Release:
2838 return llvm::AtomicOrdering::Release;
2839 case omp::ClauseMemoryOrderKind::Relaxed:
2840 return llvm::AtomicOrdering::Monotonic;
2842 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
2846 static LogicalResult
2849 auto readOp = cast<omp::AtomicReadOp>(opInst);
2854 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2857 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2860 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
2861 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
2863 llvm::Type *elementType =
2864 moduleTranslation.
convertType(readOp.getElementType());
2866 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
2867 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
2868 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
2873 static LogicalResult
2876 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
2881 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2884 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2886 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
2887 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
2888 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
2889 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
2892 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
2900 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
2901 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
2902 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
2903 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
2904 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
2905 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
2906 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
2907 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
2908 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
2909 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
2913 static LogicalResult
2915 llvm::IRBuilderBase &builder,
2922 auto &innerOpList = opInst.getRegion().front().getOperations();
2923 bool isXBinopExpr{
false};
2924 llvm::AtomicRMWInst::BinOp binop;
2926 llvm::Value *llvmExpr =
nullptr;
2927 llvm::Value *llvmX =
nullptr;
2928 llvm::Type *llvmXElementType =
nullptr;
2929 if (innerOpList.size() == 2) {
2935 opInst.getRegion().getArgument(0))) {
2936 return opInst.emitError(
"no atomic update operation with region argument"
2937 " as operand found inside atomic.update region");
2940 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
2942 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
2946 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2948 llvmX = moduleTranslation.
lookupValue(opInst.getX());
2950 opInst.getRegion().getArgument(0).getType());
2951 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2955 llvm::AtomicOrdering atomicOrdering =
2960 [&opInst, &moduleTranslation](
2961 llvm::Value *atomicx,
2964 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
2965 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
2966 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
2967 return llvm::make_error<PreviouslyReportedError>();
2969 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
2970 assert(yieldop && yieldop.getResults().size() == 1 &&
2971 "terminator must be omp.yield op and it must have exactly one "
2973 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
2978 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2979 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2980 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
2981 atomicOrdering, binop, updateFn,
2987 builder.restoreIP(*afterIP);
2991 static LogicalResult
2993 llvm::IRBuilderBase &builder,
3000 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3001 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3003 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3004 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3006 assert((atomicUpdateOp || atomicWriteOp) &&
3007 "internal op must be an atomic.update or atomic.write op");
3009 if (atomicWriteOp) {
3010 isPostfixUpdate =
true;
3011 mlirExpr = atomicWriteOp.getExpr();
3013 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3014 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3015 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3018 if (innerOpList.size() == 2) {
3021 atomicUpdateOp.getRegion().getArgument(0))) {
3022 return atomicUpdateOp.emitError(
3023 "no atomic update operation with region argument"
3024 " as operand found inside atomic.update region");
3028 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3031 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3035 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3036 llvm::Value *llvmX =
3037 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3038 llvm::Value *llvmV =
3039 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3040 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3041 atomicCaptureOp.getAtomicReadOp().getElementType());
3042 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3045 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3049 llvm::AtomicOrdering atomicOrdering =
3053 [&](llvm::Value *atomicx,
3056 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3057 Block &bb = *atomicUpdateOp.getRegion().
begin();
3058 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3060 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3061 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3062 return llvm::make_error<PreviouslyReportedError>();
3064 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3065 assert(yieldop && yieldop.getResults().size() == 1 &&
3066 "terminator must be omp.yield op and it must have exactly one "
3068 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3073 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3074 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3075 ompBuilder->createAtomicCapture(
3076 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3077 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
3079 if (failed(
handleError(afterIP, *atomicCaptureOp)))
3082 builder.restoreIP(*afterIP);
3087 omp::ClauseCancellationConstructType directive) {
3088 switch (directive) {
3089 case omp::ClauseCancellationConstructType::Loop:
3090 return llvm::omp::Directive::OMPD_for;
3091 case omp::ClauseCancellationConstructType::Parallel:
3092 return llvm::omp::Directive::OMPD_parallel;
3093 case omp::ClauseCancellationConstructType::Sections:
3094 return llvm::omp::Directive::OMPD_sections;
3095 case omp::ClauseCancellationConstructType::Taskgroup:
3096 return llvm::omp::Directive::OMPD_taskgroup;
3100 static LogicalResult
3106 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3109 llvm::Value *ifCond =
nullptr;
3110 if (
Value ifVar = op.getIfExpr())
3113 llvm::omp::Directive cancelledDirective =
3116 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3117 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3119 if (failed(
handleError(afterIP, *op.getOperation())))
3122 builder.restoreIP(afterIP.get());
3127 static LogicalResult
3129 llvm::IRBuilderBase &builder,
3134 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3137 llvm::omp::Directive cancelledDirective =
3140 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3141 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3143 if (failed(
handleError(afterIP, *op.getOperation())))
3146 builder.restoreIP(afterIP.get());
3153 static LogicalResult
3156 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3158 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3163 Value symAddr = threadprivateOp.getSymAddr();
3166 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3169 if (!isa<LLVM::AddressOfOp>(symOp))
3170 return opInst.
emitError(
"Addressing symbol not found");
3171 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3173 LLVM::GlobalOp global =
3174 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
3175 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
3177 if (!ompBuilder->Config.isTargetDevice()) {
3178 llvm::Type *type = globalValue->getValueType();
3179 llvm::TypeSize typeSize =
3180 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3182 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3183 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3184 ompLoc, globalValue, size, global.getSymName() +
".cache");
3193 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3195 switch (deviceClause) {
3196 case mlir::omp::DeclareTargetDeviceType::host:
3197 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3199 case mlir::omp::DeclareTargetDeviceType::nohost:
3200 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3202 case mlir::omp::DeclareTargetDeviceType::any:
3203 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3206 llvm_unreachable(
"unhandled device clause");
3209 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3211 mlir::omp::DeclareTargetCaptureClause captureClause) {
3212 switch (captureClause) {
3213 case mlir::omp::DeclareTargetCaptureClause::to:
3214 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3215 case mlir::omp::DeclareTargetCaptureClause::link:
3216 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3217 case mlir::omp::DeclareTargetCaptureClause::enter:
3218 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3220 llvm_unreachable(
"unhandled capture clause");
3225 llvm::OpenMPIRBuilder &ompBuilder) {
3227 llvm::raw_svector_ostream os(suffix);
3230 auto fileInfoCallBack = [&loc]() {
3231 return std::pair<std::string, uint64_t>(
3232 llvm::StringRef(loc.getFilename()), loc.getLine());
3236 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3238 os <<
"_decl_tgt_ref_ptr";
3244 if (
auto addressOfOp =
3245 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3246 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3247 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3248 if (
auto declareTargetGlobal =
3249 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3250 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3251 mlir::omp::DeclareTargetCaptureClause::link)
3260 static llvm::Value *
3267 if (
auto addressOfOp =
3268 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
3269 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3270 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3271 addressOfOp.getGlobalName()))) {
3273 if (
auto declareTargetGlobal =
3274 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3275 gOp.getOperation())) {
3279 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3280 mlir::omp::DeclareTargetCaptureClause::link) ||
3281 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3282 mlir::omp::DeclareTargetCaptureClause::to &&
3283 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3287 if (gOp.getSymName().contains(suffix))
3292 (gOp.getSymName().str() + suffix.str()).str());
3303 struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3307 void append(MapInfosTy &curInfo) {
3308 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3309 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3318 struct MapInfoData : MapInfosTy {
3330 void append(MapInfoData &CurInfo) {
3331 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3332 CurInfo.IsDeclareTarget.end());
3333 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3334 OriginalValue.append(CurInfo.OriginalValue.begin(),
3335 CurInfo.OriginalValue.end());
3336 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3337 MapInfosTy::append(CurInfo);
3343 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3344 arrTy.getElementType()))
3360 Operation *clauseOp, llvm::Value *basePointer,
3361 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3363 if (
auto memberClause =
3364 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3369 if (!memberClause.getBounds().empty()) {
3370 llvm::Value *elementCount = builder.getInt64(1);
3371 for (
auto bounds : memberClause.getBounds()) {
3372 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3373 bounds.getDefiningOp())) {
3378 elementCount = builder.CreateMul(
3382 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
3383 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
3384 builder.getInt64(1)));
3391 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3399 return builder.CreateMul(elementCount,
3400 builder.getInt64(underlyingTypeSzInBits / 8));
3413 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
3421 for (
Value mapValue : mapVars) {
3422 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3423 for (
auto member : map.getMembers())
3424 if (member == mapOp)
3431 for (
Value mapValue : mapVars) {
3432 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3434 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3435 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
3436 mapData.Pointers.push_back(mapData.OriginalValue.back());
3438 if (llvm::Value *refPtr =
3440 moduleTranslation)) {
3441 mapData.IsDeclareTarget.push_back(
true);
3442 mapData.BasePointers.push_back(refPtr);
3444 mapData.IsDeclareTarget.push_back(
false);
3445 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3448 mapData.BaseType.push_back(
3449 moduleTranslation.
convertType(mapOp.getVarType()));
3450 mapData.Sizes.push_back(
3451 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3452 mapData.BaseType.back(), builder, moduleTranslation));
3453 mapData.MapClause.push_back(mapOp.getOperation());
3454 mapData.Types.push_back(
3455 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3459 if (mapOp.getMapperId())
3460 mapData.Mappers.push_back(
3461 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3462 mapOp, mapOp.getMapperIdAttr()));
3464 mapData.Mappers.push_back(
nullptr);
3465 mapData.IsAMapping.push_back(
true);
3466 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
3469 auto findMapInfo = [&mapData](llvm::Value *val,
3470 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3473 for (llvm::Value *basePtr : mapData.OriginalValue) {
3474 if (basePtr == val && mapData.IsAMapping[index]) {
3476 mapData.Types[index] |=
3477 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3478 mapData.DevicePointers[index] = devInfoTy;
3487 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3488 for (
Value mapValue : useDevOperands) {
3489 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3491 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3492 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3495 if (!findMapInfo(origValue, devInfoTy)) {
3496 mapData.OriginalValue.push_back(origValue);
3497 mapData.Pointers.push_back(mapData.OriginalValue.back());
3498 mapData.IsDeclareTarget.push_back(
false);
3499 mapData.BasePointers.push_back(mapData.OriginalValue.back());
3500 mapData.BaseType.push_back(
3501 moduleTranslation.
convertType(mapOp.getVarType()));
3502 mapData.Sizes.push_back(builder.getInt64(0));
3503 mapData.MapClause.push_back(mapOp.getOperation());
3504 mapData.Types.push_back(
3505 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3508 mapData.DevicePointers.push_back(devInfoTy);
3509 mapData.Mappers.push_back(
nullptr);
3510 mapData.IsAMapping.push_back(
false);
3511 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
3516 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3517 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3519 for (
Value mapValue : hasDevAddrOperands) {
3520 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3522 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3523 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
3525 static_cast<llvm::omp::OpenMPOffloadMappingFlags
>(mapOp.getMapType());
3526 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3528 mapData.OriginalValue.push_back(origValue);
3529 mapData.BasePointers.push_back(origValue);
3530 mapData.Pointers.push_back(origValue);
3531 mapData.IsDeclareTarget.push_back(
false);
3532 mapData.BaseType.push_back(
3533 moduleTranslation.
convertType(mapOp.getVarType()));
3534 mapData.Sizes.push_back(
3535 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
3536 mapData.MapClause.push_back(mapOp.getOperation());
3537 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3541 mapData.Types.push_back(mapType);
3545 if (mapOp.getMapperId()) {
3546 mapData.Mappers.push_back(
3547 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3548 mapOp, mapOp.getMapperIdAttr()));
3550 mapData.Mappers.push_back(
nullptr);
3553 mapData.Types.push_back(
3554 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3555 mapData.Mappers.push_back(
nullptr);
3559 mapData.DevicePointers.push_back(
3560 llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3561 mapData.IsAMapping.push_back(
false);
3562 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
3567 auto *res = llvm::find(mapData.MapClause, memberOp);
3568 assert(res != mapData.MapClause.end() &&
3569 "MapInfoOp for member not found in MapData, cannot return index");
3570 return std::distance(mapData.MapClause.begin(), res);
3575 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3577 if (indexAttr.size() == 1)
3578 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3581 std::iota(indices.begin(), indices.end(), 0);
3583 llvm::sort(indices.begin(), indices.end(),
3584 [&](
const size_t a,
const size_t b) {
3585 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3586 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3587 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3588 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3589 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3591 if (aIndex == bIndex)
3594 if (aIndex < bIndex)
3597 if (aIndex > bIndex)
3604 return memberIndicesA.size() < memberIndicesB.size();
3607 return llvm::cast<omp::MapInfoOp>(
3608 mapInfo.getMembers()[indices.front()].getDefiningOp());
3630 std::vector<llvm::Value *>
3632 llvm::IRBuilderBase &builder,
bool isArrayTy,
3634 std::vector<llvm::Value *> idx;
3645 idx.push_back(builder.getInt64(0));
3646 for (
int i = bounds.size() - 1; i >= 0; --i) {
3647 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3648 bounds[i].getDefiningOp())) {
3649 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
3671 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3672 for (
size_t i = 1; i < bounds.size(); ++i) {
3673 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3674 bounds[i].getDefiningOp())) {
3675 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3676 moduleTranslation.
lookupValue(boundOp.getExtent()),
3677 dimensionIndexSizeOffset[i - 1]));
3685 for (
int i = bounds.size() - 1; i >= 0; --i) {
3686 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3687 bounds[i].getDefiningOp())) {
3689 idx.emplace_back(builder.CreateMul(
3690 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
3691 dimensionIndexSizeOffset[i]));
3693 idx.back() = builder.CreateAdd(
3694 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
3695 boundOp.getLowerBound()),
3696 dimensionIndexSizeOffset[i]));
3721 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3722 MapInfoData &mapData, uint64_t mapDataIndex,
bool isTargetParams) {
3724 combinedInfo.Types.emplace_back(
3726 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3727 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3728 combinedInfo.DevicePointers.emplace_back(
3729 mapData.DevicePointers[mapDataIndex]);
3730 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
3732 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3733 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3743 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3745 llvm::Value *lowAddr, *highAddr;
3746 if (!parentClause.getPartialMap()) {
3747 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3748 builder.getPtrTy());
3749 highAddr = builder.CreatePointerCast(
3750 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3751 mapData.Pointers[mapDataIndex], 1),
3752 builder.getPtrTy());
3753 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3755 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3758 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3759 builder.getPtrTy());
3762 highAddr = builder.CreatePointerCast(
3763 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3764 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3765 builder.getPtrTy());
3766 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3769 llvm::Value *size = builder.CreateIntCast(
3770 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3771 builder.getInt64Ty(),
3773 combinedInfo.Sizes.push_back(size);
3775 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3776 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3784 if (!parentClause.getPartialMap()) {
3789 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3790 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3791 combinedInfo.Types.emplace_back(mapFlag);
3792 combinedInfo.DevicePointers.emplace_back(
3794 combinedInfo.Mappers.emplace_back(
nullptr);
3796 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3797 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3798 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3799 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3801 return memberOfFlag;
3813 if (mapOp.getVarPtrPtr())
3828 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
3829 MapInfoData &mapData, uint64_t mapDataIndex,
3830 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
3833 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3835 for (
auto mappedMembers : parentClause.getMembers()) {
3837 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
3840 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
3851 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
3852 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3853 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3854 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3855 combinedInfo.Types.emplace_back(mapFlag);
3856 combinedInfo.DevicePointers.emplace_back(
3858 combinedInfo.Mappers.emplace_back(
nullptr);
3859 combinedInfo.Names.emplace_back(
3861 combinedInfo.BasePointers.emplace_back(
3862 mapData.BasePointers[mapDataIndex]);
3863 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
3864 combinedInfo.Sizes.emplace_back(builder.getInt64(
3865 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
3871 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
3872 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3873 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3874 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3876 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3878 combinedInfo.Types.emplace_back(mapFlag);
3879 combinedInfo.DevicePointers.emplace_back(
3880 mapData.DevicePointers[memberDataIdx]);
3881 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
3882 combinedInfo.Names.emplace_back(
3884 uint64_t basePointerIndex =
3886 combinedInfo.BasePointers.emplace_back(
3887 mapData.BasePointers[basePointerIndex]);
3888 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
3890 llvm::Value *size = mapData.Sizes[memberDataIdx];
3892 size = builder.CreateSelect(
3893 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
3894 builder.getInt64(0), size);
3897 combinedInfo.Sizes.emplace_back(size);
3902 MapInfosTy &combinedInfo,
bool isTargetParams,
3903 int mapDataParentIdx = -1) {
3907 auto mapFlag = mapData.Types[mapDataIdx];
3908 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
3912 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3914 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
3915 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3917 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
3919 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3924 if (mapDataParentIdx >= 0)
3925 combinedInfo.BasePointers.emplace_back(
3926 mapData.BasePointers[mapDataParentIdx]);
3928 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
3930 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
3931 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
3932 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
3933 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
3934 combinedInfo.Types.emplace_back(mapFlag);
3935 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
3939 llvm::IRBuilderBase &builder,
3940 llvm::OpenMPIRBuilder &ompBuilder,
3942 MapInfoData &mapData, uint64_t mapDataIndex,
3943 bool isTargetParams) {
3945 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3950 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
3951 auto memberClause = llvm::cast<omp::MapInfoOp>(
3952 parentClause.getMembers()[0].getDefiningOp());
3969 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
3971 combinedInfo, mapData, mapDataIndex, isTargetParams);
3973 combinedInfo, mapData, mapDataIndex,
3974 memberOfParentFlag);
3984 llvm::IRBuilderBase &builder) {
3985 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
3987 if (!mapData.IsDeclareTarget[i]) {
3988 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3989 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
3999 switch (captureKind) {
4000 case omp::VariableCaptureKind::ByRef: {
4001 llvm::Value *newV = mapData.Pointers[i];
4003 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4006 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4008 if (!offsetIdx.empty())
4009 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4011 mapData.Pointers[i] = newV;
4013 case omp::VariableCaptureKind::ByCopy: {
4014 llvm::Type *type = mapData.BaseType[i];
4016 if (mapData.Pointers[i]->getType()->isPointerTy())
4017 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4019 newV = mapData.Pointers[i];
4022 auto curInsert = builder.saveIP();
4024 auto *memTempAlloc =
4025 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
4026 builder.restoreIP(curInsert);
4028 builder.CreateStore(newV, memTempAlloc);
4029 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4032 mapData.Pointers[i] = newV;
4033 mapData.BasePointers[i] = newV;
4035 case omp::VariableCaptureKind::This:
4036 case omp::VariableCaptureKind::VLAType:
4037 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
4048 MapInfoData &mapData,
bool isTargetParams =
false) {
4070 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4073 if (mapData.IsAMember[i])
4076 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4077 if (!mapInfoOp.getMembers().empty()) {
4079 combinedInfo, mapData, i, isTargetParams);
4090 llvm::StringRef mapperFuncName);
4095 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4096 std::string mapperFuncName =
4098 {
"omp_mapper", declMapperOp.getSymName()});
4100 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
4110 llvm::StringRef mapperFuncName) {
4111 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4112 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4115 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
4118 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4121 MapInfosTy combinedInfo;
4123 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4124 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4125 builder.restoreIP(codeGenIP);
4126 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
4127 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
4128 builder.GetInsertBlock());
4129 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
4132 return llvm::make_error<PreviouslyReportedError>();
4133 MapInfoData mapData;
4136 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4141 return combinedInfo;
4145 if (!combinedInfo.Mappers[i])
4152 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4154 return newFn.takeError();
4155 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
4159 static LogicalResult
4162 llvm::Value *ifCond =
nullptr;
4163 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4167 llvm::omp::RuntimeFunction RTLFn;
4171 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
4174 LogicalResult result =
4176 .Case([&](omp::TargetDataOp dataOp) {
4180 if (
auto ifVar = dataOp.getIfExpr())
4183 if (
auto devId = dataOp.getDevice())
4185 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4186 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4187 deviceID = intAttr.getInt();
4189 mapVars = dataOp.getMapVars();
4190 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4191 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4194 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4198 if (
auto ifVar = enterDataOp.getIfExpr())
4201 if (
auto devId = enterDataOp.getDevice())
4203 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4204 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4205 deviceID = intAttr.getInt();
4207 enterDataOp.getNowait()
4208 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4209 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4210 mapVars = enterDataOp.getMapVars();
4211 info.HasNoWait = enterDataOp.getNowait();
4214 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4218 if (
auto ifVar = exitDataOp.getIfExpr())
4221 if (
auto devId = exitDataOp.getDevice())
4223 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4224 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4225 deviceID = intAttr.getInt();
4227 RTLFn = exitDataOp.getNowait()
4228 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4229 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4230 mapVars = exitDataOp.getMapVars();
4231 info.HasNoWait = exitDataOp.getNowait();
4234 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4238 if (
auto ifVar = updateDataOp.getIfExpr())
4241 if (
auto devId = updateDataOp.getDevice())
4243 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4244 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4245 deviceID = intAttr.getInt();
4248 updateDataOp.getNowait()
4249 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4250 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4251 mapVars = updateDataOp.getMapVars();
4252 info.HasNoWait = updateDataOp.getNowait();
4256 llvm_unreachable(
"unexpected operation");
4263 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4264 MapInfoData mapData;
4266 builder, useDevicePtrVars, useDeviceAddrVars);
4269 MapInfosTy combinedInfo;
4270 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4271 builder.restoreIP(codeGenIP);
4272 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
4273 return combinedInfo;
4279 [&moduleTranslation](
4280 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4284 for (
auto [arg, useDevVar] :
4285 llvm::zip_equal(blockArgs, useDeviceVars)) {
4287 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4288 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4289 : mapInfoOp.getVarPtr();
4292 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4293 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4294 mapInfoData.MapClause, mapInfoData.DevicePointers,
4295 mapInfoData.BasePointers)) {
4296 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4297 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4298 devicePointer != type)
4301 if (llvm::Value *devPtrInfoMap =
4302 mapper ? mapper(basePointer) : basePointer) {
4303 moduleTranslation.
mapValue(arg, devPtrInfoMap);
4310 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4311 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4312 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4313 builder.restoreIP(codeGenIP);
4314 assert(isa<omp::TargetDataOp>(op) &&
4315 "BodyGen requested for non TargetDataOp");
4316 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4317 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
4318 switch (bodyGenType) {
4319 case BodyGenTy::Priv:
4321 if (!info.DevicePtrInfoMap.empty()) {
4322 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4323 blockArgIface.getUseDeviceAddrBlockArgs(),
4324 useDeviceAddrVars, mapData,
4325 [&](llvm::Value *basePointer) -> llvm::Value * {
4326 if (!info.DevicePtrInfoMap[basePointer].second)
4328 return builder.CreateLoad(
4330 info.DevicePtrInfoMap[basePointer].second);
4332 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4333 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4334 mapData, [&](llvm::Value *basePointer) {
4335 return info.DevicePtrInfoMap[basePointer].second;
4339 moduleTranslation)))
4340 return llvm::make_error<PreviouslyReportedError>();
4343 case BodyGenTy::DupNoPriv:
4346 builder.restoreIP(codeGenIP);
4348 case BodyGenTy::NoPriv:
4350 if (info.DevicePtrInfoMap.empty()) {
4353 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
4354 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4355 blockArgIface.getUseDeviceAddrBlockArgs(),
4356 useDeviceAddrVars, mapData);
4357 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4358 blockArgIface.getUseDevicePtrBlockArgs(),
4359 useDevicePtrVars, mapData);
4363 moduleTranslation)))
4364 return llvm::make_error<PreviouslyReportedError>();
4368 return builder.saveIP();
4371 auto customMapperCB =
4373 if (!combinedInfo.Mappers[i])
4375 info.HasMapper =
true;
4380 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4381 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4383 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4384 if (isa<omp::TargetDataOp>(op))
4385 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
4386 builder.getInt64(deviceID), ifCond,
4387 info, genMapInfoCB, customMapperCB,
4390 return ompBuilder->createTargetData(
4391 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
4392 info, genMapInfoCB, customMapperCB, &RTLFn);
4398 builder.restoreIP(*afterIP);
4402 static LogicalResult
4406 auto distributeOp = cast<omp::DistributeOp>(opInst);
4413 bool doDistributeReduction =
4417 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4422 if (doDistributeReduction) {
4423 isByRef =
getIsByRef(teamsOp.getReductionByref());
4424 assert(isByRef.size() == teamsOp.getNumReductionVars());
4427 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4431 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4432 .getReductionBlockArgs();
4435 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4436 reductionDecls, privateReductionVariables, reductionVariableMap,
4441 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4442 auto bodyGenCB = [&](InsertPointTy allocaIP,
4443 InsertPointTy codeGenIP) -> llvm::Error {
4447 moduleTranslation, allocaIP);
4450 builder.restoreIP(codeGenIP);
4456 return llvm::make_error<PreviouslyReportedError>();
4461 return llvm::make_error<PreviouslyReportedError>();
4464 builder, moduleTranslation, privVarsInfo.
mlirVars,
4466 return llvm::make_error<PreviouslyReportedError>();
4469 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4472 builder, moduleTranslation);
4474 return regionBlock.takeError();
4475 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4480 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4483 auto schedule = omp::ClauseScheduleKind::Static;
4484 bool isOrdered =
false;
4485 std::optional<omp::ScheduleModifier> scheduleMod;
4486 bool isSimd =
false;
4487 llvm::omp::WorksharingLoopType workshareLoopType =
4488 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4489 bool loopNeedsBarrier =
false;
4490 llvm::Value *chunk =
nullptr;
4492 llvm::CanonicalLoopInfo *loopInfo =
4494 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4495 ompBuilder->applyWorkshareLoop(
4496 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4497 convertToScheduleKind(schedule), chunk, isSimd,
4498 scheduleMod == omp::ScheduleModifier::monotonic,
4499 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4503 return wsloopIP.takeError();
4507 distributeOp.getLoc(), privVarsInfo.
llvmVars,
4509 return llvm::make_error<PreviouslyReportedError>();
4511 return llvm::Error::success();
4514 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4516 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4517 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4518 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
4523 builder.restoreIP(*afterIP);
4525 if (doDistributeReduction) {
4528 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4529 privateReductionVariables, isByRef,
4540 if (!cast<mlir::ModuleOp>(op))
4545 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
4546 attribute.getOpenmpDeviceVersion());
4548 if (attribute.getNoGpuLib())
4551 ompBuilder->createGlobalFlag(
4552 attribute.getDebugKind() ,
4553 "__omp_rtl_debug_kind");
4554 ompBuilder->createGlobalFlag(
4556 .getAssumeTeamsOversubscription()
4558 "__omp_rtl_assume_teams_oversubscription");
4559 ompBuilder->createGlobalFlag(
4561 .getAssumeThreadsOversubscription()
4563 "__omp_rtl_assume_threads_oversubscription");
4564 ompBuilder->createGlobalFlag(
4565 attribute.getAssumeNoThreadState() ,
4566 "__omp_rtl_assume_no_thread_state");
4567 ompBuilder->createGlobalFlag(
4569 .getAssumeNoNestedParallelism()
4571 "__omp_rtl_assume_no_nested_parallelism");
4576 omp::TargetOp targetOp,
4577 llvm::StringRef parentName =
"") {
4578 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
4580 assert(fileLoc &&
"No file found from location");
4581 StringRef fileName = fileLoc.getFilename().getValue();
4583 llvm::sys::fs::UniqueID id;
4584 uint64_t line = fileLoc.getLine();
4585 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
4587 size_t deviceId = 0xdeadf17e;
4589 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4591 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
4592 id.getFile(), line);
4599 llvm::IRBuilderBase &builder, llvm::Function *func) {
4600 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
4613 if (mapData.IsDeclareTarget[i]) {
4620 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
4621 convertUsersOfConstantsToInstructions(constant, func,
false);
4628 for (llvm::User *user : mapData.OriginalValue[i]->users())
4629 userVec.push_back(user);
4631 for (llvm::User *user : userVec) {
4632 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
4633 if (insn->getFunction() == func) {
4634 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
4635 mapData.BasePointers[i]);
4636 load->moveBefore(insn->getIterator());
4637 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
4684 static llvm::IRBuilderBase::InsertPoint
4686 llvm::Value *input, llvm::Value *&retVal,
4687 llvm::IRBuilderBase &builder,
4688 llvm::OpenMPIRBuilder &ompBuilder,
4690 llvm::IRBuilderBase::InsertPoint allocaIP,
4691 llvm::IRBuilderBase::InsertPoint codeGenIP) {
4692 builder.restoreIP(allocaIP);
4694 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
4696 ompBuilder.M.getContext());
4697 unsigned alignmentValue = 0;
4699 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
4700 if (mapData.OriginalValue[i] == input) {
4701 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4702 capture = mapOp.getMapCaptureType();
4705 mapOp.getVarType(), ompBuilder.M.getDataLayout());
4709 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
4710 unsigned int defaultAS =
4711 ompBuilder.M.getDataLayout().getProgramAddressSpace();
4714 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
4716 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
4717 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
4719 builder.CreateStore(&arg, v);
4721 builder.restoreIP(codeGenIP);
4724 case omp::VariableCaptureKind::ByCopy: {
4728 case omp::VariableCaptureKind::ByRef: {
4729 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
4731 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
4746 if (v->getType()->isPointerTy() && alignmentValue) {
4747 llvm::MDBuilder MDB(builder.getContext());
4748 loadInst->setMetadata(
4749 llvm::LLVMContext::MD_align,
4752 llvm::Type::getInt64Ty(builder.getContext()),
4759 case omp::VariableCaptureKind::This:
4760 case omp::VariableCaptureKind::VLAType:
4763 assert(
false &&
"Currently unsupported capture kind");
4767 return builder.saveIP();
4784 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4785 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
4786 blockArgIface.getHostEvalBlockArgs())) {
4787 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
4791 .Case([&](omp::TeamsOp teamsOp) {
4792 if (teamsOp.getNumTeamsLower() == blockArg)
4793 numTeamsLower = hostEvalVar;
4794 else if (teamsOp.getNumTeamsUpper() == blockArg)
4795 numTeamsUpper = hostEvalVar;
4796 else if (teamsOp.getThreadLimit() == blockArg)
4797 threadLimit = hostEvalVar;
4799 llvm_unreachable(
"unsupported host_eval use");
4801 .Case([&](omp::ParallelOp parallelOp) {
4802 if (parallelOp.getNumThreads() == blockArg)
4803 numThreads = hostEvalVar;
4805 llvm_unreachable(
"unsupported host_eval use");
4807 .Case([&](omp::LoopNestOp loopOp) {
4808 auto processBounds =
4813 if (lb == blockArg) {
4816 (*outBounds)[i] = hostEvalVar;
4822 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
4823 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
4825 found = processBounds(loopOp.getLoopSteps(), steps) || found;
4827 assert(found &&
"unsupported host_eval use");
4830 llvm_unreachable(
"unsupported host_eval use");
4843 template <
typename OpTy>
4848 if (OpTy casted = dyn_cast<OpTy>(op))
4851 if (immediateParent)
4852 return dyn_cast_if_present<OpTy>(op->
getParentOp());
4861 return std::nullopt;
4864 dyn_cast_if_present<LLVM::ConstantOp>(value.
getDefiningOp()))
4865 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4866 return constAttr.getInt();
4868 return std::nullopt;
4873 uint64_t sizeInBytes = sizeInBits / 8;
4877 template <
typename OpTy>
4879 if (op.getNumReductionVars() > 0) {
4884 members.reserve(reductions.size());
4885 for (omp::DeclareReductionOp &red : reductions)
4886 members.push_back(red.getType());
4888 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
4904 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
4905 bool isTargetDevice,
bool isGPU) {
4908 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
4909 if (!isTargetDevice) {
4916 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4917 numTeamsLower = teamsOp.getNumTeamsLower();
4918 numTeamsUpper = teamsOp.getNumTeamsUpper();
4919 threadLimit = teamsOp.getThreadLimit();
4922 if (
auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4923 numThreads = parallelOp.getNumThreads();
4928 int32_t minTeamsVal = 1, maxTeamsVal = -1;
4929 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4932 if (numTeamsUpper) {
4934 minTeamsVal = maxTeamsVal = *val;
4936 minTeamsVal = maxTeamsVal = 0;
4938 }
else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
4940 castOrGetParentOfType<omp::SimdOp>(capturedOp,
4942 minTeamsVal = maxTeamsVal = 1;
4944 minTeamsVal = maxTeamsVal = -1;
4949 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &result) {
4963 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
4964 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
4965 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
4968 int32_t maxThreadsVal = -1;
4969 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4970 setMaxValueFromClause(numThreads, maxThreadsVal);
4971 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
4978 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
4979 if (combinedMaxThreadsVal < 0 ||
4980 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
4981 combinedMaxThreadsVal = teamsThreadLimitVal;
4983 if (combinedMaxThreadsVal < 0 ||
4984 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
4985 combinedMaxThreadsVal = maxThreadsVal;
4987 int32_t reductionDataSize = 0;
4988 if (isGPU && capturedOp) {
4989 if (
auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
4994 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
4996 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
4997 omp::TargetRegionFlags::spmd) &&
4998 "invalid kernel flags");
5000 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5001 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5002 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5003 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5004 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5005 attrs.MinTeams = minTeamsVal;
5006 attrs.MaxTeams.front() = maxTeamsVal;
5007 attrs.MinThreads = 1;
5008 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5009 attrs.ReductionDataSize = reductionDataSize;
5012 if (attrs.ReductionDataSize != 0)
5013 attrs.ReductionBufferLength = 1024;
5025 omp::TargetOp targetOp,
Operation *capturedOp,
5026 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5027 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5028 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5030 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5034 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5037 if (
Value targetThreadLimit = targetOp.getThreadLimit())
5038 attrs.TargetThreadLimit.front() =
5042 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
5045 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
5047 if (teamsThreadLimit)
5048 attrs.TeamsThreadLimit.front() =
5052 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
5054 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5055 omp::TargetRegionFlags::trip_count)) {
5057 attrs.LoopTripCount =
nullptr;
5062 for (
auto [loopLower, loopUpper, loopStep] :
5063 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5064 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
5065 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
5066 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
5068 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5069 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5070 loc, lowerBound, upperBound, step,
true,
5071 loopOp.getLoopInclusive());
5073 if (!attrs.LoopTripCount) {
5074 attrs.LoopTripCount = tripCount;
5079 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5085 static LogicalResult
5088 auto targetOp = cast<omp::TargetOp>(opInst);
5093 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5094 bool isGPU = ompBuilder->Config.isGPU();
5097 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5098 auto &targetRegion = targetOp.getRegion();
5115 llvm::Function *llvmOutlinedFn =
nullptr;
5119 bool isOffloadEntry =
5120 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5127 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5129 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5130 std::optional<DenseI64ArrayAttr> privateMapIndices =
5131 targetOp.getPrivateMapsAttr();
5133 for (
auto [privVarIdx, privVarSymPair] :
5135 auto privVar = std::get<0>(privVarSymPair);
5136 auto privSym = std::get<1>(privVarSymPair);
5138 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5139 omp::PrivateClauseOp privatizer =
5142 if (!privatizer.needsMap())
5146 targetOp.getMappedValueForPrivateVar(privVarIdx);
5147 assert(mappedValue &&
"Expected to find mapped value for a privatized "
5148 "variable that needs mapping");
5153 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
5154 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
5158 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5160 varType == privVar.getType() &&
5161 "Type of private var doesn't match the type of the mapped value");
5165 mappedPrivateVars.insert(
5167 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5168 (*privateMapIndices)[privVarIdx])});
5172 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5173 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5174 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5175 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5176 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5179 llvm::Function *llvmParentFn =
5181 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5182 assert(llvmParentFn && llvmOutlinedFn &&
5183 "Both parent and outlined functions must exist at this point");
5185 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
5186 attr.isStringAttribute())
5187 llvmOutlinedFn->addFnAttr(attr);
5189 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
5190 attr.isStringAttribute())
5191 llvmOutlinedFn->addFnAttr(attr);
5193 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5194 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5195 llvm::Value *mapOpValue =
5196 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5197 moduleTranslation.
mapValue(arg, mapOpValue);
5199 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5200 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5201 llvm::Value *mapOpValue =
5202 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
5203 moduleTranslation.
mapValue(arg, mapOpValue);
5212 allocaIP, &mappedPrivateVars);
5215 return llvm::make_error<PreviouslyReportedError>();
5217 builder.restoreIP(codeGenIP);
5219 &mappedPrivateVars),
5222 return llvm::make_error<PreviouslyReportedError>();
5225 builder, moduleTranslation, privateVarsInfo.
mlirVars,
5227 &mappedPrivateVars)))
5228 return llvm::make_error<PreviouslyReportedError>();
5232 std::back_inserter(privateCleanupRegions),
5233 [](omp::PrivateClauseOp privatizer) {
5234 return &privatizer.getDeallocRegion();
5238 targetRegion,
"omp.target", builder, moduleTranslation);
5241 return exitBlock.takeError();
5243 builder.SetInsertPoint(*exitBlock);
5244 if (!privateCleanupRegions.empty()) {
5246 privateCleanupRegions, privateVarsInfo.
llvmVars,
5247 moduleTranslation, builder,
"omp.targetop.private.cleanup",
5249 return llvm::createStringError(
5250 "failed to inline `dealloc` region of `omp.private` "
5251 "op in the target region");
5253 return builder.saveIP();
5256 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5259 StringRef parentName = parentFn.getName();
5261 llvm::TargetRegionEntryInfo entryInfo;
5265 MapInfoData mapData;
5270 MapInfosTy combinedInfos;
5272 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5273 builder.restoreIP(codeGenIP);
5274 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
5275 return combinedInfos;
5278 auto argAccessorCB = [&](
llvm::Argument &arg, llvm::Value *input,
5279 llvm::Value *&retVal, InsertPointTy allocaIP,
5280 InsertPointTy codeGenIP)
5281 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5282 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5283 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5289 if (!isTargetDevice) {
5290 retVal = cast<llvm::Value>(&arg);
5295 *ompBuilder, moduleTranslation,
5296 allocaIP, codeGenIP);
5299 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5300 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5301 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5303 isTargetDevice, isGPU);
5307 if (!isTargetDevice)
5309 targetCapturedOp, runtimeAttrs);
5317 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5318 llvm::Value *value = moduleTranslation.
lookupValue(var);
5319 moduleTranslation.
mapValue(arg, value);
5321 if (!llvm::isa<llvm::Constant>(value))
5322 kernelInput.push_back(value);
5325 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5332 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5333 kernelInput.push_back(mapData.OriginalValue[i]);
5338 moduleTranslation, dds);
5340 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5342 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5344 llvm::OpenMPIRBuilder::TargetDataInfo info(
5348 auto customMapperCB =
5350 if (!combinedInfos.Mappers[i])
5352 info.HasMapper =
true;
5357 llvm::Value *ifCond =
nullptr;
5358 if (
Value targetIfCond = targetOp.getIfExpr())
5359 ifCond = moduleTranslation.
lookupValue(targetIfCond);
5361 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5363 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
5364 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
5365 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
5370 builder.restoreIP(*afterIP);
5381 static LogicalResult
5391 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5392 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5394 if (!offloadMod.getIsTargetDevice())
5397 omp::DeclareTargetDeviceType declareType =
5398 attribute.getDeviceType().getValue();
5400 if (declareType == omp::DeclareTargetDeviceType::host) {
5401 llvm::Function *llvmFunc =
5403 llvmFunc->dropAllReferences();
5404 llvmFunc->eraseFromParent();
5410 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5411 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
5412 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5414 bool isDeclaration = gOp.isDeclaration();
5415 bool isExternallyVisible =
5418 llvm::StringRef mangledName = gOp.getSymName();
5419 auto captureClause =
5425 std::vector<llvm::GlobalVariable *> generatedRefs;
5427 std::vector<llvm::Triple> targetTriple;
5428 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5430 LLVM::LLVMDialect::getTargetTripleAttrName()));
5431 if (targetTripleAttr)
5432 targetTriple.emplace_back(targetTripleAttr.data());
5434 auto fileInfoCallBack = [&loc]() {
5435 std::string filename =
"";
5436 std::uint64_t lineNo = 0;
5439 filename = loc.getFilename().str();
5440 lineNo = loc.getLine();
5443 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5447 ompBuilder->registerTargetGlobalVariable(
5448 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5449 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5450 generatedRefs,
false, targetTriple,
5452 gVal->getType(), gVal);
5454 if (ompBuilder->Config.isTargetDevice() &&
5455 (attribute.getCaptureClause().getValue() !=
5456 mlir::omp::DeclareTargetCaptureClause::to ||
5457 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5458 ompBuilder->getAddrOfDeclareTargetVar(
5459 captureClause, deviceClause, isDeclaration, isExternallyVisible,
5460 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
5461 generatedRefs,
false, targetTriple, gVal->getType(),
5483 if (mlir::isa<omp::ThreadprivateOp>(op))
5487 if (
auto declareTargetIface =
5488 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5489 parentFn.getOperation()))
5490 if (declareTargetIface.isDeclareTarget() &&
5491 declareTargetIface.getDeclareTargetDeviceType() !=
5492 mlir::omp::DeclareTargetDeviceType::host)
5500 static LogicalResult
5511 bool isOutermostLoopWrapper =
5512 isa_and_present<omp::LoopWrapperInterface>(op) &&
5513 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
5515 if (isOutermostLoopWrapper)
5516 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
5520 .Case([&](omp::BarrierOp op) -> LogicalResult {
5524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5525 ompBuilder->createBarrier(builder.saveIP(),
5526 llvm::omp::OMPD_barrier);
5529 .Case([&](omp::TaskyieldOp op) {
5533 ompBuilder->createTaskyield(builder.saveIP());
5536 .Case([&](omp::FlushOp op) {
5548 ompBuilder->createFlush(builder.saveIP());
5551 .Case([&](omp::ParallelOp op) {
5554 .Case([&](omp::MaskedOp) {
5557 .Case([&](omp::MasterOp) {
5560 .Case([&](omp::CriticalOp) {
5563 .Case([&](omp::OrderedRegionOp) {
5566 .Case([&](omp::OrderedOp) {
5569 .Case([&](omp::WsloopOp) {
5572 .Case([&](omp::SimdOp) {
5575 .Case([&](omp::AtomicReadOp) {
5578 .Case([&](omp::AtomicWriteOp) {
5581 .Case([&](omp::AtomicUpdateOp op) {
5584 .Case([&](omp::AtomicCaptureOp op) {
5587 .Case([&](omp::CancelOp op) {
5590 .Case([&](omp::CancellationPointOp op) {
5593 .Case([&](omp::SectionsOp) {
5596 .Case([&](omp::SingleOp op) {
5599 .Case([&](omp::TeamsOp op) {
5602 .Case([&](omp::TaskOp op) {
5605 .Case([&](omp::TaskgroupOp op) {
5608 .Case([&](omp::TaskwaitOp op) {
5611 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5612 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5613 omp::CriticalDeclareOp>([](
auto op) {
5626 .Case([&](omp::ThreadprivateOp) {
5629 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
5630 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
5633 .Case([&](omp::TargetOp) {
5636 .Case([&](omp::DistributeOp) {
5639 .Case([&](omp::LoopNestOp) {
5642 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
5651 <<
"not yet implemented: " << inst->
getName();
5654 if (isOutermostLoopWrapper)
5660 static LogicalResult
5666 static LogicalResult
5669 if (isa<omp::TargetOp>(op))
5671 if (isa<omp::TargetDataOp>(op))
5675 if (isa<omp::TargetOp>(oper)) {
5677 return WalkResult::interrupt();
5678 return WalkResult::skip();
5680 if (isa<omp::TargetDataOp>(oper)) {
5682 return WalkResult::interrupt();
5683 return WalkResult::skip();
5690 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5691 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5692 !oper->getRegions().empty()) {
5693 if (
auto blockArgsIface =
5694 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5700 if (isa<mlir::omp::AtomicUpdateOp>(oper))
5701 for (
auto [operand, arg] :
5702 llvm::zip_equal(oper->getOperands(),
5703 oper->getRegion(0).getArguments())) {
5705 arg, builder.CreateLoad(
5711 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5712 assert(builder.GetInsertBlock() &&
5713 "No insert block is set for the builder");
5714 for (
auto iv : loopNest.getIVs()) {
5722 for (
Region ®ion : oper->getRegions()) {
5729 region, oper->getName().getStringRef().str() +
".fake.region",
5730 builder, moduleTranslation, &phis);
5732 return WalkResult::interrupt();
5734 builder.SetInsertPoint(result.get(), result.get()->end());
5737 return WalkResult::skip();
5740 return WalkResult::advance();
5741 }).wasInterrupted();
5742 return failure(interrupted);
5749 class OpenMPDialectLLVMIRTranslationInterface
5770 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
5776 .Case(
"omp.is_target_device",
5778 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
5779 llvm::OpenMPIRBuilderConfig &
config =
5781 config.setIsTargetDevice(deviceAttr.getValue());
5788 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
5789 llvm::OpenMPIRBuilderConfig &
config =
5791 config.setIsGPU(gpuAttr.getValue());
5796 .Case(
"omp.host_ir_filepath",
5798 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
5799 llvm::OpenMPIRBuilder *ompBuilder =
5801 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
5808 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
5812 .Case(
"omp.version",
5814 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
5815 llvm::OpenMPIRBuilder *ompBuilder =
5817 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
5818 versionAttr.getVersion());
5823 .Case(
"omp.declare_target",
5825 if (
auto declareTargetAttr =
5826 dyn_cast<omp::DeclareTargetAttr>(attr))
5831 .Case(
"omp.requires",
5833 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
5834 using Requires = omp::ClauseRequires;
5835 Requires flags = requiresAttr.getValue();
5836 llvm::OpenMPIRBuilderConfig &
config =
5838 config.setHasRequiresReverseOffload(
5839 bitEnumContainsAll(flags, Requires::reverse_offload));
5840 config.setHasRequiresUnifiedAddress(
5841 bitEnumContainsAll(flags, Requires::unified_address));
5842 config.setHasRequiresUnifiedSharedMemory(
5843 bitEnumContainsAll(flags, Requires::unified_shared_memory));
5844 config.setHasRequiresDynamicAllocators(
5845 bitEnumContainsAll(flags, Requires::dynamic_allocators));
5850 .Case(
"omp.target_triples",
5852 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
5853 llvm::OpenMPIRBuilderConfig &
config =
5855 config.TargetTriples.clear();
5856 config.TargetTriples.reserve(triplesAttr.size());
5857 for (
Attribute tripleAttr : triplesAttr) {
5858 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
5859 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
5877 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
5878 Operation *op, llvm::IRBuilderBase &builder,
5882 if (ompBuilder->Config.isTargetDevice()) {
5893 registry.
insert<omp::OpenMPDialect>();
5895 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable.
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type.
static LogicalResult copyFirstPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct.
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
Inline reductions' init regions.
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult convertIgnoredWrapper(omp::LoopWrapperInterface opInst, LLVM::ModuleTranslation &moduleTranslation)
Helper function to map block arguments defined by ignored loop wrappers to LLVM values and prevent an...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
RAII object calling stackPush/stackPop on construction/destruction.