25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Frontend/OpenMP/OMPConstants.h"
29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30 #include "llvm/IR/DebugInfoMetadata.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/ReplaceConstant.h"
33 #include "llvm/Support/FileSystem.h"
34 #include "llvm/TargetParser/Triple.h"
35 #include "llvm/Transforms/Utils/ModuleUtils.h"
47 static llvm::omp::ScheduleKind
48 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
49 if (!schedKind.has_value())
50 return llvm::omp::OMP_SCHEDULE_Default;
51 switch (schedKind.value()) {
52 case omp::ClauseScheduleKind::Static:
53 return llvm::omp::OMP_SCHEDULE_Static;
54 case omp::ClauseScheduleKind::Dynamic:
55 return llvm::omp::OMP_SCHEDULE_Dynamic;
56 case omp::ClauseScheduleKind::Guided:
57 return llvm::omp::OMP_SCHEDULE_Guided;
58 case omp::ClauseScheduleKind::Auto:
59 return llvm::omp::OMP_SCHEDULE_Auto;
61 return llvm::omp::OMP_SCHEDULE_Runtime;
63 llvm_unreachable(
"unhandled schedule clause argument");
68 class OpenMPAllocaStackFrame
73 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
74 : allocaInsertPoint(allocaIP) {}
75 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80 class OpenMPVarMappingStackFrame
82 OpenMPVarMappingStackFrame> {
86 explicit OpenMPVarMappingStackFrame(
109 class PreviouslyReportedError
110 :
public llvm::ErrorInfo<PreviouslyReportedError> {
112 void log(raw_ostream &)
const override {
116 std::error_code convertToErrorCode()
const override {
118 "PreviouslyReportedError doesn't support ECError conversion");
132 SymbolRefAttr symbolName) {
133 omp::PrivateClauseOp privatizer =
134 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
136 assert(privatizer &&
"privatizer not found in the symbol table");
147 auto todo = [&op](StringRef clauseName) {
148 return op.
emitError() <<
"not yet implemented: Unhandled clause "
149 << clauseName <<
" in " << op.
getName()
153 auto checkAligned = [&todo](
auto op, LogicalResult &result) {
154 if (!op.getAlignedVars().empty() || op.getAlignments())
155 result = todo(
"aligned");
157 auto checkAllocate = [&todo](
auto op, LogicalResult &result) {
158 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
159 result = todo(
"allocate");
161 auto checkBare = [&todo](
auto op, LogicalResult &result) {
163 result = todo(
"ompx_bare");
165 auto checkDepend = [&todo](
auto op, LogicalResult &result) {
166 if (!op.getDependVars().empty() || op.getDependKinds())
167 result = todo(
"depend");
169 auto checkDevice = [&todo](
auto op, LogicalResult &result) {
171 result = todo(
"device");
173 auto checkHasDeviceAddr = [&todo](
auto op, LogicalResult &result) {
174 if (!op.getHasDeviceAddrVars().empty())
175 result = todo(
"has_device_addr");
177 auto checkHint = [](
auto op, LogicalResult &) {
181 auto checkIf = [&todo](
auto op, LogicalResult &result) {
185 auto checkInReduction = [&todo](
auto op, LogicalResult &result) {
186 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
187 op.getInReductionSyms())
188 result = todo(
"in_reduction");
190 auto checkIsDevicePtr = [&todo](
auto op, LogicalResult &result) {
191 if (!op.getIsDevicePtrVars().empty())
192 result = todo(
"is_device_ptr");
194 auto checkLinear = [&todo](
auto op, LogicalResult &result) {
195 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
196 result = todo(
"linear");
198 auto checkNontemporal = [&todo](
auto op, LogicalResult &result) {
199 if (!op.getNontemporalVars().empty())
200 result = todo(
"nontemporal");
202 auto checkNowait = [&todo](
auto op, LogicalResult &result) {
204 result = todo(
"nowait");
206 auto checkOrder = [&todo](
auto op, LogicalResult &result) {
207 if (op.getOrder() || op.getOrderMod())
208 result = todo(
"order");
210 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &result) {
211 if (op.getParLevelSimd())
212 result = todo(
"parallelization-level");
214 auto checkPriority = [&todo](
auto op, LogicalResult &result) {
215 if (op.getPriority())
216 result = todo(
"priority");
218 auto checkPrivate = [&todo](
auto op, LogicalResult &result) {
219 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
220 result = todo(
"privatization");
222 auto checkReduction = [&todo](
auto op, LogicalResult &result) {
223 if (!op.getReductionVars().empty() || op.getReductionByref() ||
224 op.getReductionSyms())
225 result = todo(
"reduction");
227 auto checkThreadLimit = [&todo](
auto op, LogicalResult &result) {
228 if (op.getThreadLimit())
229 result = todo(
"thread_limit");
231 auto checkTaskReduction = [&todo](
auto op, LogicalResult &result) {
232 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
233 op.getTaskReductionSyms())
234 result = todo(
"task_reduction");
236 auto checkUntied = [&todo](
auto op, LogicalResult &result) {
238 result = todo(
"untied");
241 LogicalResult result = success();
243 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
244 .Case([&](omp::SectionsOp op) {
245 checkAllocate(op, result);
246 checkPrivate(op, result);
248 .Case([&](omp::SingleOp op) {
249 checkAllocate(op, result);
250 checkPrivate(op, result);
252 .Case([&](omp::TeamsOp op) {
253 checkAllocate(op, result);
254 checkPrivate(op, result);
255 checkReduction(op, result);
257 .Case([&](omp::TaskOp op) {
258 checkAllocate(op, result);
259 checkInReduction(op, result);
260 checkPriority(op, result);
262 .Case([&](omp::TaskgroupOp op) {
263 checkAllocate(op, result);
264 checkTaskReduction(op, result);
266 .Case([&](omp::TaskwaitOp op) {
267 checkDepend(op, result);
268 checkNowait(op, result);
270 .Case([&](omp::TaskloopOp op) {
272 checkUntied(op, result);
274 .Case([&](omp::WsloopOp op) {
275 checkAllocate(op, result);
276 checkLinear(op, result);
277 checkOrder(op, result);
279 .Case([&](omp::ParallelOp op) { checkAllocate(op, result); })
280 .Case([&](omp::SimdOp op) {
281 checkAligned(op, result);
282 checkLinear(op, result);
283 checkNontemporal(op, result);
284 checkPrivate(op, result);
285 checkReduction(op, result);
287 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
288 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op, result); })
289 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
290 [&](
auto op) { checkDepend(op, result); })
291 .Case([&](omp::TargetOp op) {
292 checkAllocate(op, result);
293 checkBare(op, result);
294 checkDevice(op, result);
295 checkHasDeviceAddr(op, result);
297 checkInReduction(op, result);
298 checkIsDevicePtr(op, result);
302 if (std::optional<ArrayAttr> privateSyms = op.getPrivateSyms()) {
303 for (
Attribute privatizerNameAttr : *privateSyms) {
305 op.getOperation(), cast<SymbolRefAttr>(privatizerNameAttr));
307 if (privatizer.getDataSharingType() ==
308 omp::DataSharingClauseType::FirstPrivate)
309 result = todo(
"firstprivate");
312 checkThreadLimit(op, result);
322 LogicalResult result = success();
324 llvm::handleAllErrors(
326 [&](
const PreviouslyReportedError &) { result = failure(); },
327 [&](
const llvm::ErrorInfoBase &err) {
334 template <
typename T>
344 static llvm::OpenMPIRBuilder::InsertPointTy
350 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
352 [&](
const OpenMPAllocaStackFrame &frame) {
353 allocaInsertPoint = frame.allocaInsertPoint;
357 return allocaInsertPoint;
366 if (builder.GetInsertBlock() ==
367 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
368 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
369 "Assuming end of basic block");
370 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
371 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
372 builder.GetInsertBlock()->getNextNode());
373 builder.CreateBr(entryBB);
374 builder.SetInsertPoint(entryBB);
377 llvm::BasicBlock &funcEntryBlock =
378 builder.GetInsertBlock()->getParent()->getEntryBlock();
379 return llvm::OpenMPIRBuilder::InsertPointTy(
380 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
389 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
392 llvm::BasicBlock *continuationBlock =
393 splitBB(builder,
true,
"omp.region.cont");
394 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
396 llvm::LLVMContext &llvmContext = builder.getContext();
397 for (
Block &bb : region) {
398 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
399 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
400 builder.GetInsertBlock()->getNextNode());
401 moduleTranslation.
mapBlock(&bb, llvmBB);
404 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
410 bool operandsProcessed =
false;
411 unsigned numYields = 0;
413 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
414 if (!operandsProcessed) {
415 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
416 continuationBlockPHITypes.push_back(
417 moduleTranslation.
convertType(yield->getOperand(i).getType()));
419 operandsProcessed =
true;
421 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
422 "mismatching number of values yielded from the region");
423 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
424 llvm::Type *operandType =
425 moduleTranslation.
convertType(yield->getOperand(i).getType());
427 assert(continuationBlockPHITypes[i] == operandType &&
428 "values of mismatching types yielded from the region");
437 if (!continuationBlockPHITypes.empty())
439 continuationBlockPHIs &&
440 "expected continuation block PHIs if converted regions yield values");
441 if (continuationBlockPHIs) {
442 llvm::IRBuilderBase::InsertPointGuard guard(builder);
443 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
444 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
445 for (llvm::Type *ty : continuationBlockPHITypes)
446 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
452 for (
Block *bb : blocks) {
453 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
456 if (bb->isEntryBlock()) {
457 assert(sourceTerminator->getNumSuccessors() == 1 &&
458 "provided entry block has multiple successors");
459 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
460 "ContinuationBlock is not the successor of the entry block");
461 sourceTerminator->setSuccessor(0, llvmBB);
464 llvm::IRBuilderBase::InsertPointGuard guard(builder);
466 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
467 return llvm::make_error<PreviouslyReportedError>();
476 Operation *terminator = bb->getTerminator();
477 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
478 builder.CreateBr(continuationBlock);
480 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
481 (*continuationBlockPHIs)[i]->addIncoming(
495 return continuationBlock;
501 case omp::ClauseProcBindKind::Close:
502 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
503 case omp::ClauseProcBindKind::Master:
504 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
505 case omp::ClauseProcBindKind::Primary:
506 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
507 case omp::ClauseProcBindKind::Spread:
508 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
510 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
528 for (
auto [arg, var] : llvm::zip_equal(blockArgs, operands))
533 .Case([&](omp::SimdOp op) {
534 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
535 forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
536 forwardArgs(blockArgIface.getReductionBlockArgs(),
537 op.getReductionVars());
538 op.emitWarning() <<
"simd information on composite construct discarded";
542 return op->emitError() <<
"cannot ignore nested wrapper";
554 omp::LoopWrapperInterface parentOp,
557 loopOp.gatherWrappers(wrappers);
561 std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
562 it != wrappers.rend(); ++it) {
574 auto maskedOp = cast<omp::MaskedOp>(opInst);
575 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
580 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
582 auto ®ion = maskedOp.getRegion();
583 builder.restoreIP(codeGenIP);
591 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
593 llvm::Value *filterVal =
nullptr;
594 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
595 filterVal = moduleTranslation.
lookupValue(filterVar);
597 llvm::LLVMContext &llvmContext = builder.getContext();
601 assert(filterVal !=
nullptr);
602 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
603 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
610 builder.restoreIP(*afterIP);
618 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
619 auto masterOp = cast<omp::MasterOp>(opInst);
624 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
626 auto ®ion = masterOp.getRegion();
627 builder.restoreIP(codeGenIP);
635 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
637 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
638 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
645 builder.restoreIP(*afterIP);
653 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
654 auto criticalOp = cast<omp::CriticalOp>(opInst);
659 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
661 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
662 builder.restoreIP(codeGenIP);
670 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
672 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
673 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
674 llvm::Constant *hint =
nullptr;
677 if (criticalOp.getNameAttr()) {
680 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
681 auto criticalDeclareOp =
682 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
686 static_cast<int>(criticalDeclareOp.getHint()));
688 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
690 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
695 builder.restoreIP(*afterIP);
704 std::optional<ArrayAttr> attr = op.getPrivateSyms();
708 privatizations.reserve(privatizations.size() + attr->size());
709 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
715 template <
typename T>
719 std::optional<ArrayAttr> attr = op.getReductionSyms();
723 reductions.reserve(reductions.size() + op.getNumReductionVars());
724 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
725 reductions.push_back(
726 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
737 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
745 if (llvm::hasSingleElement(region)) {
746 llvm::Instruction *potentialTerminator =
747 builder.GetInsertBlock()->empty() ? nullptr
748 : &builder.GetInsertBlock()->back();
750 if (potentialTerminator && potentialTerminator->isTerminator())
751 potentialTerminator->removeFromParent();
752 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
755 region.
front(),
true, builder)))
759 if (continuationBlockArgs)
761 *continuationBlockArgs,
768 if (potentialTerminator && potentialTerminator->isTerminator()) {
769 llvm::BasicBlock *block = builder.GetInsertBlock();
770 if (block->empty()) {
776 potentialTerminator->insertInto(block, block->begin());
778 potentialTerminator->insertAfter(&block->back());
792 if (continuationBlockArgs)
793 llvm::append_range(*continuationBlockArgs, phis);
794 builder.SetInsertPoint(*continuationBlock,
795 (*continuationBlock)->getFirstInsertionPt());
802 using OwningReductionGen =
803 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
804 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
806 using OwningAtomicReductionGen =
807 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
808 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
815 static OwningReductionGen
821 OwningReductionGen gen =
822 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
823 llvm::Value *lhs, llvm::Value *rhs,
824 llvm::Value *&result)
mutable
825 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
826 moduleTranslation.
mapValue(decl.getReductionLhsArg(), lhs);
827 moduleTranslation.
mapValue(decl.getReductionRhsArg(), rhs);
828 builder.restoreIP(insertPoint);
831 "omp.reduction.nonatomic.body", builder,
832 moduleTranslation, &phis)))
833 return llvm::createStringError(
834 "failed to inline `combiner` region of `omp.declare_reduction`");
835 assert(phis.size() == 1);
837 return builder.saveIP();
846 static OwningAtomicReductionGen
848 llvm::IRBuilderBase &builder,
850 if (decl.getAtomicReductionRegion().empty())
851 return OwningAtomicReductionGen();
856 OwningAtomicReductionGen atomicGen =
857 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
858 llvm::Value *lhs, llvm::Value *rhs)
mutable
859 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
860 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(), lhs);
861 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(), rhs);
862 builder.restoreIP(insertPoint);
865 "omp.reduction.atomic.body", builder,
866 moduleTranslation, &phis)))
867 return llvm::createStringError(
868 "failed to inline `atomic` region of `omp.declare_reduction`");
869 assert(phis.empty());
870 return builder.saveIP();
879 auto orderedOp = cast<omp::OrderedOp>(opInst);
884 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
885 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
886 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
888 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
890 size_t indexVecValues = 0;
891 while (indexVecValues < vecValues.size()) {
893 storeValues.reserve(numLoops);
894 for (
unsigned i = 0; i < numLoops; i++) {
895 storeValues.push_back(vecValues[indexVecValues]);
898 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
900 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
901 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
902 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
912 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
913 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
918 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
920 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
921 builder.restoreIP(codeGenIP);
929 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
931 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
932 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
934 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
939 builder.restoreIP(*afterIP);
945 struct DeferredStore {
946 DeferredStore(llvm::Value *value, llvm::Value *address)
947 : value(value), address(address) {}
950 llvm::Value *address;
957 template <
typename T>
960 llvm::IRBuilderBase &builder,
962 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
968 llvm::IRBuilderBase::InsertPointGuard guard(builder);
969 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
972 deferredStores.reserve(loop.getNumReductionVars());
974 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
975 Region &allocRegion = reductionDecls[i].getAllocRegion();
977 if (allocRegion.
empty())
982 builder, moduleTranslation, &phis)))
983 return loop.emitError(
984 "failed to inline `alloc` region of `omp.declare_reduction`");
986 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
987 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
991 llvm::Value *var = builder.CreateAlloca(
992 moduleTranslation.
convertType(reductionDecls[i].getType()));
993 deferredStores.emplace_back(phis[0], var);
995 privateReductionVariables[i] = var;
996 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
997 reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
999 assert(allocRegion.
empty() &&
1000 "allocaction is implicit for by-val reduction");
1001 llvm::Value *var = builder.CreateAlloca(
1002 moduleTranslation.
convertType(reductionDecls[i].getType()));
1003 moduleTranslation.
mapValue(reductionArgs[i], var);
1004 privateReductionVariables[i] = var;
1005 reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
1013 template <
typename T>
1020 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1021 Region &initializerRegion = reduction.getInitializerRegion();
1024 mlir::Value mlirSource = loop.getReductionVars()[i];
1025 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1026 assert(llvmSource &&
"lookup reduction var");
1027 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), llvmSource);
1030 llvm::Value *allocation =
1031 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1032 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1036 template <
typename OP>
1037 static LogicalResult
1039 llvm::IRBuilderBase &builder,
1041 llvm::BasicBlock *latestAllocaBlock,
1047 if (op.getNumReductionVars() == 0)
1050 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1052 builder.SetInsertPoint(latestAllocaBlock->getTerminator());
1053 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1054 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1055 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1056 builder.restoreIP(allocaIP);
1059 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1061 if (!reductionDecls[i].getAllocRegion().empty())
1067 byRefVars[i] = builder.CreateAlloca(
1068 moduleTranslation.
convertType(reductionDecls[i].getType()));
1072 builder.SetInsertPoint(&*initBlock->getFirstNonPHIOrDbgOrAlloca());
1076 for (
auto [data, addr] : deferredStores)
1077 builder.CreateStore(data, addr);
1082 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1087 reductionVariableMap, i);
1090 "omp.reduction.neutral", builder,
1091 moduleTranslation, &phis)))
1094 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1095 "reduction neutral element declaration region");
1097 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1100 if (!reductionDecls[i].getAllocRegion().empty())
1109 builder.CreateStore(phis[0], byRefVars[i]);
1111 privateReductionVariables[i] = byRefVars[i];
1112 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1113 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1116 builder.CreateStore(phis[0], privateReductionVariables[i]);
1123 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1130 template <
typename T>
1132 T loop, llvm::IRBuilderBase &builder,
1139 unsigned numReductions = loop.getNumReductionVars();
1141 for (
unsigned i = 0; i < numReductions; ++i) {
1142 owningReductionGens.push_back(
1144 owningAtomicReductionGens.push_back(
1149 reductionInfos.reserve(numReductions);
1150 for (
unsigned i = 0; i < numReductions; ++i) {
1151 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen =
nullptr;
1152 if (owningAtomicReductionGens[i])
1153 atomicGen = owningAtomicReductionGens[i];
1154 llvm::Value *variable =
1155 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1156 reductionInfos.push_back(
1158 privateReductionVariables[i],
1159 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1160 owningReductionGens[i],
1161 nullptr, atomicGen});
1166 static LogicalResult
1170 llvm::IRBuilderBase &builder, StringRef regionName,
1171 bool shouldLoadCleanupRegionArg =
true) {
1173 if (cleanupRegion->empty())
1179 llvm::Instruction *potentialTerminator =
1180 builder.GetInsertBlock()->empty() ? nullptr
1181 : &builder.GetInsertBlock()->back();
1182 if (potentialTerminator && potentialTerminator->isTerminator())
1183 builder.SetInsertPoint(potentialTerminator);
1184 llvm::Value *privateVarValue =
1185 shouldLoadCleanupRegionArg
1186 ? builder.CreateLoad(
1188 privateVariables[i])
1189 : privateVariables[i];
1194 moduleTranslation)))
1207 OP op, llvm::IRBuilderBase &builder,
1209 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1213 if (op.getNumReductionVars() == 0)
1224 owningReductionGens, owningAtomicReductionGens,
1225 privateReductionVariables, reductionInfos);
1230 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1231 builder.SetInsertPoint(tempTerminator);
1232 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1233 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1234 isByRef, op.getNowait());
1239 if (!contInsertPoint->getBlock())
1240 return op->emitOpError() <<
"failed to convert reductions";
1242 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1243 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1248 tempTerminator->eraseFromParent();
1249 builder.restoreIP(*afterIP);
1253 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1254 [](omp::DeclareReductionOp reductionDecl) {
1255 return &reductionDecl.getCleanupRegion();
1258 moduleTranslation, builder,
1259 "omp.reduction.cleanup");
1270 template <
typename OP>
1274 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1279 if (op.getNumReductionVars() == 0)
1282 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1286 allocaIP, reductionDecls,
1287 privateReductionVariables, reductionVariableMap,
1288 deferredStores, isByRef)))
1292 allocaIP.getBlock(), reductionDecls,
1293 privateReductionVariables, reductionVariableMap,
1294 isByRef, deferredStores);
1304 static llvm::Value *
1308 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1311 Value blockArg = (*mappedPrivateVars)[privateVar];
1314 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1315 "A block argument corresponding to a mapped var should have "
1318 if (privVarType == blockArgType)
1325 if (!isa<LLVM::LLVMPointerType>(privVarType))
1326 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1342 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1344 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1346 llvm::BranchInst *allocaTerminator =
1347 llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
1348 builder.SetInsertPoint(allocaTerminator);
1349 assert(allocaTerminator->getNumSuccessors() == 1 &&
1350 "This is an unconditional branch created by OpenMPIRBuilder");
1351 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1365 llvm::BasicBlock *privAllocBlock =
nullptr;
1366 if (!privateBlockArgs.empty())
1367 privAllocBlock = splitBB(builder,
true,
"omp.private.latealloc");
1368 for (
auto [privDecl, mlirPrivVar, blockArg] :
1369 llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1370 Region &allocRegion = privDecl.getAllocRegion();
1374 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1375 assert(nonPrivateVar);
1376 moduleTranslation.
mapValue(privDecl.getAllocMoldArg(), nonPrivateVar);
1380 if (privDecl.getAllocMoldArg().getUses().empty()) {
1385 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1387 builder.SetInsertPoint(privAllocBlock->getTerminator());
1391 builder, moduleTranslation, &phis)))
1392 return llvm::createStringError(
1393 "failed to inline `alloc` region of `omp.private`");
1395 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1397 moduleTranslation.
mapValue(blockArg, phis[0]);
1398 llvmPrivateVars.push_back(phis[0]);
1405 return afterAllocas;
1408 static LogicalResult
1414 llvm::BasicBlock *afterAllocas) {
1415 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1417 bool needsFirstprivate =
1418 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1419 return privOp.getDataSharingType() ==
1420 omp::DataSharingClauseType::FirstPrivate;
1423 if (!needsFirstprivate)
1426 assert(afterAllocas->getSinglePredecessor());
1429 builder.SetInsertPoint(afterAllocas->getSinglePredecessor()->getTerminator());
1430 llvm::BasicBlock *copyBlock =
1431 splitBB(builder,
true,
"omp.private.copy");
1432 builder.SetInsertPoint(copyBlock->getFirstNonPHIOrDbgOrAlloca());
1434 for (
auto [decl, mlirVar, llvmVar] :
1435 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1436 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1440 Region ©Region = decl.getCopyRegion();
1443 llvm::Value *nonPrivateVar = moduleTranslation.
lookupValue(mlirVar);
1444 assert(nonPrivateVar);
1445 moduleTranslation.
mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1448 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1451 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1453 moduleTranslation)))
1454 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1467 static LogicalResult
1474 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1475 [](omp::PrivateClauseOp privatizer) {
1476 return &privatizer.getDeallocRegion();
1480 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1481 "omp.private.dealloc",
false)))
1482 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1483 "`omp.private` op in");
1488 static LogicalResult
1491 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1492 using StorableBodyGenCallbackTy =
1493 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1495 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1501 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1505 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1509 sectionsOp.getNumReductionVars());
1513 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1516 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1517 reductionDecls, privateReductionVariables, reductionVariableMap,
1525 moduleTranslation, reductionVariableMap);
1530 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1534 Region ®ion = sectionOp.getRegion();
1535 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1536 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1537 builder.restoreIP(codeGenIP);
1544 sectionsOp.getRegion().getNumArguments());
1545 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1546 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1547 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1549 moduleTranslation.
mapValue(sectionArg, llvmVal);
1556 sectionCBs.push_back(sectionCB);
1562 if (sectionCBs.empty())
1565 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1570 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1571 llvm::Value &vPtr, llvm::Value *&replacementValue)
1572 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1573 replacementValue = &vPtr;
1579 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1582 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1583 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1585 ompLoc, allocaIP, sectionCBs, privCB, finiCB,
false,
1586 sectionsOp.getNowait());
1591 builder.restoreIP(*afterIP);
1595 allocaIP, reductionDecls,
1596 privateReductionVariables, isByRef);
1600 static LogicalResult
1603 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1604 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1609 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1610 builder.restoreIP(codegenIP);
1612 builder, moduleTranslation)
1615 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1619 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1622 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1623 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1624 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1625 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1626 llvmCPFuncs.push_back(
1630 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1632 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1638 builder.restoreIP(*afterIP);
1643 static LogicalResult
1646 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1650 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1652 moduleTranslation, allocaIP);
1653 builder.restoreIP(codegenIP);
1659 llvm::Value *numTeamsLower =
nullptr;
1660 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
1661 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
1663 llvm::Value *numTeamsUpper =
nullptr;
1664 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
1665 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
1667 llvm::Value *threadLimit =
nullptr;
1668 if (
Value threadLimitVar = op.getThreadLimit())
1669 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
1671 llvm::Value *ifExpr =
nullptr;
1672 if (
Value ifVar = op.getIfExpr())
1675 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1676 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1678 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1683 builder.restoreIP(*afterIP);
1691 if (dependVars.empty())
1693 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
1694 llvm::omp::RTLDependenceKindTy type;
1696 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
1697 case mlir::omp::ClauseTaskDepend::taskdependin:
1698 type = llvm::omp::RTLDependenceKindTy::DepIn;
1703 case mlir::omp::ClauseTaskDepend::taskdependout:
1704 case mlir::omp::ClauseTaskDepend::taskdependinout:
1705 type = llvm::omp::RTLDependenceKindTy::DepInOut;
1708 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
1709 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
1710 dds.emplace_back(dd);
1715 static LogicalResult
1718 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1724 cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs();
1728 mlirPrivateVars.reserve(privateBlockArgs.size());
1729 llvmPrivateVars.reserve(privateBlockArgs.size());
1731 for (
mlir::Value privateVar : taskOp.getPrivateVars())
1732 mlirPrivateVars.push_back(privateVar);
1734 auto bodyCB = [&](InsertPointTy allocaIP,
1735 InsertPointTy codegenIP) -> llvm::Error {
1739 moduleTranslation, allocaIP);
1742 builder, moduleTranslation, privateBlockArgs, privateDecls,
1743 mlirPrivateVars, llvmPrivateVars, allocaIP);
1745 return llvm::make_error<PreviouslyReportedError>();
1748 llvmPrivateVars, privateDecls,
1749 afterAllocas.get())))
1750 return llvm::make_error<PreviouslyReportedError>();
1753 builder.restoreIP(codegenIP);
1755 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
1756 if (failed(
handleError(continuationBlockOrError, *taskOp)))
1757 return llvm::make_error<PreviouslyReportedError>();
1759 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
1762 llvmPrivateVars, privateDecls)))
1763 return llvm::make_error<PreviouslyReportedError>();
1765 return llvm::Error::success();
1770 moduleTranslation, dds);
1772 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1774 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1775 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1777 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
1779 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
1780 taskOp.getMergeable(),
1781 moduleTranslation.
lookupValue(taskOp.getEventHandle()));
1786 builder.restoreIP(*afterIP);
1791 static LogicalResult
1794 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1798 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1799 builder.restoreIP(codegenIP);
1801 builder, moduleTranslation)
1806 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1807 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1814 builder.restoreIP(*afterIP);
1818 static LogicalResult
1829 static LogicalResult
1832 auto wsloopOp = cast<omp::WsloopOp>(opInst);
1836 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
1838 assert(isByRef.size() == wsloopOp.getNumReductionVars());
1842 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
1845 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
1846 llvm::Type *ivType = step->getType();
1847 llvm::Value *chunk =
nullptr;
1848 if (wsloopOp.getScheduleChunk()) {
1849 llvm::Value *chunkVar =
1850 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
1851 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
1855 cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs();
1859 mlirPrivateVars.reserve(privateBlockArgs.size());
1860 llvmPrivateVars.reserve(privateBlockArgs.size());
1863 for (
mlir::Value privateVar : wsloopOp.getPrivateVars())
1864 mlirPrivateVars.push_back(privateVar);
1868 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1872 wsloopOp.getNumReductionVars());
1874 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(
1875 allocaIP.getBlock(),
1876 allocaIP.getBlock()->getTerminator()->getIterator()),
1877 true,
"omp.region.after_alloca");
1880 builder, moduleTranslation, privateBlockArgs, privateDecls,
1881 mlirPrivateVars, llvmPrivateVars, allocaIP);
1888 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1893 moduleTranslation, allocaIP, reductionDecls,
1894 privateReductionVariables, reductionVariableMap,
1895 deferredStores, isByRef)))
1899 llvmPrivateVars, privateDecls,
1900 afterAllocas.get())))
1903 assert(afterAllocas.get()->getSinglePredecessor());
1906 afterAllocas.get()->getSinglePredecessor(),
1907 reductionDecls, privateReductionVariables,
1908 reductionVariableMap, isByRef, deferredStores)))
1922 moduleTranslation, reductionVariableMap);
1925 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1930 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
1931 llvm::Value *iv) -> llvm::Error {
1934 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1939 bodyInsertPoints.push_back(ip);
1941 if (loopInfos.size() != loopOp.getNumLoops() - 1)
1942 return llvm::Error::success();
1945 builder.restoreIP(ip);
1957 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1958 llvm::Value *lowerBound =
1959 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
1960 llvm::Value *upperBound =
1961 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
1962 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
1967 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1968 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1970 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
1971 computeIP = loopInfos.front()->getPreheaderIP();
1975 ompBuilder->createCanonicalLoop(
1976 loc, bodyGen, lowerBound, upperBound, step,
1977 true, loopOp.getLoopInclusive(), computeIP);
1982 loopInfos.push_back(*loopResult);
1987 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
1988 llvm::CanonicalLoopInfo *loopInfo =
1989 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
1994 bool isOrdered = wsloopOp.getOrdered().has_value();
1995 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
1996 bool isSimd = wsloopOp.getScheduleSimd();
1998 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
1999 ompBuilder->applyWorkshareLoop(
2000 ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
2001 convertToScheduleKind(schedule), chunk, isSimd,
2002 scheduleMod == omp::ScheduleModifier::monotonic,
2003 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
2012 builder.restoreIP(afterIP);
2016 allocaIP, reductionDecls,
2017 privateReductionVariables, isByRef)))
2021 llvmPrivateVars, privateDecls);
2025 static LogicalResult
2028 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2030 assert(isByRef.size() == opInst.getNumReductionVars());
2038 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs();
2042 mlirPrivateVars.reserve(privateBlockArgs.size());
2043 llvmPrivateVars.reserve(privateBlockArgs.size());
2045 for (
mlir::Value privateVar : opInst.getPrivateVars())
2046 mlirPrivateVars.push_back(privateVar);
2052 opInst.getNumReductionVars());
2055 auto bodyGenCB = [&](InsertPointTy allocaIP,
2056 InsertPointTy codeGenIP) -> llvm::Error {
2058 builder, moduleTranslation, privateBlockArgs, privateDecls,
2059 mlirPrivateVars, llvmPrivateVars, allocaIP);
2061 return llvm::make_error<PreviouslyReportedError>();
2067 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2070 InsertPointTy(allocaIP.getBlock(),
2071 allocaIP.getBlock()->getTerminator()->getIterator());
2074 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2075 reductionDecls, privateReductionVariables, reductionVariableMap,
2076 deferredStores, isByRef)))
2077 return llvm::make_error<PreviouslyReportedError>();
2080 llvmPrivateVars, privateDecls,
2081 afterAllocas.get())))
2082 return llvm::make_error<PreviouslyReportedError>();
2084 assert(afterAllocas.get()->getSinglePredecessor());
2087 afterAllocas.get()->getSinglePredecessor(),
2088 reductionDecls, privateReductionVariables,
2089 reductionVariableMap, isByRef, deferredStores)))
2090 return llvm::make_error<PreviouslyReportedError>();
2096 moduleTranslation, reductionVariableMap);
2101 moduleTranslation, allocaIP);
2104 builder.restoreIP(codeGenIP);
2106 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
2108 return regionBlock.takeError();
2111 if (opInst.getNumReductionVars() > 0) {
2117 owningReductionGens, owningAtomicReductionGens,
2118 privateReductionVariables, reductionInfos);
2121 builder.SetInsertPoint((*regionBlock)->getTerminator());
2124 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2125 builder.SetInsertPoint(tempTerminator);
2127 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2128 ompBuilder->createReductions(builder.saveIP(), allocaIP,
2129 reductionInfos, isByRef,
false);
2130 if (!contInsertPoint)
2131 return contInsertPoint.takeError();
2133 if (!contInsertPoint->getBlock())
2134 return llvm::make_error<PreviouslyReportedError>();
2136 tempTerminator->eraseFromParent();
2137 builder.restoreIP(*contInsertPoint);
2139 return llvm::Error::success();
2142 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2143 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2152 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2153 InsertPointTy oldIP = builder.saveIP();
2154 builder.restoreIP(codeGenIP);
2159 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2160 [](omp::DeclareReductionOp reductionDecl) {
2161 return &reductionDecl.getCleanupRegion();
2164 reductionCleanupRegions, privateReductionVariables,
2165 moduleTranslation, builder,
"omp.reduction.cleanup")))
2166 return llvm::createStringError(
2167 "failed to inline `cleanup` region of `omp.declare_reduction`");
2170 llvmPrivateVars, privateDecls)))
2171 return llvm::make_error<PreviouslyReportedError>();
2173 builder.restoreIP(oldIP);
2174 return llvm::Error::success();
2177 llvm::Value *ifCond =
nullptr;
2178 if (
auto ifVar = opInst.getIfExpr())
2180 llvm::Value *numThreads =
nullptr;
2181 if (
auto numThreadsVar = opInst.getNumThreads())
2182 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
2183 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2184 if (
auto bind = opInst.getProcBindKind())
2187 bool isCancellable =
false;
2189 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2191 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2193 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2194 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2195 ifCond, numThreads, pbKind, isCancellable);
2200 builder.restoreIP(*afterIP);
2205 static llvm::omp::OrderKind
2208 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2210 case omp::ClauseOrderKind::Concurrent:
2211 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2213 llvm_unreachable(
"Unknown ClauseOrderKind kind");
2217 static LogicalResult
2220 auto simdOp = cast<omp::SimdOp>(opInst);
2221 auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
2226 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2231 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2232 llvm::Value *iv) -> llvm::Error {
2235 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2240 bodyInsertPoints.push_back(ip);
2242 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2243 return llvm::Error::success();
2246 builder.restoreIP(ip);
2258 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2259 llvm::Value *lowerBound =
2260 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
2261 llvm::Value *upperBound =
2262 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
2263 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
2268 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2269 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2271 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2273 computeIP = loopInfos.front()->getPreheaderIP();
2277 ompBuilder->createCanonicalLoop(
2278 loc, bodyGen, lowerBound, upperBound, step,
2279 true,
true, computeIP);
2284 loopInfos.push_back(*loopResult);
2288 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
2289 llvm::CanonicalLoopInfo *loopInfo =
2290 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2292 llvm::ConstantInt *simdlen =
nullptr;
2293 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2294 simdlen = builder.getInt64(simdlenVar.value());
2296 llvm::ConstantInt *safelen =
nullptr;
2297 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2298 safelen = builder.getInt64(safelenVar.value());
2300 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2302 ompBuilder->applySimd(loopInfo, alignedVars,
2304 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
2306 order, simdlen, safelen);
2308 builder.restoreIP(afterIP);
2313 static llvm::AtomicOrdering
2316 return llvm::AtomicOrdering::Monotonic;
2319 case omp::ClauseMemoryOrderKind::Seq_cst:
2320 return llvm::AtomicOrdering::SequentiallyConsistent;
2321 case omp::ClauseMemoryOrderKind::Acq_rel:
2322 return llvm::AtomicOrdering::AcquireRelease;
2323 case omp::ClauseMemoryOrderKind::Acquire:
2324 return llvm::AtomicOrdering::Acquire;
2325 case omp::ClauseMemoryOrderKind::Release:
2326 return llvm::AtomicOrdering::Release;
2327 case omp::ClauseMemoryOrderKind::Relaxed:
2328 return llvm::AtomicOrdering::Monotonic;
2330 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
2334 static LogicalResult
2337 auto readOp = cast<omp::AtomicReadOp>(opInst);
2343 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2346 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
2347 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
2349 llvm::Type *elementType =
2350 moduleTranslation.
convertType(readOp.getElementType());
2352 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
2353 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
2354 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
2359 static LogicalResult
2362 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
2368 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2370 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
2371 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
2372 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
2373 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
2375 builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
2383 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
2384 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
2385 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
2386 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
2387 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
2388 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
2389 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
2390 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
2391 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
2392 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
2396 static LogicalResult
2398 llvm::IRBuilderBase &builder,
2405 auto &innerOpList = opInst.getRegion().front().getOperations();
2406 bool isXBinopExpr{
false};
2407 llvm::AtomicRMWInst::BinOp binop;
2409 llvm::Value *llvmExpr =
nullptr;
2410 llvm::Value *llvmX =
nullptr;
2411 llvm::Type *llvmXElementType =
nullptr;
2412 if (innerOpList.size() == 2) {
2418 opInst.getRegion().getArgument(0))) {
2419 return opInst.emitError(
"no atomic update operation with region argument"
2420 " as operand found inside atomic.update region");
2423 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
2425 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
2429 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2431 llvmX = moduleTranslation.
lookupValue(opInst.getX());
2433 opInst.getRegion().getArgument(0).getType());
2434 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2438 llvm::AtomicOrdering atomicOrdering =
2443 [&opInst, &moduleTranslation](
2444 llvm::Value *atomicx,
2447 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
2448 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
2449 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
2450 return llvm::make_error<PreviouslyReportedError>();
2452 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
2453 assert(yieldop && yieldop.getResults().size() == 1 &&
2454 "terminator must be omp.yield op and it must have exactly one "
2456 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
2461 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2462 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2463 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
2464 atomicOrdering, binop, updateFn,
2470 builder.restoreIP(*afterIP);
2474 static LogicalResult
2476 llvm::IRBuilderBase &builder,
2483 bool isXBinopExpr =
false, isPostfixUpdate =
false;
2484 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2486 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
2487 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
2489 assert((atomicUpdateOp || atomicWriteOp) &&
2490 "internal op must be an atomic.update or atomic.write op");
2492 if (atomicWriteOp) {
2493 isPostfixUpdate =
true;
2494 mlirExpr = atomicWriteOp.getExpr();
2496 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
2497 atomicCaptureOp.getAtomicUpdateOp().getOperation();
2498 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
2501 if (innerOpList.size() == 2) {
2504 atomicUpdateOp.getRegion().getArgument(0))) {
2505 return atomicUpdateOp.emitError(
2506 "no atomic update operation with region argument"
2507 " as operand found inside atomic.update region");
2511 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
2514 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2518 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
2519 llvm::Value *llvmX =
2520 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
2521 llvm::Value *llvmV =
2522 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
2523 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
2524 atomicCaptureOp.getAtomicReadOp().getElementType());
2525 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2528 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
2532 llvm::AtomicOrdering atomicOrdering =
2536 [&](llvm::Value *atomicx,
2539 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
2540 Block &bb = *atomicUpdateOp.getRegion().
begin();
2541 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
2543 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
2544 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
2545 return llvm::make_error<PreviouslyReportedError>();
2547 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
2548 assert(yieldop && yieldop.getResults().size() == 1 &&
2549 "terminator must be omp.yield op and it must have exactly one "
2551 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
2556 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2557 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2558 ompBuilder->createAtomicCapture(
2559 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
2560 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
2562 if (failed(
handleError(afterIP, *atomicCaptureOp)))
2565 builder.restoreIP(*afterIP);
2571 static LogicalResult
2574 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2575 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
2580 Value symAddr = threadprivateOp.getSymAddr();
2582 if (!isa<LLVM::AddressOfOp>(symOp))
2583 return opInst.
emitError(
"Addressing symbol not found");
2584 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
2586 LLVM::GlobalOp global =
2587 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
2588 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
2589 llvm::Type *type = globalValue->getValueType();
2590 llvm::TypeSize typeSize =
2591 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
2593 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
2594 llvm::StringRef suffix = llvm::StringRef(
".cache", 6);
2595 std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
2596 llvm::Value *callInst =
2598 ompLoc, globalValue, size, cacheName);
2603 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
2605 switch (deviceClause) {
2606 case mlir::omp::DeclareTargetDeviceType::host:
2607 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
2609 case mlir::omp::DeclareTargetDeviceType::nohost:
2610 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
2612 case mlir::omp::DeclareTargetDeviceType::any:
2613 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
2616 llvm_unreachable(
"unhandled device clause");
2619 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
2621 mlir::omp::DeclareTargetCaptureClause captureClause) {
2622 switch (captureClause) {
2623 case mlir::omp::DeclareTargetCaptureClause::to:
2624 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
2625 case mlir::omp::DeclareTargetCaptureClause::link:
2626 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
2627 case mlir::omp::DeclareTargetCaptureClause::enter:
2628 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
2630 llvm_unreachable(
"unhandled capture clause");
2635 llvm::OpenMPIRBuilder &ompBuilder) {
2637 llvm::raw_svector_ostream os(suffix);
2640 auto fileInfoCallBack = [&loc]() {
2641 return std::pair<std::string, uint64_t>(
2642 llvm::StringRef(loc.getFilename()), loc.getLine());
2646 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
2648 os <<
"_decl_tgt_ref_ptr";
2654 if (
auto addressOfOp =
2655 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
2656 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
2657 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
2658 if (
auto declareTargetGlobal =
2659 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
2660 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2661 mlir::omp::DeclareTargetCaptureClause::link)
2670 static llvm::Value *
2677 if (
auto addressOfOp =
2678 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.
getDefiningOp())) {
2679 if (
auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
2680 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
2681 addressOfOp.getGlobalName()))) {
2683 if (
auto declareTargetGlobal =
2684 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
2685 gOp.getOperation())) {
2689 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
2690 mlir::omp::DeclareTargetCaptureClause::link) ||
2691 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2692 mlir::omp::DeclareTargetCaptureClause::to &&
2693 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
2697 if (gOp.getSymName().contains(suffix))
2702 (gOp.getSymName().str() + suffix.str()).str());
2718 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
2730 void append(MapInfoData &CurInfo) {
2731 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
2732 CurInfo.IsDeclareTarget.end());
2733 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
2734 OriginalValue.append(CurInfo.OriginalValue.begin(),
2735 CurInfo.OriginalValue.end());
2736 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
2737 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
2743 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
2744 arrTy.getElementType()))
2760 Operation *clauseOp, llvm::Value *basePointer,
2761 llvm::Type *baseType, llvm::IRBuilderBase &builder,
2763 if (
auto memberClause =
2764 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
2769 if (!memberClause.getBounds().empty()) {
2770 llvm::Value *elementCount = builder.getInt64(1);
2771 for (
auto bounds : memberClause.getBounds()) {
2772 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2773 bounds.getDefiningOp())) {
2778 elementCount = builder.CreateMul(
2782 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
2783 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
2784 builder.getInt64(1)));
2791 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
2799 return builder.CreateMul(elementCount,
2800 builder.getInt64(underlyingTypeSzInBits / 8));
2810 llvm::IRBuilderBase &builder,
const ArrayRef<Value> &useDevPtrOperands = {},
2812 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
2820 for (
Value mapValue : mapVars) {
2821 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2822 for (
auto member : map.getMembers())
2823 if (member == mapOp)
2830 for (
Value mapValue : mapVars) {
2831 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2833 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2834 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
2835 mapData.Pointers.push_back(mapData.OriginalValue.back());
2837 if (llvm::Value *refPtr =
2839 moduleTranslation)) {
2840 mapData.IsDeclareTarget.push_back(
true);
2841 mapData.BasePointers.push_back(refPtr);
2843 mapData.IsDeclareTarget.push_back(
false);
2844 mapData.BasePointers.push_back(mapData.OriginalValue.back());
2847 mapData.BaseType.push_back(
2848 moduleTranslation.
convertType(mapOp.getVarType()));
2849 mapData.Sizes.push_back(
2850 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2851 mapData.BaseType.back(), builder, moduleTranslation));
2852 mapData.MapClause.push_back(mapOp.getOperation());
2853 mapData.Types.push_back(
2854 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2858 mapData.IsAMapping.push_back(
true);
2859 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
2862 auto findMapInfo = [&mapData](llvm::Value *val,
2863 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2866 for (llvm::Value *basePtr : mapData.OriginalValue) {
2867 if (basePtr == val && mapData.IsAMapping[index]) {
2869 mapData.Types[index] |=
2870 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2871 mapData.DevicePointers[index] = devInfoTy;
2880 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2881 for (
Value mapValue : useDevOperands) {
2882 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2884 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2885 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
2888 if (!findMapInfo(origValue, devInfoTy)) {
2889 mapData.OriginalValue.push_back(origValue);
2890 mapData.Pointers.push_back(mapData.OriginalValue.back());
2891 mapData.IsDeclareTarget.push_back(
false);
2892 mapData.BasePointers.push_back(mapData.OriginalValue.back());
2893 mapData.BaseType.push_back(
2894 moduleTranslation.
convertType(mapOp.getVarType()));
2895 mapData.Sizes.push_back(builder.getInt64(0));
2896 mapData.MapClause.push_back(mapOp.getOperation());
2897 mapData.Types.push_back(
2898 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2901 mapData.DevicePointers.push_back(devInfoTy);
2902 mapData.IsAMapping.push_back(
false);
2903 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
2908 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2909 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2913 auto *res = llvm::find(mapData.MapClause, memberOp);
2914 assert(res != mapData.MapClause.end() &&
2915 "MapInfoOp for member not found in MapData, cannot return index");
2916 return std::distance(mapData.MapClause.begin(), res);
2921 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2923 if (indexAttr.size() == 1)
2924 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
2927 std::iota(indices.begin(), indices.end(), 0);
2929 llvm::sort(indices.begin(), indices.end(),
2930 [&](
const size_t a,
const size_t b) {
2931 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2932 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
2933 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
2934 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
2935 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
2937 if (aIndex == bIndex)
2940 if (aIndex < bIndex)
2943 if (aIndex > bIndex)
2950 return memberIndicesA.size() < memberIndicesB.size();
2953 return llvm::cast<omp::MapInfoOp>(
2954 mapInfo.getMembers()[indices.front()].getDefiningOp());
2976 std::vector<llvm::Value *>
2978 llvm::IRBuilderBase &builder,
bool isArrayTy,
2980 std::vector<llvm::Value *> idx;
2991 idx.push_back(builder.getInt64(0));
2992 for (
int i = bounds.size() - 1; i >= 0; --i) {
2993 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
2994 bounds[i].getDefiningOp())) {
2995 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
3017 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3018 for (
size_t i = 1; i < bounds.size(); ++i) {
3019 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3020 bounds[i].getDefiningOp())) {
3021 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3022 moduleTranslation.
lookupValue(boundOp.getExtent()),
3023 dimensionIndexSizeOffset[i - 1]));
3031 for (
int i = bounds.size() - 1; i >= 0; --i) {
3032 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3033 bounds[i].getDefiningOp())) {
3035 idx.emplace_back(builder.CreateMul(
3036 moduleTranslation.
lookupValue(boundOp.getLowerBound()),
3037 dimensionIndexSizeOffset[i]));
3039 idx.back() = builder.CreateAdd(
3040 idx.back(), builder.CreateMul(moduleTranslation.
lookupValue(
3041 boundOp.getLowerBound()),
3042 dimensionIndexSizeOffset[i]));
3067 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl,
3068 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3069 uint64_t mapDataIndex,
bool isTargetParams) {
3071 combinedInfo.Types.emplace_back(
3073 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3074 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3075 combinedInfo.DevicePointers.emplace_back(
3076 mapData.DevicePointers[mapDataIndex]);
3078 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3079 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3089 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3091 llvm::Value *lowAddr, *highAddr;
3092 if (!parentClause.getPartialMap()) {
3093 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3094 builder.getPtrTy());
3095 highAddr = builder.CreatePointerCast(
3096 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3097 mapData.Pointers[mapDataIndex], 1),
3098 builder.getPtrTy());
3099 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3101 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3104 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3105 builder.getPtrTy());
3108 highAddr = builder.CreatePointerCast(
3109 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3110 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3111 builder.getPtrTy());
3112 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3115 llvm::Value *size = builder.CreateIntCast(
3116 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3117 builder.getInt64Ty(),
3119 combinedInfo.Sizes.push_back(size);
3121 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3122 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3130 if (!parentClause.getPartialMap()) {
3135 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3136 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3137 combinedInfo.Types.emplace_back(mapFlag);
3138 combinedInfo.DevicePointers.emplace_back(
3141 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3142 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3143 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3144 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3146 return memberOfFlag;
3158 if (mapOp.getVarPtrPtr())
3173 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl,
3174 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3175 uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
3178 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3180 for (
auto mappedMembers : parentClause.getMembers()) {
3182 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
3185 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
3195 auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
3196 memberClause.getMapType().value());
3197 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3198 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3199 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3200 combinedInfo.Types.emplace_back(mapFlag);
3201 combinedInfo.DevicePointers.emplace_back(
3203 combinedInfo.Names.emplace_back(
3205 combinedInfo.BasePointers.emplace_back(
3206 mapData.BasePointers[mapDataIndex]);
3207 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
3208 combinedInfo.Sizes.emplace_back(builder.getInt64(
3209 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
3215 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
3216 mapFlag &= ~
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3217 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3218 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3220 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3222 combinedInfo.Types.emplace_back(mapFlag);
3223 combinedInfo.DevicePointers.emplace_back(
3224 mapData.DevicePointers[memberDataIdx]);
3225 combinedInfo.Names.emplace_back(
3227 uint64_t basePointerIndex =
3229 combinedInfo.BasePointers.emplace_back(
3230 mapData.BasePointers[basePointerIndex]);
3231 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
3232 combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
3238 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
3239 bool isTargetParams,
int mapDataParentIdx = -1) {
3243 auto mapFlag = mapData.Types[mapDataIdx];
3244 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
3248 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3250 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
3251 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3253 if (mapInfoOp.getMapCaptureType().value() ==
3254 omp::VariableCaptureKind::ByCopy &&
3256 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3261 if (mapDataParentIdx >= 0)
3262 combinedInfo.BasePointers.emplace_back(
3263 mapData.BasePointers[mapDataParentIdx]);
3265 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
3267 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
3268 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
3269 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
3270 combinedInfo.Types.emplace_back(mapFlag);
3271 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
3276 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl,
3277 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3278 uint64_t mapDataIndex,
bool isTargetParams) {
3280 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3285 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
3286 auto memberClause = llvm::cast<omp::MapInfoOp>(
3287 parentClause.getMembers()[0].getDefiningOp());
3304 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
3306 combinedInfo, mapData, mapDataIndex, isTargetParams);
3308 combinedInfo, mapData, mapDataIndex,
3309 memberOfParentFlag);
3319 llvm::IRBuilderBase &builder) {
3320 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
3322 if (!mapData.IsDeclareTarget[i]) {
3323 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3324 omp::VariableCaptureKind captureKind =
3325 mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
3335 switch (captureKind) {
3336 case omp::VariableCaptureKind::ByRef: {
3337 llvm::Value *newV = mapData.Pointers[i];
3339 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
3342 newV = builder.CreateLoad(builder.getPtrTy(), newV);
3344 if (!offsetIdx.empty())
3345 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
3347 mapData.Pointers[i] = newV;
3349 case omp::VariableCaptureKind::ByCopy: {
3350 llvm::Type *type = mapData.BaseType[i];
3352 if (mapData.Pointers[i]->getType()->isPointerTy())
3353 newV = builder.CreateLoad(type, mapData.Pointers[i]);
3355 newV = mapData.Pointers[i];
3358 auto curInsert = builder.saveIP();
3360 auto *memTempAlloc =
3361 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
3362 builder.restoreIP(curInsert);
3364 builder.CreateStore(newV, memTempAlloc);
3365 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
3368 mapData.Pointers[i] = newV;
3369 mapData.BasePointers[i] = newV;
3371 case omp::VariableCaptureKind::This:
3372 case omp::VariableCaptureKind::VLAType:
3373 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
3384 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
3385 MapInfoData &mapData,
bool isTargetParams =
false) {
3407 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
3410 if (mapData.IsAMember[i])
3413 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
3414 if (!mapInfoOp.getMembers().empty()) {
3416 combinedInfo, mapData, i, isTargetParams);
3424 static LogicalResult
3427 llvm::Value *ifCond =
nullptr;
3428 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
3432 llvm::omp::RuntimeFunction RTLFn;
3436 llvm::OpenMPIRBuilder::TargetDataInfo info(
true,
3439 LogicalResult result =
3441 .Case([&](omp::TargetDataOp dataOp) {
3445 if (
auto ifVar = dataOp.getIfExpr())
3448 if (
auto devId = dataOp.getDevice())
3450 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3451 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3452 deviceID = intAttr.getInt();
3454 mapVars = dataOp.getMapVars();
3455 useDevicePtrVars = dataOp.getUseDevicePtrVars();
3456 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
3459 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
3463 if (
auto ifVar = enterDataOp.getIfExpr())
3466 if (
auto devId = enterDataOp.getDevice())
3468 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3469 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3470 deviceID = intAttr.getInt();
3472 enterDataOp.getNowait()
3473 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
3474 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
3475 mapVars = enterDataOp.getMapVars();
3476 info.HasNoWait = enterDataOp.getNowait();
3479 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
3483 if (
auto ifVar = exitDataOp.getIfExpr())
3486 if (
auto devId = exitDataOp.getDevice())
3488 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3489 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3490 deviceID = intAttr.getInt();
3492 RTLFn = exitDataOp.getNowait()
3493 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
3494 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
3495 mapVars = exitDataOp.getMapVars();
3496 info.HasNoWait = exitDataOp.getNowait();
3499 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
3503 if (
auto ifVar = updateDataOp.getIfExpr())
3506 if (
auto devId = updateDataOp.getDevice())
3508 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3509 if (
auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3510 deviceID = intAttr.getInt();
3513 updateDataOp.getNowait()
3514 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
3515 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
3516 mapVars = updateDataOp.getMapVars();
3517 info.HasNoWait = updateDataOp.getNowait();
3521 llvm_unreachable(
"unexpected operation");
3528 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3530 MapInfoData mapData;
3532 builder, useDevicePtrVars, useDeviceAddrVars);
3535 llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
3537 [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
3538 builder.restoreIP(codeGenIP);
3539 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
3540 return combinedInfo;
3546 [&moduleTranslation](
3547 llvm::OpenMPIRBuilder::DeviceInfoTy type,
3551 for (
auto [arg, useDevVar] :
3552 llvm::zip_equal(blockArgs, useDeviceVars)) {
3554 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
3555 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
3556 : mapInfoOp.getVarPtr();
3559 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
3560 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
3561 mapInfoData.MapClause, mapInfoData.DevicePointers,
3562 mapInfoData.BasePointers)) {
3563 auto mapOp = cast<omp::MapInfoOp>(mapClause);
3564 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
3565 devicePointer != type)
3568 if (llvm::Value *devPtrInfoMap =
3569 mapper ? mapper(basePointer) : basePointer) {
3570 moduleTranslation.
mapValue(arg, devPtrInfoMap);
3577 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
3578 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
3579 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
3580 assert(isa<omp::TargetDataOp>(op) &&
3581 "BodyGen requested for non TargetDataOp");
3582 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
3583 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
3584 switch (bodyGenType) {
3585 case BodyGenTy::Priv:
3587 if (!info.DevicePtrInfoMap.empty()) {
3588 builder.restoreIP(codeGenIP);
3590 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
3591 blockArgIface.getUseDeviceAddrBlockArgs(),
3592 useDeviceAddrVars, mapData,
3593 [&](llvm::Value *basePointer) -> llvm::Value * {
3594 if (!info.DevicePtrInfoMap[basePointer].second)
3596 return builder.CreateLoad(
3598 info.DevicePtrInfoMap[basePointer].second);
3600 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
3601 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
3602 mapData, [&](llvm::Value *basePointer) {
3603 return info.DevicePtrInfoMap[basePointer].second;
3607 moduleTranslation)))
3608 return llvm::make_error<PreviouslyReportedError>();
3611 case BodyGenTy::DupNoPriv:
3613 case BodyGenTy::NoPriv:
3615 if (info.DevicePtrInfoMap.empty()) {
3616 builder.restoreIP(codeGenIP);
3619 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
3620 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
3621 blockArgIface.getUseDeviceAddrBlockArgs(),
3622 useDeviceAddrVars, mapData);
3623 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
3624 blockArgIface.getUseDevicePtrBlockArgs(),
3625 useDevicePtrVars, mapData);
3629 moduleTranslation)))
3630 return llvm::make_error<PreviouslyReportedError>();
3634 return builder.saveIP();
3637 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3638 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3640 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
3641 if (isa<omp::TargetDataOp>(op))
3642 return ompBuilder->createTargetData(
3643 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID),
3644 ifCond, info, genMapInfoCB,
nullptr, bodyGenCB);
3645 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
3646 builder.getInt64(deviceID), ifCond,
3647 info, genMapInfoCB, &RTLFn);
3653 builder.restoreIP(*afterIP);
3662 if (!cast<mlir::ModuleOp>(op))
3667 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
3668 attribute.getOpenmpDeviceVersion());
3670 if (attribute.getNoGpuLib())
3673 ompBuilder->createGlobalFlag(
3674 attribute.getDebugKind() ,
3675 "__omp_rtl_debug_kind");
3676 ompBuilder->createGlobalFlag(
3678 .getAssumeTeamsOversubscription()
3680 "__omp_rtl_assume_teams_oversubscription");
3681 ompBuilder->createGlobalFlag(
3683 .getAssumeThreadsOversubscription()
3685 "__omp_rtl_assume_threads_oversubscription");
3686 ompBuilder->createGlobalFlag(
3687 attribute.getAssumeNoThreadState() ,
3688 "__omp_rtl_assume_no_thread_state");
3689 ompBuilder->createGlobalFlag(
3691 .getAssumeNoNestedParallelism()
3693 "__omp_rtl_assume_no_nested_parallelism");
3698 omp::TargetOp targetOp,
3699 llvm::StringRef parentName =
"") {
3700 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
3702 assert(fileLoc &&
"No file found from location");
3703 StringRef fileName = fileLoc.getFilename().getValue();
3705 llvm::sys::fs::UniqueID id;
3706 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
3707 targetOp.emitError(
"Unable to get unique ID for file");
3711 uint64_t line = fileLoc.getLine();
3712 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.getDevice(),
3713 id.getFile(), line);
3720 llvm::IRBuilderBase &builder, llvm::Function *func) {
3721 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
3734 if (mapData.IsDeclareTarget[i]) {
3741 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
3742 convertUsersOfConstantsToInstructions(constant, func,
false);
3749 for (llvm::User *user : mapData.OriginalValue[i]->users())
3750 userVec.push_back(user);
3752 for (llvm::User *user : userVec) {
3753 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
3754 if (insn->getFunction() == func) {
3755 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
3756 mapData.BasePointers[i]);
3757 load->moveBefore(insn);
3758 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
3805 static llvm::IRBuilderBase::InsertPoint
3807 llvm::Value *input, llvm::Value *&retVal,
3808 llvm::IRBuilderBase &builder,
3809 llvm::OpenMPIRBuilder &ompBuilder,
3811 llvm::IRBuilderBase::InsertPoint allocaIP,
3812 llvm::IRBuilderBase::InsertPoint codeGenIP) {
3813 builder.restoreIP(allocaIP);
3815 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
3818 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
3819 if (mapData.OriginalValue[i] == input) {
3820 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3822 mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
3827 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
3828 unsigned int defaultAS =
3829 ompBuilder.M.getDataLayout().getProgramAddressSpace();
3832 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
3834 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
3835 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
3837 builder.CreateStore(&arg, v);
3839 builder.restoreIP(codeGenIP);
3842 case omp::VariableCaptureKind::ByCopy: {
3846 case omp::VariableCaptureKind::ByRef: {
3847 retVal = builder.CreateAlignedLoad(
3849 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
3852 case omp::VariableCaptureKind::This:
3853 case omp::VariableCaptureKind::VLAType:
3856 assert(
false &&
"Currently unsupported capture kind");
3860 return builder.saveIP();
3863 static LogicalResult
3866 auto targetOp = cast<omp::TargetOp>(opInst);
3871 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
3873 auto &targetRegion = targetOp.getRegion();
3888 cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs();
3889 llvm::Function *llvmOutlinedFn =
nullptr;
3893 bool isOffloadEntry =
3894 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
3901 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
3902 auto argIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3904 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
3905 std::optional<DenseI64ArrayAttr> privateMapIndices =
3906 targetOp.getPrivateMapsAttr();
3908 for (
auto [privVarIdx, privVarSymPair] :
3910 auto privVar = std::get<0>(privVarSymPair);
3911 auto privSym = std::get<1>(privVarSymPair);
3913 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
3914 omp::PrivateClauseOp privatizer =
3917 if (!privatizer.needsMap())
3921 targetOp.getMappedValueForPrivateVar(privVarIdx);
3922 assert(mappedValue &&
"Expected to find mapped value for a privatized "
3923 "variable that needs mapping");
3928 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
3929 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
3933 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
3935 varType == privVar.getType() &&
3936 "Type of private var doesn't match the type of the mapped value");
3940 mappedPrivateVars.insert(
3942 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
3943 (*privateMapIndices)[privVarIdx])});
3947 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3948 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
3949 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
3952 llvm::Function *llvmParentFn =
3954 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
3955 assert(llvmParentFn && llvmOutlinedFn &&
3956 "Both parent and outlined functions must exist at this point");
3958 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
3959 attr.isStringAttribute())
3960 llvmOutlinedFn->addFnAttr(attr);
3962 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
3963 attr.isStringAttribute())
3964 llvmOutlinedFn->addFnAttr(attr);
3966 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
3967 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
3968 llvm::Value *mapOpValue =
3969 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
3970 moduleTranslation.
mapValue(arg, mapOpValue);
3976 cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3980 mlirPrivateVars.reserve(privateBlockArgs.size());
3981 llvmPrivateVars.reserve(privateBlockArgs.size());
3983 for (
mlir::Value privateVar : targetOp.getPrivateVars())
3984 mlirPrivateVars.push_back(privateVar);
3987 builder, moduleTranslation, privateBlockArgs, privateDecls,
3988 mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars);
3991 return llvm::make_error<PreviouslyReportedError>();
3994 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
3995 [](omp::PrivateClauseOp privatizer) {
3996 return &privatizer.getDeallocRegion();
3999 builder.restoreIP(codeGenIP);
4001 targetRegion,
"omp.target", builder, moduleTranslation);
4004 return exitBlock.takeError();
4006 builder.SetInsertPoint(*exitBlock);
4007 if (!privateCleanupRegions.empty()) {
4009 privateCleanupRegions, llvmPrivateVars, moduleTranslation,
4010 builder,
"omp.targetop.private.cleanup",
4012 return llvm::createStringError(
4013 "failed to inline `dealloc` region of `omp.private` "
4014 "op in the target region");
4018 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
4021 StringRef parentName = parentFn.getName();
4023 llvm::TargetRegionEntryInfo entryInfo;
4028 int32_t defaultValTeams = -1;
4029 int32_t defaultValThreads = 0;
4031 MapInfoData mapData;
4035 llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
4036 auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
4037 -> llvm::OpenMPIRBuilder::MapInfosTy & {
4038 builder.restoreIP(codeGenIP);
4039 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
true);
4040 return combinedInfos;
4043 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
4044 llvm::Value *&retVal, InsertPointTy allocaIP,
4045 InsertPointTy codeGenIP)
4046 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4052 if (!isTargetDevice) {
4053 retVal = cast<llvm::Value>(&arg);
4058 *ompBuilder, moduleTranslation,
4059 allocaIP, codeGenIP);
4063 for (
size_t i = 0; i < mapVars.size(); ++i) {
4070 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
4071 kernelInput.push_back(mapData.OriginalValue[i]);
4076 moduleTranslation, dds);
4078 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4082 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4084 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
4085 defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
4086 argAccessorCB, dds, targetOp.getNowait());
4091 builder.restoreIP(*afterIP);
4102 static LogicalResult
4112 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
4113 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
4115 if (!offloadMod.getIsTargetDevice())
4118 omp::DeclareTargetDeviceType declareType =
4119 attribute.getDeviceType().getValue();
4121 if (declareType == omp::DeclareTargetDeviceType::host) {
4122 llvm::Function *llvmFunc =
4124 llvmFunc->dropAllReferences();
4125 llvmFunc->eraseFromParent();
4131 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
4132 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
4133 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
4135 bool isDeclaration = gOp.isDeclaration();
4136 bool isExternallyVisible =
4139 llvm::StringRef mangledName = gOp.getSymName();
4140 auto captureClause =
4146 std::vector<llvm::GlobalVariable *> generatedRefs;
4148 std::vector<llvm::Triple> targetTriple;
4149 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
4151 LLVM::LLVMDialect::getTargetTripleAttrName()));
4152 if (targetTripleAttr)
4153 targetTriple.emplace_back(targetTripleAttr.data());
4155 auto fileInfoCallBack = [&loc]() {
4156 std::string filename =
"";
4157 std::uint64_t lineNo = 0;
4160 filename = loc.getFilename().str();
4161 lineNo = loc.getLine();
4164 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
4168 ompBuilder->registerTargetGlobalVariable(
4169 captureClause, deviceClause, isDeclaration, isExternallyVisible,
4170 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
4171 generatedRefs,
false, targetTriple,
4173 gVal->getType(), gVal);
4175 if (ompBuilder->Config.isTargetDevice() &&
4176 (attribute.getCaptureClause().getValue() !=
4177 mlir::omp::DeclareTargetCaptureClause::to ||
4178 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4179 ompBuilder->getAddrOfDeclareTargetVar(
4180 captureClause, deviceClause, isDeclaration, isExternallyVisible,
4181 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
4182 generatedRefs,
false, targetTriple, gVal->getType(),
4200 if (
auto declareTargetIface =
4201 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
4202 parentFn.getOperation()))
4203 if (declareTargetIface.isDeclareTarget() &&
4204 declareTargetIface.getDeclareTargetDeviceType() !=
4205 mlir::omp::DeclareTargetDeviceType::host)
4213 static LogicalResult
4220 .Case([&](omp::BarrierOp op) -> LogicalResult {
4224 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4225 ompBuilder->createBarrier(builder.saveIP(),
4226 llvm::omp::OMPD_barrier);
4229 .Case([&](omp::TaskyieldOp op) {
4233 ompBuilder->createTaskyield(builder.saveIP());
4236 .Case([&](omp::FlushOp op) {
4248 ompBuilder->createFlush(builder.saveIP());
4251 .Case([&](omp::ParallelOp op) {
4254 .Case([&](omp::MaskedOp) {
4257 .Case([&](omp::MasterOp) {
4260 .Case([&](omp::CriticalOp) {
4263 .Case([&](omp::OrderedRegionOp) {
4266 .Case([&](omp::OrderedOp) {
4269 .Case([&](omp::WsloopOp) {
4272 .Case([&](omp::SimdOp) {
4275 .Case([&](omp::AtomicReadOp) {
4278 .Case([&](omp::AtomicWriteOp) {
4281 .Case([&](omp::AtomicUpdateOp op) {
4284 .Case([&](omp::AtomicCaptureOp op) {
4287 .Case([&](omp::SectionsOp) {
4290 .Case([&](omp::SingleOp op) {
4293 .Case([&](omp::TeamsOp op) {
4296 .Case([&](omp::TaskOp op) {
4299 .Case([&](omp::TaskgroupOp op) {
4302 .Case([&](omp::TaskwaitOp op) {
4305 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
4306 omp::CriticalDeclareOp>([](
auto op) {
4317 .Case([&](omp::ThreadprivateOp) {
4320 .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
4321 omp::TargetUpdateOp>([&](
auto op) {
4324 .Case([&](omp::TargetOp) {
4327 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
4339 static LogicalResult
4345 static LogicalResult
4348 if (isa<omp::TargetOp>(op))
4350 if (isa<omp::TargetDataOp>(op))
4354 if (isa<omp::TargetOp>(oper)) {
4356 return WalkResult::interrupt();
4357 return WalkResult::skip();
4359 if (isa<omp::TargetDataOp>(oper)) {
4361 return WalkResult::interrupt();
4362 return WalkResult::skip();
4364 return WalkResult::advance();
4365 }).wasInterrupted();
4366 return failure(interrupted);
4373 class OpenMPDialectLLVMIRTranslationInterface
4394 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
4400 .Case(
"omp.is_target_device",
4402 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
4403 llvm::OpenMPIRBuilderConfig &
config =
4405 config.setIsTargetDevice(deviceAttr.getValue());
4412 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
4413 llvm::OpenMPIRBuilderConfig &
config =
4415 config.setIsGPU(gpuAttr.getValue());
4420 .Case(
"omp.host_ir_filepath",
4422 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
4423 llvm::OpenMPIRBuilder *ompBuilder =
4425 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
4432 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
4436 .Case(
"omp.version",
4438 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
4439 llvm::OpenMPIRBuilder *ompBuilder =
4441 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
4442 versionAttr.getVersion());
4447 .Case(
"omp.declare_target",
4449 if (
auto declareTargetAttr =
4450 dyn_cast<omp::DeclareTargetAttr>(attr))
4455 .Case(
"omp.requires",
4457 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
4458 using Requires = omp::ClauseRequires;
4459 Requires flags = requiresAttr.getValue();
4460 llvm::OpenMPIRBuilderConfig &
config =
4462 config.setHasRequiresReverseOffload(
4463 bitEnumContainsAll(flags, Requires::reverse_offload));
4464 config.setHasRequiresUnifiedAddress(
4465 bitEnumContainsAll(flags, Requires::unified_address));
4466 config.setHasRequiresUnifiedSharedMemory(
4467 bitEnumContainsAll(flags, Requires::unified_shared_memory));
4468 config.setHasRequiresDynamicAllocators(
4469 bitEnumContainsAll(flags, Requires::dynamic_allocators));
4474 .Case(
"omp.target_triples",
4476 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
4477 llvm::OpenMPIRBuilderConfig &
config =
4479 config.TargetTriples.clear();
4480 config.TargetTriples.reserve(triplesAttr.size());
4481 for (
Attribute tripleAttr : triplesAttr) {
4482 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
4483 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
4501 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
4502 Operation *op, llvm::IRBuilderBase &builder,
4506 if (ompBuilder->Config.isTargetDevice()) {
4517 registry.
insert<omp::OpenMPDialect>();
4519 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static llvm::Value * getRefPtrIfDeclareTarget(mlir::Value value, LLVM::ModuleTranslation &moduleTranslation)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static LogicalResult convertIgnoredWrappers(omp::LoopNestOp loopOp, omp::LoopWrapperInterface parentOp, LLVM::ModuleTranslation &moduleTranslation)
Helper function to call convertIgnoredWrapper() for all wrappers of the given loopOp nested inside of...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, const LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, MutableArrayRef< BlockArgument > privateBlockArgs, MutableArrayRef< omp::PrivateClauseOp > privateDecls, MutableArrayRef< mlir::Value > mlirPrivateVars, llvm::SmallVectorImpl< llvm::Value * > &llvmPrivateVars, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate delayed private variables.
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static LogicalResult initReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::BasicBlock *latestAllocaBlock, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef, SmallVectorImpl< DeferredStore > &deferredStores)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static void collectPrivatizationDecls(OP op, SmallVectorImpl< omp::PrivateClauseOp > &privatizations)
Populates privatizations with privatization declarations used for the given op.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef)
LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, bool isTargetParams, int mapDataParentIdx=-1)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool >> attr)
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams=false)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult convertIgnoredWrapper(omp::LoopWrapperInterface &opInst, LLVM::ModuleTranslation &moduleTranslation)
Helper function to map block arguments defined by ignored loop wrappers to LLVM values and prevent an...
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, const ArrayRef< Value > &useDevPtrOperands={}, const ArrayRef< Value > &useDevAddrOperands={})
static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void collectReductionInfo(T loop, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< OwningReductionGen > &owningReductionGens, SmallVectorImpl< OwningAtomicReductionGen > &owningAtomicReductionGens, const ArrayRef< llvm::Value * > privateReductionVariables, SmallVectorImpl< llvm::OpenMPIRBuilder::ReductionInfo > &reductionInfos)
Collect reduction info.
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult initFirstPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, llvm::BasicBlock *afterAllocas)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(mlir::Value value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to act on an operation that has dialect attributes from the derive...
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Concrete CRTP base class for ModuleTranslation stack frames.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
WalkResult stackWalk(llvm::function_ref< WalkResult(const T &)> callback) const
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
RAII object calling stackPush/stackPop on construction/destruction.