24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
198 if (linearVarType->isIntegerTy()) {
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 }
else if (linearVarType->isFloatingPointTy()) {
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
220 "Linear variable must be of integer or floating-point type");
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(),
"omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
265 builder.SetInsertPoint(linearExitBB->getTerminator());
267 builder.saveIP(), llvm::omp::OMPD_barrier);
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
282 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (
auto *user : users) {
288 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
307 assert(privatizer &&
"privatizer not found in the symbol table");
318 auto todo = [&op](StringRef clauseName) {
319 return op.
emitError() <<
"not yet implemented: Unhandled clause "
320 << clauseName <<
" in " << op.
getName()
324 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
325 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
326 result = todo(
"allocate");
328 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
330 result = todo(
"ompx_bare");
332 auto checkCollapse = [&todo](
auto op, LogicalResult &
result) {
333 if (op.getCollapseNumLoops() > 1)
334 result = todo(
"collapse");
336 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
340 auto checkHint = [](
auto op, LogicalResult &) {
344 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo(
"in_reduction");
349 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
353 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
354 if (op.getOrder() || op.getOrderMod())
357 auto checkParLevelSimd = [&todo](
auto op, LogicalResult &
result) {
358 if (op.getParLevelSimd())
359 result = todo(
"parallelization-level");
361 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo(
"privatization");
365 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo(
"reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo(
"reduction with modifier");
374 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo(
"task_reduction");
382 .Case([&](omp::DistributeOp op) {
383 checkAllocate(op,
result);
386 .Case([&](omp::LoopNestOp op) {
387 if (mlir::isa<omp::TaskloopOp>(op.getOperation()->
getParentOp()))
388 checkCollapse(op,
result);
390 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op,
result); })
391 .Case([&](omp::SectionsOp op) {
392 checkAllocate(op,
result);
394 checkReduction(op,
result);
396 .Case([&](omp::SingleOp op) {
397 checkAllocate(op,
result);
400 .Case([&](omp::TeamsOp op) {
401 checkAllocate(op,
result);
404 .Case([&](omp::TaskOp op) {
405 checkAllocate(op,
result);
406 checkInReduction(op,
result);
408 .Case([&](omp::TaskgroupOp op) {
409 checkAllocate(op,
result);
410 checkTaskReduction(op,
result);
412 .Case([&](omp::TaskwaitOp op) {
416 .Case([&](omp::TaskloopOp op) {
417 checkAllocate(op,
result);
418 checkInReduction(op,
result);
419 checkReduction(op,
result);
421 .Case([&](omp::WsloopOp op) {
422 checkAllocate(op,
result);
424 checkReduction(op,
result);
426 .Case([&](omp::ParallelOp op) {
427 checkAllocate(op,
result);
428 checkReduction(op,
result);
430 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
431 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
432 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
433 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
434 [&](
auto op) { checkDepend(op,
result); })
435 .Case<omp::TargetUpdateOp>([&](
auto op) { checkDepend(op,
result); })
436 .Case([&](omp::TargetOp op) {
437 checkAllocate(op,
result);
439 checkInReduction(op,
result);
451 llvm::handleAllErrors(
453 [&](
const PreviouslyReportedError &) {
result = failure(); },
454 [&](
const llvm::ErrorInfoBase &err) {
471static llvm::OpenMPIRBuilder::InsertPointTy
477 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
479 [&](OpenMPAllocaStackFrame &frame) {
480 allocaInsertPoint = frame.allocaInsertPoint;
488 allocaInsertPoint.getBlock()->getParent() ==
489 builder.GetInsertBlock()->getParent())
490 return allocaInsertPoint;
499 if (builder.GetInsertBlock() ==
500 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
501 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
502 "Assuming end of basic block");
503 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
504 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
505 builder.GetInsertBlock()->getNextNode());
506 builder.CreateBr(entryBB);
507 builder.SetInsertPoint(entryBB);
510 llvm::BasicBlock &funcEntryBlock =
511 builder.GetInsertBlock()->getParent()->getEntryBlock();
512 return llvm::OpenMPIRBuilder::InsertPointTy(
513 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
519static llvm::CanonicalLoopInfo *
521 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
522 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
523 [&](OpenMPLoopInfoStackFrame &frame) {
524 loopInfo = frame.loopInfo;
536 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
539 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
541 llvm::BasicBlock *continuationBlock =
542 splitBB(builder,
true,
"omp.region.cont");
543 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
545 llvm::LLVMContext &llvmContext = builder.getContext();
546 for (
Block &bb : region) {
547 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
548 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
549 builder.GetInsertBlock()->getNextNode());
550 moduleTranslation.
mapBlock(&bb, llvmBB);
553 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
560 unsigned numYields = 0;
562 if (!isLoopWrapper) {
563 bool operandsProcessed =
false;
565 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
566 if (!operandsProcessed) {
567 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
568 continuationBlockPHITypes.push_back(
569 moduleTranslation.
convertType(yield->getOperand(i).getType()));
571 operandsProcessed =
true;
573 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
574 "mismatching number of values yielded from the region");
575 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
576 llvm::Type *operandType =
577 moduleTranslation.
convertType(yield->getOperand(i).getType());
579 assert(continuationBlockPHITypes[i] == operandType &&
580 "values of mismatching types yielded from the region");
590 if (!continuationBlockPHITypes.empty())
592 continuationBlockPHIs &&
593 "expected continuation block PHIs if converted regions yield values");
594 if (continuationBlockPHIs) {
595 llvm::IRBuilderBase::InsertPointGuard guard(builder);
596 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
597 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
598 for (llvm::Type *ty : continuationBlockPHITypes)
599 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
605 for (
Block *bb : blocks) {
606 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
609 if (bb->isEntryBlock()) {
610 assert(sourceTerminator->getNumSuccessors() == 1 &&
611 "provided entry block has multiple successors");
612 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
613 "ContinuationBlock is not the successor of the entry block");
614 sourceTerminator->setSuccessor(0, llvmBB);
617 llvm::IRBuilderBase::InsertPointGuard guard(builder);
619 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
620 return llvm::make_error<PreviouslyReportedError>();
625 builder.CreateBr(continuationBlock);
636 Operation *terminator = bb->getTerminator();
637 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
638 builder.CreateBr(continuationBlock);
640 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
641 (*continuationBlockPHIs)[i]->addIncoming(
655 return continuationBlock;
661 case omp::ClauseProcBindKind::Close:
662 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
663 case omp::ClauseProcBindKind::Master:
664 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
665 case omp::ClauseProcBindKind::Primary:
666 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
667 case omp::ClauseProcBindKind::Spread:
668 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
670 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
680 omp::BlockArgOpenMPOpInterface blockArgIface) {
682 blockArgIface.getBlockArgsPairs(blockArgsPairs);
683 for (
auto [var, arg] : blockArgsPairs)
691 auto maskedOp = cast<omp::MaskedOp>(opInst);
692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
697 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
699 auto ®ion = maskedOp.getRegion();
700 builder.restoreIP(codeGenIP);
708 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
710 llvm::Value *filterVal =
nullptr;
711 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
712 filterVal = moduleTranslation.
lookupValue(filterVar);
714 llvm::LLVMContext &llvmContext = builder.getContext();
716 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
718 assert(filterVal !=
nullptr);
719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
720 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
727 builder.restoreIP(*afterIP);
735 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
736 auto masterOp = cast<omp::MasterOp>(opInst);
741 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
743 auto ®ion = masterOp.getRegion();
744 builder.restoreIP(codeGenIP);
752 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
762 builder.restoreIP(*afterIP);
770 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
771 auto criticalOp = cast<omp::CriticalOp>(opInst);
776 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
778 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
779 builder.restoreIP(codeGenIP);
787 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
789 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
790 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
791 llvm::Constant *hint =
nullptr;
794 if (criticalOp.getNameAttr()) {
797 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
798 auto criticalDeclareOp =
802 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
803 static_cast<int>(criticalDeclareOp.getHint()));
805 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
807 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
812 builder.restoreIP(*afterIP);
819 template <
typename OP>
822 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
825 collectPrivatizationDecls<OP>(op);
840 void collectPrivatizationDecls(OP op) {
841 std::optional<ArrayAttr> attr = op.getPrivateSyms();
846 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
857 std::optional<ArrayAttr> attr = op.getReductionSyms();
861 reductions.reserve(reductions.size() + op.getNumReductionVars());
862 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
863 reductions.push_back(
875 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
884 llvm::Instruction *potentialTerminator =
885 builder.GetInsertBlock()->empty() ?
nullptr
886 : &builder.GetInsertBlock()->back();
888 if (potentialTerminator && potentialTerminator->isTerminator())
889 potentialTerminator->removeFromParent();
890 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
893 region.
front(),
true, builder)))
897 if (continuationBlockArgs)
899 *continuationBlockArgs,
906 if (potentialTerminator && potentialTerminator->isTerminator()) {
907 llvm::BasicBlock *block = builder.GetInsertBlock();
908 if (block->empty()) {
914 potentialTerminator->insertInto(block, block->begin());
916 potentialTerminator->insertAfter(&block->back());
930 if (continuationBlockArgs)
931 llvm::append_range(*continuationBlockArgs, phis);
932 builder.SetInsertPoint(*continuationBlock,
933 (*continuationBlock)->getFirstInsertionPt());
940using OwningReductionGen =
941 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
942 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
944using OwningAtomicReductionGen =
945 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
946 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
948using OwningDataPtrPtrReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
956static OwningReductionGen
962 OwningReductionGen gen =
963 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
964 llvm::Value *
lhs, llvm::Value *
rhs,
965 llvm::Value *&
result)
mutable
966 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
967 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
968 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
969 builder.restoreIP(insertPoint);
972 "omp.reduction.nonatomic.body", builder,
973 moduleTranslation, &phis)))
974 return llvm::createStringError(
975 "failed to inline `combiner` region of `omp.declare_reduction`");
976 result = llvm::getSingleElement(phis);
977 return builder.saveIP();
986static OwningAtomicReductionGen
988 llvm::IRBuilderBase &builder,
990 if (decl.getAtomicReductionRegion().empty())
991 return OwningAtomicReductionGen();
997 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
998 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1001 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1002 builder.restoreIP(insertPoint);
1005 "omp.reduction.atomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `atomic` region of `omp.declare_reduction`");
1009 assert(phis.empty());
1010 return builder.saveIP();
1019static OwningDataPtrPtrReductionGen
1023 return OwningDataPtrPtrReductionGen();
1025 OwningDataPtrPtrReductionGen refDataPtrGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1027 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1030 builder.restoreIP(insertPoint);
1033 "omp.data_ptr_ptr.body", builder,
1034 moduleTranslation, &phis)))
1035 return llvm::createStringError(
1036 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1037 result = llvm::getSingleElement(phis);
1038 return builder.saveIP();
1041 return refDataPtrGen;
1048 auto orderedOp = cast<omp::OrderedOp>(opInst);
1053 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1054 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1055 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1057 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1059 size_t indexVecValues = 0;
1060 while (indexVecValues < vecValues.size()) {
1062 storeValues.reserve(numLoops);
1063 for (
unsigned i = 0; i < numLoops; i++) {
1064 storeValues.push_back(vecValues[indexVecValues]);
1067 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1070 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1071 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1082 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1087 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1089 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1090 builder.restoreIP(codeGenIP);
1098 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1100 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1101 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1103 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1108 builder.restoreIP(*afterIP);
1114struct DeferredStore {
1115 DeferredStore(llvm::Value *value, llvm::Value *address)
1116 : value(value), address(address) {}
1119 llvm::Value *address;
1126template <
typename T>
1129 llvm::IRBuilderBase &builder,
1131 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1137 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1141 deferredStores.reserve(loop.getNumReductionVars());
1143 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1144 Region &allocRegion = reductionDecls[i].getAllocRegion();
1146 if (allocRegion.
empty())
1151 builder, moduleTranslation, &phis)))
1152 return loop.emitError(
1153 "failed to inline `alloc` region of `omp.declare_reduction`");
1155 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1156 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.
convertType(reductionDecls[i].getType()));
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 llvm::Value *castPhi =
1167 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1169 deferredStores.emplace_back(castPhi, castVar);
1171 privateReductionVariables[i] = castVar;
1172 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1173 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1175 assert(allocRegion.
empty() &&
1176 "allocaction is implicit for by-val reduction");
1177 llvm::Value *var = builder.CreateAlloca(
1178 moduleTranslation.
convertType(reductionDecls[i].getType()));
1180 llvm::Type *ptrTy = builder.getPtrTy();
1181 llvm::Value *castVar =
1182 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1184 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1185 privateReductionVariables[i] = castVar;
1186 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1194template <
typename T>
1197 llvm::IRBuilderBase &builder,
1202 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1203 Region &initializerRegion = reduction.getInitializerRegion();
1206 mlir::Value mlirSource = loop.getReductionVars()[i];
1207 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1208 llvm::Value *origVal = llvmSource;
1210 if (!isa<LLVM::LLVMPointerType>(
1211 reduction.getInitializerMoldArg().getType()) &&
1212 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1215 reduction.getInitializerMoldArg().getType()),
1216 llvmSource,
"omp_orig");
1218 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1221 llvm::Value *allocation =
1222 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1223 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1229 llvm::BasicBlock *block =
nullptr) {
1230 if (block ==
nullptr)
1231 block = builder.GetInsertBlock();
1233 if (block->empty() || block->getTerminator() ==
nullptr)
1234 builder.SetInsertPoint(block);
1236 builder.SetInsertPoint(block->getTerminator());
1244template <
typename OP>
1247 llvm::IRBuilderBase &builder,
1249 llvm::BasicBlock *latestAllocaBlock,
1255 if (op.getNumReductionVars() == 0)
1258 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1259 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1260 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1261 builder.restoreIP(allocaIP);
1264 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1266 if (!reductionDecls[i].getAllocRegion().empty())
1272 byRefVars[i] = builder.CreateAlloca(
1273 moduleTranslation.
convertType(reductionDecls[i].getType()));
1281 for (
auto [data, addr] : deferredStores)
1282 builder.CreateStore(data, addr);
1287 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1292 reductionVariableMap, i);
1300 "omp.reduction.neutral", builder,
1301 moduleTranslation, &phis)))
1304 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1305 "reduction neutral element declaration region");
1310 if (!reductionDecls[i].getAllocRegion().empty())
1319 builder.CreateStore(phis[0], byRefVars[i]);
1321 privateReductionVariables[i] = byRefVars[i];
1322 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1323 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1326 builder.CreateStore(phis[0], privateReductionVariables[i]);
1333 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1340template <
typename T>
1341static void collectReductionInfo(
1342 T loop, llvm::IRBuilderBase &builder,
1351 unsigned numReductions = loop.getNumReductionVars();
1353 for (
unsigned i = 0; i < numReductions; ++i) {
1356 owningAtomicReductionGens.push_back(
1359 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1363 reductionInfos.reserve(numReductions);
1364 for (
unsigned i = 0; i < numReductions; ++i) {
1365 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1366 if (owningAtomicReductionGens[i])
1367 atomicGen = owningAtomicReductionGens[i];
1368 llvm::Value *variable =
1369 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1372 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1373 allocatedType = alloca.getElemType();
1380 reductionInfos.push_back(
1382 privateReductionVariables[i],
1383 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1387 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1388 reductionDecls[i].getByrefElementType()
1390 *reductionDecls[i].getByrefElementType())
1400 llvm::IRBuilderBase &builder, StringRef regionName,
1401 bool shouldLoadCleanupRegionArg =
true) {
1402 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1403 if (cleanupRegion->empty())
1409 llvm::Instruction *potentialTerminator =
1410 builder.GetInsertBlock()->empty() ?
nullptr
1411 : &builder.GetInsertBlock()->back();
1412 if (potentialTerminator && potentialTerminator->isTerminator())
1413 builder.SetInsertPoint(potentialTerminator);
1414 llvm::Value *privateVarValue =
1415 shouldLoadCleanupRegionArg
1416 ? builder.CreateLoad(
1418 privateVariables[i])
1419 : privateVariables[i];
1424 moduleTranslation)))
1437 OP op, llvm::IRBuilderBase &builder,
1439 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1442 bool isNowait =
false,
bool isTeamsReduction =
false) {
1444 if (op.getNumReductionVars() == 0)
1456 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1458 owningReductionGenRefDataPtrGens,
1459 privateReductionVariables, reductionInfos, isByRef);
1464 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1465 builder.SetInsertPoint(tempTerminator);
1466 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1467 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1468 isByRef, isNowait, isTeamsReduction);
1473 if (!contInsertPoint->getBlock())
1474 return op->emitOpError() <<
"failed to convert reductions";
1476 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1477 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1482 tempTerminator->eraseFromParent();
1483 builder.restoreIP(*afterIP);
1487 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1488 [](omp::DeclareReductionOp reductionDecl) {
1489 return &reductionDecl.getCleanupRegion();
1492 moduleTranslation, builder,
1493 "omp.reduction.cleanup");
1504template <
typename OP>
1508 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1513 if (op.getNumReductionVars() == 0)
1519 allocaIP, reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 deferredStores, isByRef)))
1524 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1525 allocaIP.getBlock(), reductionDecls,
1526 privateReductionVariables, reductionVariableMap,
1527 isByRef, deferredStores);
1541 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1544 Value blockArg = (*mappedPrivateVars)[privateVar];
1547 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1548 "A block argument corresponding to a mapped var should have "
1551 if (privVarType == blockArgType)
1558 if (!isa<LLVM::LLVMPointerType>(privVarType))
1559 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1572 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1574 llvm::BasicBlock *privInitBlock,
1576 Region &initRegion = privDecl.getInitRegion();
1577 if (initRegion.
empty())
1578 return llvmPrivateVar;
1580 assert(nonPrivateVar);
1581 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1582 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1587 moduleTranslation, &phis)))
1588 return llvm::createStringError(
1589 "failed to inline `init` region of `omp.private`");
1591 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1608 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1611 builder, moduleTranslation, privDecl,
1614 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1623 return llvm::Error::success();
1625 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1628 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1631 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1633 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1634 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1637 return privVarOrErr.takeError();
1639 llvmPrivateVar = privVarOrErr.get();
1640 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1645 return llvm::Error::success();
1655 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1658 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1659 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1660 allocaTerminator->getIterator()),
1661 true, allocaTerminator->getStableDebugLoc(),
1662 "omp.region.after_alloca");
1664 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1666 allocaTerminator = allocaIP.getBlock()->getTerminator();
1667 builder.SetInsertPoint(allocaTerminator);
1669 assert(allocaTerminator->getNumSuccessors() == 1 &&
1670 "This is an unconditional branch created by splitBB");
1672 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1673 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1675 unsigned int allocaAS =
1676 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1679 .getProgramAddressSpace();
1681 for (
auto [privDecl, mlirPrivVar, blockArg] :
1684 llvm::Type *llvmAllocType =
1685 moduleTranslation.
convertType(privDecl.getType());
1686 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1687 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1688 llvmAllocType,
nullptr,
"omp.private.alloc");
1689 if (allocaAS != defaultAS)
1690 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1691 builder.getPtrTy(defaultAS));
1693 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1696 return afterAllocas;
1704 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1713 if (mlir::isa<omp::ParallelOp>(parent))
1727 bool needsFirstprivate =
1728 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1729 return privOp.getDataSharingType() ==
1730 omp::DataSharingClauseType::FirstPrivate;
1733 if (!needsFirstprivate)
1736 llvm::BasicBlock *copyBlock =
1737 splitBB(builder,
true,
"omp.private.copy");
1740 for (
auto [decl, moldVar, llvmVar] :
1741 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1742 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1746 Region ©Region = decl.getCopyRegion();
1748 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1751 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1755 moduleTranslation)))
1756 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1771 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1787 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1789 llvm::Value *moldVar = findAssociatedValue(
1790 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1795 llvmPrivateVars, privateDecls, insertBarrier,
1806 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1807 [](omp::PrivateClauseOp privatizer) {
1808 return &privatizer.getDeallocRegion();
1812 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1813 "omp.private.dealloc",
false)))
1814 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1815 "`omp.private` op in");
1827 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1838 using StorableBodyGenCallbackTy =
1839 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1841 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1847 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1855 sectionsOp.getNumReductionVars());
1859 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1862 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1863 reductionDecls, privateReductionVariables, reductionVariableMap,
1870 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1874 Region ®ion = sectionOp.getRegion();
1875 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1876 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1877 builder.restoreIP(codeGenIP);
1884 sectionsOp.getRegion().getNumArguments());
1885 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1886 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1887 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1889 moduleTranslation.
mapValue(sectionArg, llvmVal);
1896 sectionCBs.push_back(sectionCB);
1902 if (sectionCBs.empty())
1905 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1910 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1911 llvm::Value &vPtr, llvm::Value *&replacementValue)
1912 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1913 replacementValue = &vPtr;
1919 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1926 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1927 sectionsOp.getNowait());
1932 builder.restoreIP(*afterIP);
1936 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1937 privateReductionVariables, isByRef, sectionsOp.getNowait());
1944 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1950 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1951 builder.restoreIP(codegenIP);
1953 builder, moduleTranslation)
1956 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1960 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1963 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1964 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1966 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1967 llvmCPFuncs.push_back(
1971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1973 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1979 builder.restoreIP(*afterIP);
1985 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1990 for (
auto ra : iface.getReductionBlockArgs())
1991 for (
auto &use : ra.getUses()) {
1992 auto *useOp = use.getOwner();
1994 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1995 debugUses.push_back(useOp);
1999 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2004 Operation *currentOp = currentDistOp.getOperation();
2005 if (distOp && (distOp != currentOp))
2014 for (
auto *use : debugUses)
2023 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2028 unsigned numReductionVars = op.getNumReductionVars();
2032 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2038 if (doTeamsReduction) {
2039 isByRef =
getIsByRef(op.getReductionByref());
2041 assert(isByRef.size() == op.getNumReductionVars());
2044 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2049 op, reductionArgs, builder, moduleTranslation, allocaIP,
2050 reductionDecls, privateReductionVariables, reductionVariableMap,
2055 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2057 moduleTranslation, allocaIP);
2058 builder.restoreIP(codegenIP);
2064 llvm::Value *numTeamsLower =
nullptr;
2065 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2066 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2068 llvm::Value *numTeamsUpper =
nullptr;
2069 if (
Value numTeamsUpperVar = op.getNumTeamsUpper())
2070 numTeamsUpper = moduleTranslation.
lookupValue(numTeamsUpperVar);
2072 llvm::Value *threadLimit =
nullptr;
2073 if (
Value threadLimitVar = op.getThreadLimit())
2074 threadLimit = moduleTranslation.
lookupValue(threadLimitVar);
2076 llvm::Value *ifExpr =
nullptr;
2077 if (
Value ifVar = op.getIfExpr())
2080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2081 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2083 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2088 builder.restoreIP(*afterIP);
2089 if (doTeamsReduction) {
2092 op, builder, moduleTranslation, allocaIP, reductionDecls,
2093 privateReductionVariables, isByRef,
2103 if (dependVars.empty())
2105 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2106 llvm::omp::RTLDependenceKindTy type;
2108 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2109 case mlir::omp::ClauseTaskDepend::taskdependin:
2110 type = llvm::omp::RTLDependenceKindTy::DepIn;
2115 case mlir::omp::ClauseTaskDepend::taskdependout:
2116 case mlir::omp::ClauseTaskDepend::taskdependinout:
2117 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2119 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2120 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2122 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2126 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2140 llvm::IRBuilderBase &llvmBuilder,
2142 llvm::omp::Directive cancelDirective) {
2143 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2144 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2148 llvmBuilder.restoreIP(ip);
2154 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2155 return llvm::Error::success();
2160 ompBuilder.pushFinalizationCB(
2170 llvm::OpenMPIRBuilder &ompBuilder,
2171 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2172 ompBuilder.popFinalizationCB();
2173 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2174 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2175 assert(cancelBranch->getNumSuccessors() == 1 &&
2176 "cancel branch should have one target");
2177 cancelBranch->setSuccessor(0, constructFini);
2184class TaskContextStructManager {
2186 TaskContextStructManager(llvm::IRBuilderBase &builder,
2187 LLVM::ModuleTranslation &moduleTranslation,
2188 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2189 : builder{builder}, moduleTranslation{moduleTranslation},
2190 privateDecls{privateDecls} {}
2196 void generateTaskContextStruct();
2202 void createGEPsToPrivateVars();
2208 SmallVector<llvm::Value *>
2209 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2212 void freeStructPtr();
2214 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2215 return llvmPrivateVarGEPs;
2218 llvm::Value *getStructPtr() {
return structPtr; }
2221 llvm::IRBuilderBase &builder;
2222 LLVM::ModuleTranslation &moduleTranslation;
2223 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2226 SmallVector<llvm::Type *> privateVarTypes;
2230 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2233 llvm::Value *structPtr =
nullptr;
2235 llvm::Type *structTy =
nullptr;
2239void TaskContextStructManager::generateTaskContextStruct() {
2240 if (privateDecls.empty())
2242 privateVarTypes.reserve(privateDecls.size());
2244 for (omp::PrivateClauseOp &privOp : privateDecls) {
2247 if (!privOp.readsFromMold())
2249 Type mlirType = privOp.getType();
2250 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2253 if (privateVarTypes.empty())
2256 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2259 llvm::DataLayout dataLayout =
2260 builder.GetInsertBlock()->getModule()->getDataLayout();
2261 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2262 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2265 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2267 "omp.task.context_ptr");
2270SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2271 llvm::Value *altStructPtr)
const {
2272 SmallVector<llvm::Value *> ret;
2275 ret.reserve(privateDecls.size());
2276 llvm::Value *zero = builder.getInt32(0);
2278 for (
auto privDecl : privateDecls) {
2279 if (!privDecl.readsFromMold()) {
2281 ret.push_back(
nullptr);
2284 llvm::Value *iVal = builder.getInt32(i);
2285 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2292void TaskContextStructManager::createGEPsToPrivateVars() {
2294 assert(privateVarTypes.empty());
2298 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2301void TaskContextStructManager::freeStructPtr() {
2305 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2307 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2308 builder.CreateFree(structPtr);
2315 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2320 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2337 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2338 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2339 builder.getContext(),
"omp.task.start",
2340 builder.GetInsertBlock()->getParent());
2341 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2342 builder.SetInsertPoint(branchToTaskStartBlock);
2345 llvm::BasicBlock *copyBlock =
2346 splitBB(builder,
true,
"omp.private.copy");
2347 llvm::BasicBlock *initBlock =
2348 splitBB(builder,
true,
"omp.private.init");
2364 moduleTranslation, allocaIP);
2367 builder.SetInsertPoint(initBlock->getTerminator());
2370 taskStructMgr.generateTaskContextStruct();
2377 taskStructMgr.createGEPsToPrivateVars();
2379 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2382 taskStructMgr.getLLVMPrivateVarGEPs())) {
2384 if (!privDecl.readsFromMold())
2386 assert(llvmPrivateVarAlloc &&
2387 "reads from mold so shouldn't have been skipped");
2390 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2391 blockArg, llvmPrivateVarAlloc, initBlock);
2392 if (!privateVarOrErr)
2393 return handleError(privateVarOrErr, *taskOp.getOperation());
2402 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2403 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2404 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2406 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2407 llvmPrivateVarAlloc);
2409 assert(llvmPrivateVarAlloc->getType() ==
2410 moduleTranslation.
convertType(blockArg.getType()));
2420 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2421 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2422 taskOp.getPrivateNeedsBarrier())))
2423 return llvm::failure();
2426 builder.SetInsertPoint(taskStartBlock);
2428 auto bodyCB = [&](InsertPointTy allocaIP,
2429 InsertPointTy codegenIP) -> llvm::Error {
2433 moduleTranslation, allocaIP);
2436 builder.restoreIP(codegenIP);
2438 llvm::BasicBlock *privInitBlock =
nullptr;
2440 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2443 auto [blockArg, privDecl, mlirPrivVar] = zip;
2445 if (privDecl.readsFromMold())
2448 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2449 llvm::Type *llvmAllocType =
2450 moduleTranslation.
convertType(privDecl.getType());
2451 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2452 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2453 llvmAllocType,
nullptr,
"omp.private.alloc");
2456 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2457 blockArg, llvmPrivateVar, privInitBlock);
2458 if (!privateVarOrError)
2459 return privateVarOrError.takeError();
2460 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2461 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2464 taskStructMgr.createGEPsToPrivateVars();
2465 for (
auto [i, llvmPrivVar] :
2466 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2468 assert(privateVarsInfo.
llvmVars[i] &&
2469 "This is added in the loop above");
2472 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2477 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2481 if (!privateDecl.readsFromMold())
2484 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2485 llvmPrivateVar = builder.CreateLoad(
2486 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2488 assert(llvmPrivateVar->getType() ==
2489 moduleTranslation.
convertType(blockArg.getType()));
2490 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2494 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2495 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2496 return llvm::make_error<PreviouslyReportedError>();
2498 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2503 return llvm::make_error<PreviouslyReportedError>();
2506 taskStructMgr.freeStructPtr();
2508 return llvm::Error::success();
2517 llvm::omp::Directive::OMPD_taskgroup);
2521 moduleTranslation, dds);
2523 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2526 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2528 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds,
2529 taskOp.getMergeable(),
2530 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2531 moduleTranslation.
lookupValue(taskOp.getPriority()));
2539 builder.restoreIP(*afterIP);
2547 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2548 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2556 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2562 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2563 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2564 builder.getContext(),
"omp.taskloop.start",
2565 builder.GetInsertBlock()->getParent());
2566 llvm::Instruction *branchToTaskloopStartBlock =
2567 builder.CreateBr(taskloopStartBlock);
2568 builder.SetInsertPoint(branchToTaskloopStartBlock);
2570 llvm::BasicBlock *copyBlock =
2571 splitBB(builder,
true,
"omp.private.copy");
2572 llvm::BasicBlock *initBlock =
2573 splitBB(builder,
true,
"omp.private.init");
2576 moduleTranslation, allocaIP);
2579 builder.SetInsertPoint(initBlock->getTerminator());
2582 taskStructMgr.generateTaskContextStruct();
2583 taskStructMgr.createGEPsToPrivateVars();
2585 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
2587 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2589 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2590 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2592 if (!privDecl.readsFromMold())
2594 assert(llvmPrivateVarAlloc &&
2595 "reads from mold so shouldn't have been skipped");
2598 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2599 blockArg, llvmPrivateVarAlloc, initBlock);
2600 if (!privateVarOrErr)
2601 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2603 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2606 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2608 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2609 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2610 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2612 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2613 llvmPrivateVarAlloc);
2615 assert(llvmPrivateVarAlloc->getType() ==
2616 moduleTranslation.
convertType(blockArg.getType()));
2622 taskloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2623 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2624 taskloopOp.getPrivateNeedsBarrier())))
2625 return llvm::failure();
2628 builder.SetInsertPoint(taskloopStartBlock);
2630 auto bodyCB = [&](InsertPointTy allocaIP,
2631 InsertPointTy codegenIP) -> llvm::Error {
2635 moduleTranslation, allocaIP);
2638 builder.restoreIP(codegenIP);
2640 llvm::BasicBlock *privInitBlock =
nullptr;
2642 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2645 auto [blockArg, privDecl, mlirPrivVar] = zip;
2647 if (privDecl.readsFromMold())
2650 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2651 llvm::Type *llvmAllocType =
2652 moduleTranslation.
convertType(privDecl.getType());
2653 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2654 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2655 llvmAllocType,
nullptr,
"omp.private.alloc");
2658 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2659 blockArg, llvmPrivateVar, privInitBlock);
2660 if (!privateVarOrError)
2661 return privateVarOrError.takeError();
2662 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2663 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2666 taskStructMgr.createGEPsToPrivateVars();
2667 for (
auto [i, llvmPrivVar] :
2668 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2670 assert(privateVarsInfo.
llvmVars[i] &&
2671 "This is added in the loop above");
2674 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2679 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2683 if (!privateDecl.readsFromMold())
2686 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2687 llvmPrivateVar = builder.CreateLoad(
2688 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2690 assert(llvmPrivateVar->getType() ==
2691 moduleTranslation.
convertType(blockArg.getType()));
2692 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2695 auto continuationBlockOrError =
2697 builder, moduleTranslation);
2699 if (failed(
handleError(continuationBlockOrError, opInst)))
2700 return llvm::make_error<PreviouslyReportedError>();
2702 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2710 taskloopOp.getLoc(), privateVarsInfo.
llvmVars,
2712 return llvm::make_error<PreviouslyReportedError>();
2715 taskStructMgr.freeStructPtr();
2717 return llvm::Error::success();
2723 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2724 llvm::Value *destPtr, llvm::Value *srcPtr)
2726 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2727 builder.restoreIP(codegenIP);
2730 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2732 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
2734 TaskContextStructManager &srcStructMgr = taskStructMgr;
2735 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2737 destStructMgr.generateTaskContextStruct();
2738 llvm::Value *dest = destStructMgr.getStructPtr();
2739 dest->setName(
"omp.taskloop.context.dest");
2740 builder.CreateStore(dest, destPtr);
2743 srcStructMgr.createGEPsToPrivateVars(src);
2745 destStructMgr.createGEPsToPrivateVars(dest);
2748 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2749 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
2752 if (!privDecl.readsFromMold())
2754 assert(llvmPrivateVarAlloc &&
2755 "reads from mold so shouldn't have been skipped");
2758 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2759 llvmPrivateVarAlloc, builder.GetInsertBlock());
2760 if (!privateVarOrErr)
2761 return privateVarOrErr.takeError();
2770 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2771 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2772 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2774 llvmPrivateVarAlloc = builder.CreateLoad(
2775 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2777 assert(llvmPrivateVarAlloc->getType() ==
2778 moduleTranslation.
convertType(blockArg.getType()));
2786 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2787 privateVarsInfo.
privatizers, taskloopOp.getPrivateNeedsBarrier())))
2788 return llvm::make_error<PreviouslyReportedError>();
2790 return builder.saveIP();
2793 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2800 llvm::Value *ifCond =
nullptr;
2801 llvm::Value *grainsize =
nullptr;
2803 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2804 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2805 if (
Value ifVar = taskloopOp.getIfExpr())
2808 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
2810 }
else if (numTasksVal) {
2811 grainsize = moduleTranslation.
lookupValue(numTasksVal);
2815 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
2816 if (taskStructMgr.getStructPtr())
2817 taskDupOrNull = taskDupCB;
2827 llvm::omp::Directive::OMPD_taskgroup);
2829 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2830 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2832 ompLoc, allocaIP, bodyCB, loopInfo,
2833 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[0]),
2834 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[0]),
2835 moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]),
2836 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2837 sched, moduleTranslation.
lookupValue(taskloopOp.getFinal()),
2838 taskloopOp.getMergeable(),
2839 moduleTranslation.
lookupValue(taskloopOp.getPriority()),
2840 taskDupOrNull, taskStructMgr.getStructPtr());
2847 builder.restoreIP(*afterIP);
2855 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2859 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2860 builder.restoreIP(codegenIP);
2862 builder, moduleTranslation)
2867 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2868 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2875 builder.restoreIP(*afterIP);
2894 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2898 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2900 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2904 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2907 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
2908 llvm::Type *ivType = step->getType();
2909 llvm::Value *chunk =
nullptr;
2910 if (wsloopOp.getScheduleChunk()) {
2911 llvm::Value *chunkVar =
2912 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
2913 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2916 omp::DistributeOp distributeOp =
nullptr;
2917 llvm::Value *distScheduleChunk =
nullptr;
2918 bool hasDistSchedule =
false;
2919 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
2920 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
2921 hasDistSchedule = distributeOp.getDistScheduleStatic();
2922 if (distributeOp.getDistScheduleChunkSize()) {
2923 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
2924 distributeOp.getDistScheduleChunkSize());
2925 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2933 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2937 wsloopOp.getNumReductionVars());
2940 builder, moduleTranslation, privateVarsInfo, allocaIP);
2947 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2952 moduleTranslation, allocaIP, reductionDecls,
2953 privateReductionVariables, reductionVariableMap,
2954 deferredStores, isByRef)))
2963 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2965 wsloopOp.getPrivateNeedsBarrier())))
2968 assert(afterAllocas.get()->getSinglePredecessor());
2969 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2971 afterAllocas.get()->getSinglePredecessor(),
2972 reductionDecls, privateReductionVariables,
2973 reductionVariableMap, isByRef, deferredStores)))
2977 bool isOrdered = wsloopOp.getOrdered().has_value();
2978 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2979 bool isSimd = wsloopOp.getScheduleSimd();
2980 bool loopNeedsBarrier = !wsloopOp.getNowait();
2985 llvm::omp::WorksharingLoopType workshareLoopType =
2986 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
2987 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2988 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2992 llvm::omp::Directive::OMPD_for);
2994 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2997 LinearClauseProcessor linearClauseProcessor;
2999 if (!wsloopOp.getLinearVars().empty()) {
3000 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3002 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3004 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3005 linearClauseProcessor.createLinearVar(
3006 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3008 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3009 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3013 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3021 if (!wsloopOp.getLinearVars().empty()) {
3022 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3023 loopInfo->getPreheader());
3024 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3026 builder.saveIP(), llvm::omp::OMPD_barrier);
3029 builder.restoreIP(*afterBarrierIP);
3030 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3031 loopInfo->getIndVar());
3032 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3035 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3038 bool noLoopMode =
false;
3039 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3041 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3045 if (loopOp == targetCapturedOp) {
3046 omp::TargetRegionFlags kernelFlags =
3047 targetOp.getKernelExecFlags(targetCapturedOp);
3048 if (omp::bitEnumContainsAll(kernelFlags,
3049 omp::TargetRegionFlags::spmd |
3050 omp::TargetRegionFlags::no_loop) &&
3051 !omp::bitEnumContainsAny(kernelFlags,
3052 omp::TargetRegionFlags::generic))
3057 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3058 ompBuilder->applyWorkshareLoop(
3059 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3060 convertToScheduleKind(schedule), chunk, isSimd,
3061 scheduleMod == omp::ScheduleModifier::monotonic,
3062 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3063 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3069 if (!wsloopOp.getLinearVars().empty()) {
3070 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3071 assert(loopInfo->getLastIter() &&
3072 "`lastiter` in CanonicalLoopInfo is nullptr");
3073 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3074 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3075 loopInfo->getLastIter());
3078 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3079 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3081 builder.restoreIP(oldIP);
3089 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3090 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3103 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3105 assert(isByRef.size() == opInst.getNumReductionVars());
3118 opInst.getNumReductionVars());
3121 auto bodyGenCB = [&](InsertPointTy allocaIP,
3122 InsertPointTy codeGenIP) -> llvm::Error {
3124 builder, moduleTranslation, privateVarsInfo, allocaIP);
3126 return llvm::make_error<PreviouslyReportedError>();
3132 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3135 InsertPointTy(allocaIP.getBlock(),
3136 allocaIP.getBlock()->getTerminator()->getIterator());
3139 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3140 reductionDecls, privateReductionVariables, reductionVariableMap,
3141 deferredStores, isByRef)))
3142 return llvm::make_error<PreviouslyReportedError>();
3144 assert(afterAllocas.get()->getSinglePredecessor());
3145 builder.restoreIP(codeGenIP);
3151 return llvm::make_error<PreviouslyReportedError>();
3154 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3156 opInst.getPrivateNeedsBarrier())))
3157 return llvm::make_error<PreviouslyReportedError>();
3160 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3161 afterAllocas.get()->getSinglePredecessor(),
3162 reductionDecls, privateReductionVariables,
3163 reductionVariableMap, isByRef, deferredStores)))
3164 return llvm::make_error<PreviouslyReportedError>();
3169 moduleTranslation, allocaIP);
3173 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3175 return regionBlock.takeError();
3178 if (opInst.getNumReductionVars() > 0) {
3183 owningReductionGenRefDataPtrGens;
3185 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3187 owningReductionGenRefDataPtrGens,
3188 privateReductionVariables, reductionInfos, isByRef);
3191 builder.SetInsertPoint((*regionBlock)->getTerminator());
3194 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3195 builder.SetInsertPoint(tempTerminator);
3197 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3198 ompBuilder->createReductions(
3199 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3201 if (!contInsertPoint)
3202 return contInsertPoint.takeError();
3204 if (!contInsertPoint->getBlock())
3205 return llvm::make_error<PreviouslyReportedError>();
3207 tempTerminator->eraseFromParent();
3208 builder.restoreIP(*contInsertPoint);
3211 return llvm::Error::success();
3214 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3215 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3224 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3225 InsertPointTy oldIP = builder.saveIP();
3226 builder.restoreIP(codeGenIP);
3231 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3232 [](omp::DeclareReductionOp reductionDecl) {
3233 return &reductionDecl.getCleanupRegion();
3236 reductionCleanupRegions, privateReductionVariables,
3237 moduleTranslation, builder,
"omp.reduction.cleanup")))
3238 return llvm::createStringError(
3239 "failed to inline `cleanup` region of `omp.declare_reduction`");
3244 return llvm::make_error<PreviouslyReportedError>();
3248 if (isCancellable) {
3249 auto IPOrErr = ompBuilder->createBarrier(
3250 llvm::OpenMPIRBuilder::LocationDescription(builder),
3251 llvm::omp::Directive::OMPD_unknown,
3255 return IPOrErr.takeError();
3258 builder.restoreIP(oldIP);
3259 return llvm::Error::success();
3262 llvm::Value *ifCond =
nullptr;
3263 if (
auto ifVar = opInst.getIfExpr())
3265 llvm::Value *numThreads =
nullptr;
3266 if (
auto numThreadsVar = opInst.getNumThreads())
3267 numThreads = moduleTranslation.
lookupValue(numThreadsVar);
3268 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3269 if (
auto bind = opInst.getProcBindKind())
3272 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3274 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3276 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3277 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3278 ifCond, numThreads, pbKind, isCancellable);
3283 builder.restoreIP(*afterIP);
3288static llvm::omp::OrderKind
3291 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3293 case omp::ClauseOrderKind::Concurrent:
3294 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3296 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3304 auto simdOp = cast<omp::SimdOp>(opInst);
3312 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3315 simdOp.getNumReductionVars());
3320 assert(isByRef.size() == simdOp.getNumReductionVars());
3322 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3326 builder, moduleTranslation, privateVarsInfo, allocaIP);
3331 LinearClauseProcessor linearClauseProcessor;
3333 if (!simdOp.getLinearVars().empty()) {
3334 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3336 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3337 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3338 bool isImplicit =
false;
3339 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3343 if (linearVar == mlirPrivVar) {
3345 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3346 llvmPrivateVar, idx);
3352 linearClauseProcessor.createLinearVar(
3353 builder, moduleTranslation,
3356 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3357 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3361 moduleTranslation, allocaIP, reductionDecls,
3362 privateReductionVariables, reductionVariableMap,
3363 deferredStores, isByRef)))
3374 assert(afterAllocas.get()->getSinglePredecessor());
3375 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3377 afterAllocas.get()->getSinglePredecessor(),
3378 reductionDecls, privateReductionVariables,
3379 reductionVariableMap, isByRef, deferredStores)))
3382 llvm::ConstantInt *simdlen =
nullptr;
3383 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3384 simdlen = builder.getInt64(simdlenVar.value());
3386 llvm::ConstantInt *safelen =
nullptr;
3387 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3388 safelen = builder.getInt64(safelenVar.value());
3390 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3393 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3394 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3396 for (
size_t i = 0; i < operands.size(); ++i) {
3397 llvm::Value *alignment =
nullptr;
3398 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3399 llvm::Type *ty = llvmVal->getType();
3401 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3402 alignment = builder.getInt64(intAttr.getInt());
3403 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3404 assert(alignment &&
"Invalid alignment value");
3408 if (!intAttr.getValue().isPowerOf2())
3411 auto curInsert = builder.saveIP();
3412 builder.SetInsertPoint(sourceBlock);
3413 llvmVal = builder.CreateLoad(ty, llvmVal);
3414 builder.restoreIP(curInsert);
3415 alignedVars[llvmVal] = alignment;
3419 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3426 if (simdOp.getLinearVars().size()) {
3427 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3428 loopInfo->getPreheader());
3430 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3431 loopInfo->getIndVar());
3433 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3435 ompBuilder->applySimd(loopInfo, alignedVars,
3437 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3439 order, simdlen, safelen);
3441 linearClauseProcessor.emitStoresForLinearVar(builder);
3442 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++)
3443 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3450 for (
auto [i, tuple] : llvm::enumerate(
3451 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3452 privateReductionVariables))) {
3453 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3455 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3456 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3457 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3461 llvm::Value *redValue = originalVariable;
3464 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3465 llvm::Value *privateRedValue = builder.CreateLoad(
3466 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3467 llvm::Value *reduced;
3469 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3472 builder.restoreIP(res.get());
3476 builder.CreateStore(reduced, originalVariable);
3481 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3482 [](omp::DeclareReductionOp reductionDecl) {
3483 return &reductionDecl.getCleanupRegion();
3486 moduleTranslation, builder,
3487 "omp.reduction.cleanup")))
3500 auto loopOp = cast<omp::LoopNestOp>(opInst);
3506 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3511 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3512 llvm::Value *iv) -> llvm::Error {
3515 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3520 bodyInsertPoints.push_back(ip);
3522 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3523 return llvm::Error::success();
3526 builder.restoreIP(ip);
3528 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3530 return regionBlock.takeError();
3532 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3533 return llvm::Error::success();
3541 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3542 llvm::Value *lowerBound =
3543 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3544 llvm::Value *upperBound =
3545 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3546 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3551 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3552 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3554 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3556 computeIP = loopInfos.front()->getPreheaderIP();
3560 ompBuilder->createCanonicalLoop(
3561 loc, bodyGen, lowerBound, upperBound, step,
3562 true, loopOp.getLoopInclusive(), computeIP);
3567 loopInfos.push_back(*loopResult);
3570 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3571 loopInfos.front()->getAfterIP();
3574 if (
const auto &tiles = loopOp.getTileSizes()) {
3575 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3578 for (
auto tile : tiles.value()) {
3579 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
3580 tileSizes.push_back(tileVal);
3583 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3584 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3588 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3589 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3590 afterIP = {afterAfterBB, afterAfterBB->begin()};
3594 for (
const auto &newLoop : newLoops)
3595 loopInfos.push_back(newLoop);
3599 const auto &numCollapse = loopOp.getCollapseNumLoops();
3601 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3603 auto newTopLoopInfo =
3604 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3606 assert(newTopLoopInfo &&
"New top loop information is missing");
3607 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
3608 [&](OpenMPLoopInfoStackFrame &frame) {
3609 frame.loopInfo = newTopLoopInfo;
3617 builder.restoreIP(afterIP);
3627 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3628 Value loopIV = op.getInductionVar();
3629 Value loopTC = op.getTripCount();
3631 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
3634 ompBuilder->createCanonicalLoop(
3636 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3639 moduleTranslation.
mapValue(loopIV, llvmIV);
3641 builder.restoreIP(ip);
3646 return bodyGenStatus.takeError();
3648 llvmTC,
"omp.loop");
3650 return op.emitError(llvm::toString(llvmOrError.takeError()));
3652 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3653 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3654 builder.restoreIP(afterIP);
3657 if (
Value cli = op.getCli())
3670 Value applyee = op.getApplyee();
3671 assert(applyee &&
"Loop to apply unrolling on required");
3673 llvm::CanonicalLoopInfo *consBuilderCLI =
3675 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3676 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3684static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3687 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3692 for (
Value size : op.getSizes()) {
3693 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
3694 assert(translatedSize &&
3695 "sizes clause arguments must already be translated");
3696 translatedSizes.push_back(translatedSize);
3699 for (
Value applyee : op.getApplyees()) {
3700 llvm::CanonicalLoopInfo *consBuilderCLI =
3702 assert(applyee &&
"Canonical loop must already been translated");
3703 translatedLoops.push_back(consBuilderCLI);
3706 auto generatedLoops =
3707 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3708 if (!op.getGeneratees().empty()) {
3709 for (
auto [mlirLoop,
genLoop] :
3710 zip_equal(op.getGeneratees(), generatedLoops))
3715 for (
Value applyee : op.getApplyees())
3722static llvm::AtomicOrdering
3725 return llvm::AtomicOrdering::Monotonic;
3728 case omp::ClauseMemoryOrderKind::Seq_cst:
3729 return llvm::AtomicOrdering::SequentiallyConsistent;
3730 case omp::ClauseMemoryOrderKind::Acq_rel:
3731 return llvm::AtomicOrdering::AcquireRelease;
3732 case omp::ClauseMemoryOrderKind::Acquire:
3733 return llvm::AtomicOrdering::Acquire;
3734 case omp::ClauseMemoryOrderKind::Release:
3735 return llvm::AtomicOrdering::Release;
3736 case omp::ClauseMemoryOrderKind::Relaxed:
3737 return llvm::AtomicOrdering::Monotonic;
3739 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
3746 auto readOp = cast<omp::AtomicReadOp>(opInst);
3751 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3757 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
3758 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
3760 llvm::Type *elementType =
3761 moduleTranslation.
convertType(readOp.getElementType());
3763 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
3764 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
3765 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3773 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3778 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3781 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3783 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
3784 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
3785 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
3786 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
3789 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3797 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
3798 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
3799 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
3800 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
3801 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
3802 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
3803 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
3804 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
3805 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
3806 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3810 bool &isIgnoreDenormalMode,
3811 bool &isFineGrainedMemory,
3812 bool &isRemoteMemory) {
3813 isIgnoreDenormalMode =
false;
3814 isFineGrainedMemory =
false;
3815 isRemoteMemory =
false;
3816 if (atomicUpdateOp &&
3817 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3818 mlir::omp::AtomicControlAttr atomicControlAttr =
3819 atomicUpdateOp.getAtomicControlAttr();
3820 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3821 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3822 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3829 llvm::IRBuilderBase &builder,
3836 auto &innerOpList = opInst.getRegion().front().getOperations();
3837 bool isXBinopExpr{
false};
3838 llvm::AtomicRMWInst::BinOp binop;
3840 llvm::Value *llvmExpr =
nullptr;
3841 llvm::Value *llvmX =
nullptr;
3842 llvm::Type *llvmXElementType =
nullptr;
3843 if (innerOpList.size() == 2) {
3849 opInst.getRegion().getArgument(0))) {
3850 return opInst.emitError(
"no atomic update operation with region argument"
3851 " as operand found inside atomic.update region");
3854 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
3856 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3860 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3862 llvmX = moduleTranslation.
lookupValue(opInst.getX());
3864 opInst.getRegion().getArgument(0).getType());
3865 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3869 llvm::AtomicOrdering atomicOrdering =
3874 [&opInst, &moduleTranslation](
3875 llvm::Value *atomicx,
3878 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
3879 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3880 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3881 return llvm::make_error<PreviouslyReportedError>();
3883 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3884 assert(yieldop && yieldop.getResults().size() == 1 &&
3885 "terminator must be omp.yield op and it must have exactly one "
3887 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3890 bool isIgnoreDenormalMode;
3891 bool isFineGrainedMemory;
3892 bool isRemoteMemory;
3897 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3898 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3899 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3900 atomicOrdering, binop, updateFn,
3901 isXBinopExpr, isIgnoreDenormalMode,
3902 isFineGrainedMemory, isRemoteMemory);
3907 builder.restoreIP(*afterIP);
3913 llvm::IRBuilderBase &builder,
3920 bool isXBinopExpr =
false, isPostfixUpdate =
false;
3921 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3923 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3924 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3926 assert((atomicUpdateOp || atomicWriteOp) &&
3927 "internal op must be an atomic.update or atomic.write op");
3929 if (atomicWriteOp) {
3930 isPostfixUpdate =
true;
3931 mlirExpr = atomicWriteOp.getExpr();
3933 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3934 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3935 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3938 if (innerOpList.size() == 2) {
3941 atomicUpdateOp.getRegion().getArgument(0))) {
3942 return atomicUpdateOp.emitError(
3943 "no atomic update operation with region argument"
3944 " as operand found inside atomic.update region");
3948 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3951 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3955 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
3956 llvm::Value *llvmX =
3957 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3958 llvm::Value *llvmV =
3959 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3960 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
3961 atomicCaptureOp.getAtomicReadOp().getElementType());
3962 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3965 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3969 llvm::AtomicOrdering atomicOrdering =
3973 [&](llvm::Value *atomicx,
3976 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
3977 Block &bb = *atomicUpdateOp.getRegion().
begin();
3978 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
3980 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
3981 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
3982 return llvm::make_error<PreviouslyReportedError>();
3984 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
3985 assert(yieldop && yieldop.getResults().size() == 1 &&
3986 "terminator must be omp.yield op and it must have exactly one "
3988 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
3991 bool isIgnoreDenormalMode;
3992 bool isFineGrainedMemory;
3993 bool isRemoteMemory;
3995 isFineGrainedMemory, isRemoteMemory);
3998 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3999 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4000 ompBuilder->createAtomicCapture(
4001 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4002 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4003 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4005 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4008 builder.restoreIP(*afterIP);
4013 omp::ClauseCancellationConstructType directive) {
4014 switch (directive) {
4015 case omp::ClauseCancellationConstructType::Loop:
4016 return llvm::omp::Directive::OMPD_for;
4017 case omp::ClauseCancellationConstructType::Parallel:
4018 return llvm::omp::Directive::OMPD_parallel;
4019 case omp::ClauseCancellationConstructType::Sections:
4020 return llvm::omp::Directive::OMPD_sections;
4021 case omp::ClauseCancellationConstructType::Taskgroup:
4022 return llvm::omp::Directive::OMPD_taskgroup;
4024 llvm_unreachable(
"Unhandled cancellation construct type");
4033 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4036 llvm::Value *ifCond =
nullptr;
4037 if (
Value ifVar = op.getIfExpr())
4040 llvm::omp::Directive cancelledDirective =
4043 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4044 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4046 if (failed(
handleError(afterIP, *op.getOperation())))
4049 builder.restoreIP(afterIP.get());
4056 llvm::IRBuilderBase &builder,
4061 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4064 llvm::omp::Directive cancelledDirective =
4067 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4068 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4070 if (failed(
handleError(afterIP, *op.getOperation())))
4073 builder.restoreIP(afterIP.get());
4083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4085 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4090 Value symAddr = threadprivateOp.getSymAddr();
4093 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4096 if (!isa<LLVM::AddressOfOp>(symOp))
4097 return opInst.
emitError(
"Addressing symbol not found");
4098 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4100 LLVM::GlobalOp global =
4101 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4102 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4104 if (!ompBuilder->Config.isTargetDevice()) {
4105 llvm::Type *type = globalValue->getValueType();
4106 llvm::TypeSize typeSize =
4107 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4109 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4110 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4111 ompLoc, globalValue, size, global.getSymName() +
".cache");
4120static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4122 switch (deviceClause) {
4123 case mlir::omp::DeclareTargetDeviceType::host:
4124 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4126 case mlir::omp::DeclareTargetDeviceType::nohost:
4127 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4129 case mlir::omp::DeclareTargetDeviceType::any:
4130 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4133 llvm_unreachable(
"unhandled device clause");
4136static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4138 mlir::omp::DeclareTargetCaptureClause captureClause) {
4139 switch (captureClause) {
4140 case mlir::omp::DeclareTargetCaptureClause::to:
4141 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4142 case mlir::omp::DeclareTargetCaptureClause::link:
4143 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4144 case mlir::omp::DeclareTargetCaptureClause::enter:
4145 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4146 case mlir::omp::DeclareTargetCaptureClause::none:
4147 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4149 llvm_unreachable(
"unhandled capture clause");
4154 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4156 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4157 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4158 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4163static llvm::SmallString<64>
4165 llvm::OpenMPIRBuilder &ompBuilder) {
4167 llvm::raw_svector_ostream os(suffix);
4170 auto fileInfoCallBack = [&loc]() {
4171 return std::pair<std::string, uint64_t>(
4172 llvm::StringRef(loc.getFilename()), loc.getLine());
4175 auto vfs = llvm::vfs::getRealFileSystem();
4178 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4180 os <<
"_decl_tgt_ref_ptr";
4186 if (
auto declareTargetGlobal =
4187 dyn_cast_if_present<omp::DeclareTargetInterface>(
4189 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4190 omp::DeclareTargetCaptureClause::link)
4196 if (
auto declareTargetGlobal =
4197 dyn_cast_if_present<omp::DeclareTargetInterface>(
4199 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4200 omp::DeclareTargetCaptureClause::to ||
4201 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4202 omp::DeclareTargetCaptureClause::enter)
4216 if (
auto declareTargetGlobal =
4217 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4220 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4221 omp::DeclareTargetCaptureClause::link) ||
4222 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4223 omp::DeclareTargetCaptureClause::to &&
4224 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4228 if (gOp.getSymName().contains(suffix))
4233 (gOp.getSymName().str() + suffix.str()).str());
4242struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4243 SmallVector<Operation *, 4> Mappers;
4246 void append(MapInfosTy &curInfo) {
4247 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4248 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4257struct MapInfoData : MapInfosTy {
4258 llvm::SmallVector<bool, 4> IsDeclareTarget;
4259 llvm::SmallVector<bool, 4> IsAMember;
4261 llvm::SmallVector<bool, 4> IsAMapping;
4262 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4263 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4266 llvm::SmallVector<llvm::Type *, 4> BaseType;
4269 void append(MapInfoData &CurInfo) {
4270 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4271 CurInfo.IsDeclareTarget.end());
4272 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4273 OriginalValue.append(CurInfo.OriginalValue.begin(),
4274 CurInfo.OriginalValue.end());
4275 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4276 MapInfosTy::append(CurInfo);
4280enum class TargetDirectiveEnumTy : uint32_t {
4284 TargetEnterData = 3,
4289static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4290 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4291 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4292 .Case([](omp::TargetEnterDataOp) {
4293 return TargetDirectiveEnumTy::TargetEnterData;
4295 .Case([&](omp::TargetExitDataOp) {
4296 return TargetDirectiveEnumTy::TargetExitData;
4298 .Case([&](omp::TargetUpdateOp) {
4299 return TargetDirectiveEnumTy::TargetUpdate;
4301 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4302 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4309 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4310 arrTy.getElementType()))
4327 llvm::Value *basePointer,
4328 llvm::Type *baseType,
4329 llvm::IRBuilderBase &builder,
4331 if (
auto memberClause =
4332 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4337 if (!memberClause.getBounds().empty()) {
4338 llvm::Value *elementCount = builder.getInt64(1);
4339 for (
auto bounds : memberClause.getBounds()) {
4340 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4341 bounds.getDefiningOp())) {
4346 elementCount = builder.CreateMul(
4350 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
4351 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
4352 builder.getInt64(1)));
4359 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4367 return builder.CreateMul(elementCount,
4368 builder.getInt64(underlyingTypeSzInBits / 8));
4379static llvm::omp::OpenMPOffloadMappingFlags
4381 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4382 return (mlirFlags & flag) == flag;
4384 const bool hasExplicitMap =
4385 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
4386 omp::ClauseMapFlags::none;
4388 llvm::omp::OpenMPOffloadMappingFlags mapType =
4389 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4392 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4395 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4398 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4401 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4404 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4407 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4410 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4413 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4416 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4419 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4422 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4425 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4428 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4429 if (!hasExplicitMap)
4430 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4440 ArrayRef<Value> useDevAddrOperands = {},
4441 ArrayRef<Value> hasDevAddrOperands = {}) {
4442 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4450 for (Value mapValue : mapVars) {
4451 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4452 for (
auto member : map.getMembers())
4453 if (member == mapOp)
4460 for (Value mapValue : mapVars) {
4461 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4463 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4464 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4465 mapData.Pointers.push_back(mapData.OriginalValue.back());
4467 if (llvm::Value *refPtr =
4469 mapData.IsDeclareTarget.push_back(
true);
4470 mapData.BasePointers.push_back(refPtr);
4472 mapData.IsDeclareTarget.push_back(
true);
4473 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4475 mapData.IsDeclareTarget.push_back(
false);
4476 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4479 mapData.BaseType.push_back(
4480 moduleTranslation.
convertType(mapOp.getVarType()));
4481 mapData.Sizes.push_back(
4482 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4483 mapData.BaseType.back(), builder, moduleTranslation));
4484 mapData.MapClause.push_back(mapOp.getOperation());
4486 mapData.Names.push_back(LLVM::createMappingInformation(
4488 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4489 if (mapOp.getMapperId())
4490 mapData.Mappers.push_back(
4492 mapOp, mapOp.getMapperIdAttr()));
4494 mapData.Mappers.push_back(
nullptr);
4495 mapData.IsAMapping.push_back(
true);
4496 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4499 auto findMapInfo = [&mapData](llvm::Value *val,
4500 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4503 for (llvm::Value *basePtr : mapData.OriginalValue) {
4504 if (basePtr == val && mapData.IsAMapping[index]) {
4506 mapData.Types[index] |=
4507 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4508 mapData.DevicePointers[index] = devInfoTy;
4516 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4517 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4518 for (Value mapValue : useDevOperands) {
4519 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4521 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4522 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4525 if (!findMapInfo(origValue, devInfoTy)) {
4526 mapData.OriginalValue.push_back(origValue);
4527 mapData.Pointers.push_back(mapData.OriginalValue.back());
4528 mapData.IsDeclareTarget.push_back(
false);
4529 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4530 mapData.BaseType.push_back(
4531 moduleTranslation.
convertType(mapOp.getVarType()));
4532 mapData.Sizes.push_back(builder.getInt64(0));
4533 mapData.MapClause.push_back(mapOp.getOperation());
4534 mapData.Types.push_back(
4535 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4536 mapData.Names.push_back(LLVM::createMappingInformation(
4538 mapData.DevicePointers.push_back(devInfoTy);
4539 mapData.Mappers.push_back(
nullptr);
4540 mapData.IsAMapping.push_back(
false);
4541 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4546 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4547 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4549 for (Value mapValue : hasDevAddrOperands) {
4550 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4552 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4553 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
4555 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4557 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4558 omp::ClauseMapFlags::none;
4560 mapData.OriginalValue.push_back(origValue);
4561 mapData.BasePointers.push_back(origValue);
4562 mapData.Pointers.push_back(origValue);
4563 mapData.IsDeclareTarget.push_back(
false);
4564 mapData.BaseType.push_back(
4565 moduleTranslation.
convertType(mapOp.getVarType()));
4566 mapData.Sizes.push_back(
4567 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
4568 mapData.MapClause.push_back(mapOp.getOperation());
4569 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4573 mapData.Types.push_back(mapType);
4577 if (mapOp.getMapperId()) {
4578 mapData.Mappers.push_back(
4580 mapOp, mapOp.getMapperIdAttr()));
4582 mapData.Mappers.push_back(
nullptr);
4587 mapData.Types.push_back(
4588 isDevicePtr ? mapType
4589 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4590 mapData.Mappers.push_back(
nullptr);
4592 mapData.Names.push_back(LLVM::createMappingInformation(
4594 mapData.DevicePointers.push_back(
4595 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4596 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4597 mapData.IsAMapping.push_back(
false);
4598 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4603 auto *res = llvm::find(mapData.MapClause, memberOp);
4604 assert(res != mapData.MapClause.end() &&
4605 "MapInfoOp for member not found in MapData, cannot return index");
4606 return std::distance(mapData.MapClause.begin(), res);
4610 omp::MapInfoOp mapInfo) {
4611 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4621 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4622 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4624 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4625 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4626 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4628 if (aIndex == bIndex)
4631 if (aIndex < bIndex)
4634 if (aIndex > bIndex)
4641 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4643 occludedChildren.push_back(
b);
4645 occludedChildren.push_back(a);
4646 return memberAParent;
4652 for (
auto v : occludedChildren)
4659 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4661 if (indexAttr.size() == 1)
4662 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4666 return llvm::cast<omp::MapInfoOp>(
4691static std::vector<llvm::Value *>
4693 llvm::IRBuilderBase &builder,
bool isArrayTy,
4695 std::vector<llvm::Value *> idx;
4706 idx.push_back(builder.getInt64(0));
4707 for (
int i = bounds.size() - 1; i >= 0; --i) {
4708 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4709 bounds[i].getDefiningOp())) {
4710 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4728 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4729 for (
int i = bounds.size() - 1; i >= 0; --i) {
4730 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4731 bounds[i].getDefiningOp())) {
4732 if (i == ((
int)bounds.size() - 1))
4734 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4736 idx.back() = builder.CreateAdd(
4737 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
4738 boundOp.getExtent())),
4739 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
4748 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
4749 return cast<IntegerAttr>(value).getInt();
4757 omp::MapInfoOp parentOp) {
4759 if (parentOp.getMembers().empty())
4763 if (parentOp.getMembers().size() == 1) {
4764 overlapMapDataIdxs.push_back(0);
4770 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4771 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4772 memberByIndex.push_back(
4773 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4778 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4779 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
4785 for (
auto v : memberByIndex) {
4789 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
4792 llvm::SmallVector<int64_t> xArr(x.second.size());
4793 getAsIntegers(x.second, xArr);
4794 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4795 xArr.size() >= vArr.size();
4801 for (
auto v : memberByIndex)
4802 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4803 overlapMapDataIdxs.push_back(v.first);
4815 if (mapOp.getVarPtrPtr())
4844 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
4845 MapInfoData &mapData, uint64_t mapDataIndex,
4846 TargetDirectiveEnumTy targetDirective) {
4847 assert(!ompBuilder.Config.isTargetDevice() &&
4848 "function only supported for host device codegen");
4854 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4855 (targetDirective == TargetDirectiveEnumTy::Target &&
4856 !mapData.IsDeclareTarget[mapDataIndex])
4857 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4858 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4861 bool hasUserMapper = mapData.Mappers[mapDataIndex] !=
nullptr;
4862 if (hasUserMapper) {
4863 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4867 mapFlags parentFlags = mapData.Types[mapDataIndex];
4868 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4869 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4870 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4871 baseFlag |= (parentFlags & preserve);
4874 combinedInfo.Types.emplace_back(baseFlag);
4875 combinedInfo.DevicePointers.emplace_back(
4876 mapData.DevicePointers[mapDataIndex]);
4877 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4879 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4880 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4890 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4891 llvm::Value *lowAddr, *highAddr;
4892 if (!parentClause.getPartialMap()) {
4893 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4894 builder.getPtrTy());
4895 highAddr = builder.CreatePointerCast(
4896 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4897 mapData.Pointers[mapDataIndex], 1),
4898 builder.getPtrTy());
4899 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4901 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4904 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4905 builder.getPtrTy());
4908 highAddr = builder.CreatePointerCast(
4909 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4910 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4911 builder.getPtrTy());
4912 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4915 llvm::Value *size = builder.CreateIntCast(
4916 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4917 builder.getInt64Ty(),
4919 combinedInfo.Sizes.push_back(size);
4921 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4922 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4930 if (!parentClause.getPartialMap()) {
4935 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4936 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
4937 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
4938 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4939 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4941 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
4942 combinedInfo.Types.emplace_back(mapFlag);
4943 combinedInfo.DevicePointers.emplace_back(
4944 mapData.DevicePointers[mapDataIndex]);
4946 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4947 combinedInfo.BasePointers.emplace_back(
4948 mapData.BasePointers[mapDataIndex]);
4949 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4950 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4951 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4962 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4963 builder.getPtrTy());
4964 highAddr = builder.CreatePointerCast(
4965 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4966 mapData.Pointers[mapDataIndex], 1),
4967 builder.getPtrTy());
4974 for (
auto v : overlapIdxs) {
4977 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
4978 combinedInfo.Types.emplace_back(mapFlag);
4979 combinedInfo.DevicePointers.emplace_back(
4980 mapData.DevicePointers[mapDataOverlapIdx]);
4982 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4983 combinedInfo.BasePointers.emplace_back(
4984 mapData.BasePointers[mapDataIndex]);
4985 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4986 combinedInfo.Pointers.emplace_back(lowAddr);
4987 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4988 builder.CreatePtrDiff(builder.getInt8Ty(),
4989 mapData.OriginalValue[mapDataOverlapIdx],
4991 builder.getInt64Ty(),
true));
4992 lowAddr = builder.CreateConstGEP1_32(
4994 mapData.MapClause[mapDataOverlapIdx]))
4995 ? builder.getPtrTy()
4996 : mapData.BaseType[mapDataOverlapIdx],
4997 mapData.BasePointers[mapDataOverlapIdx], 1);
5000 combinedInfo.Types.emplace_back(mapFlag);
5001 combinedInfo.DevicePointers.emplace_back(
5002 mapData.DevicePointers[mapDataIndex]);
5004 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5005 combinedInfo.BasePointers.emplace_back(
5006 mapData.BasePointers[mapDataIndex]);
5007 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
5008 combinedInfo.Pointers.emplace_back(lowAddr);
5009 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5010 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5011 builder.getInt64Ty(),
true));
5014 return memberOfFlag;
5020 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5021 MapInfoData &mapData, uint64_t mapDataIndex,
5022 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5023 TargetDirectiveEnumTy targetDirective) {
5024 assert(!ompBuilder.Config.isTargetDevice() &&
5025 "function only supported for host device codegen");
5028 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5030 for (
auto mappedMembers : parentClause.getMembers()) {
5032 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5035 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5046 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5047 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5048 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5049 combinedInfo.Types.emplace_back(mapFlag);
5050 combinedInfo.DevicePointers.emplace_back(
5051 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5052 combinedInfo.Mappers.emplace_back(
nullptr);
5053 combinedInfo.Names.emplace_back(
5055 combinedInfo.BasePointers.emplace_back(
5056 mapData.BasePointers[mapDataIndex]);
5057 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5058 combinedInfo.Sizes.emplace_back(builder.getInt64(
5059 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5065 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5066 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5067 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5069 ? parentClause.getVarPtr()
5070 : parentClause.getVarPtrPtr());
5073 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5074 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5075 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5078 combinedInfo.Types.emplace_back(mapFlag);
5079 combinedInfo.DevicePointers.emplace_back(
5080 mapData.DevicePointers[memberDataIdx]);
5081 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5082 combinedInfo.Names.emplace_back(
5084 uint64_t basePointerIndex =
5086 combinedInfo.BasePointers.emplace_back(
5087 mapData.BasePointers[basePointerIndex]);
5088 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5090 llvm::Value *size = mapData.Sizes[memberDataIdx];
5092 size = builder.CreateSelect(
5093 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5094 builder.getInt64(0), size);
5097 combinedInfo.Sizes.emplace_back(size);
5102 MapInfosTy &combinedInfo,
5103 TargetDirectiveEnumTy targetDirective,
5104 int mapDataParentIdx = -1) {
5108 auto mapFlag = mapData.Types[mapDataIdx];
5109 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5113 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5115 if (targetDirective == TargetDirectiveEnumTy::Target &&
5116 !mapData.IsDeclareTarget[mapDataIdx])
5117 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5119 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5121 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5126 if (mapDataParentIdx >= 0)
5127 combinedInfo.BasePointers.emplace_back(
5128 mapData.BasePointers[mapDataParentIdx]);
5130 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5132 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5133 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5134 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5135 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5136 combinedInfo.Types.emplace_back(mapFlag);
5137 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5141 llvm::IRBuilderBase &builder,
5142 llvm::OpenMPIRBuilder &ompBuilder,
5144 MapInfoData &mapData, uint64_t mapDataIndex,
5145 TargetDirectiveEnumTy targetDirective) {
5146 assert(!ompBuilder.Config.isTargetDevice() &&
5147 "function only supported for host device codegen");
5150 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5155 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5156 auto memberClause = llvm::cast<omp::MapInfoOp>(
5157 parentClause.getMembers()[0].getDefiningOp());
5174 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5176 combinedInfo, mapData, mapDataIndex,
5179 combinedInfo, mapData, mapDataIndex,
5180 memberOfParentFlag, targetDirective);
5190 llvm::IRBuilderBase &builder) {
5192 "function only supported for host device codegen");
5193 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5195 if (!mapData.IsDeclareTarget[i]) {
5196 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5197 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5207 switch (captureKind) {
5208 case omp::VariableCaptureKind::ByRef: {
5209 llvm::Value *newV = mapData.Pointers[i];
5211 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5214 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5216 if (!offsetIdx.empty())
5217 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5219 mapData.Pointers[i] = newV;
5221 case omp::VariableCaptureKind::ByCopy: {
5222 llvm::Type *type = mapData.BaseType[i];
5224 if (mapData.Pointers[i]->getType()->isPointerTy())
5225 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5227 newV = mapData.Pointers[i];
5230 auto curInsert = builder.saveIP();
5231 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5233 auto *memTempAlloc =
5234 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5235 builder.SetCurrentDebugLocation(DbgLoc);
5236 builder.restoreIP(curInsert);
5238 builder.CreateStore(newV, memTempAlloc);
5239 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5242 mapData.Pointers[i] = newV;
5243 mapData.BasePointers[i] = newV;
5245 case omp::VariableCaptureKind::This:
5246 case omp::VariableCaptureKind::VLAType:
5247 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5258 MapInfoData &mapData,
5259 TargetDirectiveEnumTy targetDirective) {
5261 "function only supported for host device codegen");
5282 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5285 if (mapData.IsAMember[i])
5288 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5289 if (!mapInfoOp.getMembers().empty()) {
5291 combinedInfo, mapData, i, targetDirective);
5299static llvm::Expected<llvm::Function *>
5301 LLVM::ModuleTranslation &moduleTranslation,
5302 llvm::StringRef mapperFuncName,
5303 TargetDirectiveEnumTy targetDirective);
5305static llvm::Expected<llvm::Function *>
5308 TargetDirectiveEnumTy targetDirective) {
5310 "function only supported for host device codegen");
5311 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5312 std::string mapperFuncName =
5314 {
"omp_mapper", declMapperOp.getSymName()});
5316 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
5320 mapperFuncName, targetDirective);
5323static llvm::Expected<llvm::Function *>
5326 llvm::StringRef mapperFuncName,
5327 TargetDirectiveEnumTy targetDirective) {
5329 "function only supported for host device codegen");
5330 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5331 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5334 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
5337 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5340 MapInfosTy combinedInfo;
5342 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5343 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5344 builder.restoreIP(codeGenIP);
5345 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
5346 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
5347 builder.GetInsertBlock());
5348 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
5351 return llvm::make_error<PreviouslyReportedError>();
5352 MapInfoData mapData;
5355 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5361 return combinedInfo;
5365 if (!combinedInfo.Mappers[i])
5368 moduleTranslation, targetDirective);
5372 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5374 return newFn.takeError();
5375 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
5382 llvm::Value *ifCond =
nullptr;
5383 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5387 llvm::omp::RuntimeFunction RTLFn;
5389 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5392 llvm::OpenMPIRBuilder::TargetDataInfo info(
5395 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5396 bool isOffloadEntry =
5397 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5399 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
5400 llvm::Value *v = moduleTranslation.
lookupValue(dev);
5401 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
5406 .Case([&](omp::TargetDataOp dataOp) {
5410 if (
auto ifVar = dataOp.getIfExpr())
5414 deviceID = getDeviceID(devId);
5416 mapVars = dataOp.getMapVars();
5417 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5418 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5421 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5425 if (
auto ifVar = enterDataOp.getIfExpr())
5429 deviceID = getDeviceID(devId);
5432 enterDataOp.getNowait()
5433 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5434 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5435 mapVars = enterDataOp.getMapVars();
5436 info.HasNoWait = enterDataOp.getNowait();
5439 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5443 if (
auto ifVar = exitDataOp.getIfExpr())
5447 deviceID = getDeviceID(devId);
5449 RTLFn = exitDataOp.getNowait()
5450 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5451 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5452 mapVars = exitDataOp.getMapVars();
5453 info.HasNoWait = exitDataOp.getNowait();
5456 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5460 if (
auto ifVar = updateDataOp.getIfExpr())
5464 deviceID = getDeviceID(devId);
5467 updateDataOp.getNowait()
5468 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5469 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5470 mapVars = updateDataOp.getMapVars();
5471 info.HasNoWait = updateDataOp.getNowait();
5474 .DefaultUnreachable(
"unexpected operation");
5479 if (!isOffloadEntry)
5480 ifCond = builder.getFalse();
5482 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5483 MapInfoData mapData;
5485 builder, useDevicePtrVars, useDeviceAddrVars);
5488 MapInfosTy combinedInfo;
5489 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5490 builder.restoreIP(codeGenIP);
5491 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5493 return combinedInfo;
5499 [&moduleTranslation](
5500 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5504 for (
auto [arg, useDevVar] :
5505 llvm::zip_equal(blockArgs, useDeviceVars)) {
5507 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5508 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5509 : mapInfoOp.getVarPtr();
5512 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5513 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5514 mapInfoData.MapClause, mapInfoData.DevicePointers,
5515 mapInfoData.BasePointers)) {
5516 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5517 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5518 devicePointer != type)
5521 if (llvm::Value *devPtrInfoMap =
5522 mapper ? mapper(basePointer) : basePointer) {
5523 moduleTranslation.
mapValue(arg, devPtrInfoMap);
5530 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5531 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5532 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5535 builder.restoreIP(codeGenIP);
5536 assert(isa<omp::TargetDataOp>(op) &&
5537 "BodyGen requested for non TargetDataOp");
5538 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5539 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
5540 switch (bodyGenType) {
5541 case BodyGenTy::Priv:
5543 if (!info.DevicePtrInfoMap.empty()) {
5544 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5545 blockArgIface.getUseDeviceAddrBlockArgs(),
5546 useDeviceAddrVars, mapData,
5547 [&](llvm::Value *basePointer) -> llvm::Value * {
5548 if (!info.DevicePtrInfoMap[basePointer].second)
5550 return builder.CreateLoad(
5552 info.DevicePtrInfoMap[basePointer].second);
5554 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5555 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5556 mapData, [&](llvm::Value *basePointer) {
5557 return info.DevicePtrInfoMap[basePointer].second;
5561 moduleTranslation)))
5562 return llvm::make_error<PreviouslyReportedError>();
5565 case BodyGenTy::DupNoPriv:
5566 if (info.DevicePtrInfoMap.empty()) {
5569 if (!ompBuilder->Config.IsTargetDevice.value_or(
false)) {
5570 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5571 blockArgIface.getUseDeviceAddrBlockArgs(),
5572 useDeviceAddrVars, mapData);
5573 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5574 blockArgIface.getUseDevicePtrBlockArgs(),
5575 useDevicePtrVars, mapData);
5579 case BodyGenTy::NoPriv:
5581 if (info.DevicePtrInfoMap.empty()) {
5584 if (ompBuilder->Config.IsTargetDevice.value_or(
false)) {
5585 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5586 blockArgIface.getUseDeviceAddrBlockArgs(),
5587 useDeviceAddrVars, mapData);
5588 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5589 blockArgIface.getUseDevicePtrBlockArgs(),
5590 useDevicePtrVars, mapData);
5594 moduleTranslation)))
5595 return llvm::make_error<PreviouslyReportedError>();
5599 return builder.saveIP();
5602 auto customMapperCB =
5604 if (!combinedInfo.Mappers[i])
5606 info.HasMapper =
true;
5608 moduleTranslation, targetDirective);
5611 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5612 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5614 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5615 if (isa<omp::TargetDataOp>(op))
5616 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5617 deviceID, ifCond, info, genMapInfoCB,
5621 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5622 deviceID, ifCond, info, genMapInfoCB,
5623 customMapperCB, &RTLFn);
5629 builder.restoreIP(*afterIP);
5637 auto distributeOp = cast<omp::DistributeOp>(opInst);
5644 bool doDistributeReduction =
5648 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5653 if (doDistributeReduction) {
5654 isByRef =
getIsByRef(teamsOp.getReductionByref());
5655 assert(isByRef.size() == teamsOp.getNumReductionVars());
5658 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5662 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5663 .getReductionBlockArgs();
5666 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5667 reductionDecls, privateReductionVariables, reductionVariableMap,
5672 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5673 auto bodyGenCB = [&](InsertPointTy allocaIP,
5674 InsertPointTy codeGenIP) -> llvm::Error {
5678 moduleTranslation, allocaIP);
5681 builder.restoreIP(codeGenIP);
5687 return llvm::make_error<PreviouslyReportedError>();
5692 return llvm::make_error<PreviouslyReportedError>();
5695 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
5697 distributeOp.getPrivateNeedsBarrier())))
5698 return llvm::make_error<PreviouslyReportedError>();
5701 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5704 builder, moduleTranslation);
5706 return regionBlock.takeError();
5707 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5712 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5715 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5716 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5717 : omp::ClauseScheduleKind::Static;
5719 bool isOrdered = hasDistSchedule;
5720 std::optional<omp::ScheduleModifier> scheduleMod;
5721 bool isSimd =
false;
5722 llvm::omp::WorksharingLoopType workshareLoopType =
5723 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5724 bool loopNeedsBarrier =
false;
5725 llvm::Value *chunk = moduleTranslation.
lookupValue(
5726 distributeOp.getDistScheduleChunkSize());
5727 llvm::CanonicalLoopInfo *loopInfo =
5729 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5730 ompBuilder->applyWorkshareLoop(
5731 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5732 convertToScheduleKind(schedule), chunk, isSimd,
5733 scheduleMod == omp::ScheduleModifier::monotonic,
5734 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5735 workshareLoopType,
false, hasDistSchedule, chunk);
5738 return wsloopIP.takeError();
5741 distributeOp.getLoc(), privVarsInfo.
llvmVars,
5743 return llvm::make_error<PreviouslyReportedError>();
5745 return llvm::Error::success();
5748 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5750 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5751 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5752 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5757 builder.restoreIP(*afterIP);
5759 if (doDistributeReduction) {
5762 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5763 privateReductionVariables, isByRef,
5775 if (!cast<mlir::ModuleOp>(op))
5780 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
5781 attribute.getOpenmpDeviceVersion());
5783 if (attribute.getNoGpuLib())
5786 ompBuilder->createGlobalFlag(
5787 attribute.getDebugKind() ,
5788 "__omp_rtl_debug_kind");
5789 ompBuilder->createGlobalFlag(
5791 .getAssumeTeamsOversubscription()
5793 "__omp_rtl_assume_teams_oversubscription");
5794 ompBuilder->createGlobalFlag(
5796 .getAssumeThreadsOversubscription()
5798 "__omp_rtl_assume_threads_oversubscription");
5799 ompBuilder->createGlobalFlag(
5800 attribute.getAssumeNoThreadState() ,
5801 "__omp_rtl_assume_no_thread_state");
5802 ompBuilder->createGlobalFlag(
5804 .getAssumeNoNestedParallelism()
5806 "__omp_rtl_assume_no_nested_parallelism");
5811 omp::TargetOp targetOp,
5812 llvm::StringRef parentName =
"") {
5813 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
5815 assert(fileLoc &&
"No file found from location");
5816 StringRef fileName = fileLoc.getFilename().getValue();
5818 llvm::sys::fs::UniqueID id;
5819 uint64_t line = fileLoc.getLine();
5820 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
5821 size_t fileHash = llvm::hash_value(fileName.str());
5822 size_t deviceId = 0xdeadf17e;
5824 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5826 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
5827 id.getFile(), line);
5834 llvm::IRBuilderBase &builder, llvm::Function *
func) {
5836 "function only supported for target device codegen");
5837 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5838 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5851 if (mapData.IsDeclareTarget[i]) {
5858 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5859 convertUsersOfConstantsToInstructions(constant,
func,
false);
5866 for (llvm::User *user : mapData.OriginalValue[i]->users())
5867 userVec.push_back(user);
5869 for (llvm::User *user : userVec) {
5870 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
5871 if (insn->getFunction() ==
func) {
5872 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5873 llvm::Value *substitute = mapData.BasePointers[i];
5875 : mapOp.getVarPtr())) {
5876 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5877 substitute = builder.CreateLoad(
5878 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5879 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5881 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5928static llvm::IRBuilderBase::InsertPoint
5930 llvm::Value *input, llvm::Value *&retVal,
5931 llvm::IRBuilderBase &builder,
5932 llvm::OpenMPIRBuilder &ompBuilder,
5934 llvm::IRBuilderBase::InsertPoint allocaIP,
5935 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5936 assert(ompBuilder.Config.isTargetDevice() &&
5937 "function only supported for target device codegen");
5938 builder.restoreIP(allocaIP);
5940 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5942 ompBuilder.M.getContext());
5943 unsigned alignmentValue = 0;
5945 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
5946 if (mapData.OriginalValue[i] == input) {
5947 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5948 capture = mapOp.getMapCaptureType();
5951 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5955 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5956 unsigned int defaultAS =
5957 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5960 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5962 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5963 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5965 builder.CreateStore(&arg, v);
5967 builder.restoreIP(codeGenIP);
5970 case omp::VariableCaptureKind::ByCopy: {
5974 case omp::VariableCaptureKind::ByRef: {
5975 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5977 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5992 if (v->getType()->isPointerTy() && alignmentValue) {
5993 llvm::MDBuilder MDB(builder.getContext());
5994 loadInst->setMetadata(
5995 llvm::LLVMContext::MD_align,
5996 llvm::MDNode::get(builder.getContext(),
5997 MDB.createConstant(llvm::ConstantInt::get(
5998 llvm::Type::getInt64Ty(builder.getContext()),
6005 case omp::VariableCaptureKind::This:
6006 case omp::VariableCaptureKind::VLAType:
6009 assert(
false &&
"Currently unsupported capture kind");
6013 return builder.saveIP();
6030 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6031 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6032 blockArgIface.getHostEvalBlockArgs())) {
6033 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6037 .Case([&](omp::TeamsOp teamsOp) {
6038 if (teamsOp.getNumTeamsLower() == blockArg)
6039 numTeamsLower = hostEvalVar;
6040 else if (teamsOp.getNumTeamsUpper() == blockArg)
6041 numTeamsUpper = hostEvalVar;
6042 else if (teamsOp.getThreadLimit() == blockArg)
6043 threadLimit = hostEvalVar;
6045 llvm_unreachable(
"unsupported host_eval use");
6047 .Case([&](omp::ParallelOp parallelOp) {
6048 if (parallelOp.getNumThreads() == blockArg)
6049 numThreads = hostEvalVar;
6051 llvm_unreachable(
"unsupported host_eval use");
6053 .Case([&](omp::LoopNestOp loopOp) {
6054 auto processBounds =
6058 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6059 if (lb == blockArg) {
6062 (*outBounds)[i] = hostEvalVar;
6068 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6069 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6071 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6073 assert(found &&
"unsupported host_eval use");
6075 .DefaultUnreachable(
"unsupported host_eval use");
6087template <
typename OpTy>
6092 if (OpTy casted = dyn_cast<OpTy>(op))
6095 if (immediateParent)
6096 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6105 return std::nullopt;
6108 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6109 return constAttr.getInt();
6111 return std::nullopt;
6116 uint64_t sizeInBytes = sizeInBits / 8;
6120template <
typename OpTy>
6122 if (op.getNumReductionVars() > 0) {
6127 members.reserve(reductions.size());
6128 for (omp::DeclareReductionOp &red : reductions)
6129 members.push_back(red.getType());
6131 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6147 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6148 bool isTargetDevice,
bool isGPU) {
6151 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6152 if (!isTargetDevice) {
6160 numTeamsLower = teamsOp.getNumTeamsLower();
6161 numTeamsUpper = teamsOp.getNumTeamsUpper();
6162 threadLimit = teamsOp.getThreadLimit();
6166 numThreads = parallelOp.getNumThreads();
6171 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6175 if (numTeamsUpper) {
6177 minTeamsVal = maxTeamsVal = *val;
6179 minTeamsVal = maxTeamsVal = 0;
6185 minTeamsVal = maxTeamsVal = 1;
6187 minTeamsVal = maxTeamsVal = -1;
6192 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6206 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6207 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
6208 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6211 int32_t maxThreadsVal = -1;
6213 setMaxValueFromClause(numThreads, maxThreadsVal);
6221 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6222 if (combinedMaxThreadsVal < 0 ||
6223 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6224 combinedMaxThreadsVal = teamsThreadLimitVal;
6226 if (combinedMaxThreadsVal < 0 ||
6227 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6228 combinedMaxThreadsVal = maxThreadsVal;
6230 int32_t reductionDataSize = 0;
6231 if (isGPU && capturedOp) {
6237 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6239 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6240 omp::TargetRegionFlags::spmd) &&
6241 "invalid kernel flags");
6243 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6244 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6245 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6246 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6247 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6248 if (omp::bitEnumContainsAll(kernelFlags,
6249 omp::TargetRegionFlags::spmd |
6250 omp::TargetRegionFlags::no_loop) &&
6251 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6252 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6254 attrs.MinTeams = minTeamsVal;
6255 attrs.MaxTeams.front() = maxTeamsVal;
6256 attrs.MinThreads = 1;
6257 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6258 attrs.ReductionDataSize = reductionDataSize;
6261 if (attrs.ReductionDataSize != 0)
6262 attrs.ReductionBufferLength = 1024;
6274 omp::TargetOp targetOp,
Operation *capturedOp,
6275 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6277 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6279 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6283 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6286 if (
Value targetThreadLimit = targetOp.getThreadLimit())
6287 attrs.TargetThreadLimit.front() =
6291 attrs.MinTeams = moduleTranslation.
lookupValue(numTeamsLower);
6294 attrs.MaxTeams.front() = moduleTranslation.
lookupValue(numTeamsUpper);
6296 if (teamsThreadLimit)
6297 attrs.TeamsThreadLimit.front() =
6301 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
6303 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6304 omp::TargetRegionFlags::trip_count)) {
6306 attrs.LoopTripCount =
nullptr;
6311 for (
auto [loopLower, loopUpper, loopStep] :
6312 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6313 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
6314 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
6315 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
6317 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6318 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6319 loc, lowerBound, upperBound, step,
true,
6320 loopOp.getLoopInclusive());
6322 if (!attrs.LoopTripCount) {
6323 attrs.LoopTripCount = tripCount;
6328 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6333 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6335 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
6337 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6344 auto targetOp = cast<omp::TargetOp>(opInst);
6348 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6357 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6358 assert(parentBB &&
"No insert block is set for the builder");
6359 llvm::Function *parentLLVMFn = parentBB->getParent();
6360 assert(parentLLVMFn &&
"Parent Function must be valid");
6361 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6362 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6363 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6364 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6367 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6368 bool isGPU = ompBuilder->Config.isGPU();
6371 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6372 auto &targetRegion = targetOp.getRegion();
6389 llvm::Function *llvmOutlinedFn =
nullptr;
6390 TargetDirectiveEnumTy targetDirective =
6391 getTargetDirectiveEnumTyFromOp(&opInst);
6395 bool isOffloadEntry =
6396 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6403 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6405 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6406 std::optional<DenseI64ArrayAttr> privateMapIndices =
6407 targetOp.getPrivateMapsAttr();
6409 for (
auto [privVarIdx, privVarSymPair] :
6410 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6411 auto privVar = std::get<0>(privVarSymPair);
6412 auto privSym = std::get<1>(privVarSymPair);
6414 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6415 omp::PrivateClauseOp privatizer =
6418 if (!privatizer.needsMap())
6422 targetOp.getMappedValueForPrivateVar(privVarIdx);
6423 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6424 "variable that needs mapping");
6429 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6430 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6434 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6436 varType == privVar.getType() &&
6437 "Type of private var doesn't match the type of the mapped value");
6441 mappedPrivateVars.insert(
6443 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6444 (*privateMapIndices)[privVarIdx])});
6448 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6449 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6450 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6451 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6452 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6455 llvm::Function *llvmParentFn =
6457 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6458 assert(llvmParentFn && llvmOutlinedFn &&
6459 "Both parent and outlined functions must exist at this point");
6461 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6462 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6464 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6465 attr.isStringAttribute())
6466 llvmOutlinedFn->addFnAttr(attr);
6468 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6469 attr.isStringAttribute())
6470 llvmOutlinedFn->addFnAttr(attr);
6472 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6473 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6474 llvm::Value *mapOpValue =
6475 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6476 moduleTranslation.
mapValue(arg, mapOpValue);
6478 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6479 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6480 llvm::Value *mapOpValue =
6481 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6482 moduleTranslation.
mapValue(arg, mapOpValue);
6491 allocaIP, &mappedPrivateVars);
6494 return llvm::make_error<PreviouslyReportedError>();
6496 builder.restoreIP(codeGenIP);
6498 &mappedPrivateVars),
6501 return llvm::make_error<PreviouslyReportedError>();
6504 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
6506 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6507 return llvm::make_error<PreviouslyReportedError>();
6511 std::back_inserter(privateCleanupRegions),
6512 [](omp::PrivateClauseOp privatizer) {
6513 return &privatizer.getDeallocRegion();
6517 targetRegion,
"omp.target", builder, moduleTranslation);
6520 return exitBlock.takeError();
6522 builder.SetInsertPoint(*exitBlock);
6523 if (!privateCleanupRegions.empty()) {
6525 privateCleanupRegions, privateVarsInfo.
llvmVars,
6526 moduleTranslation, builder,
"omp.targetop.private.cleanup",
6528 return llvm::createStringError(
6529 "failed to inline `dealloc` region of `omp.private` "
6530 "op in the target region");
6532 return builder.saveIP();
6535 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6538 StringRef parentName = parentFn.getName();
6540 llvm::TargetRegionEntryInfo entryInfo;
6544 MapInfoData mapData;
6549 MapInfosTy combinedInfos;
6551 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6552 builder.restoreIP(codeGenIP);
6553 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6555 return combinedInfos;
6558 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6559 llvm::Value *&retVal, InsertPointTy allocaIP,
6560 InsertPointTy codeGenIP)
6561 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6562 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6563 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6569 if (!isTargetDevice) {
6570 retVal = cast<llvm::Value>(&arg);
6575 *ompBuilder, moduleTranslation,
6576 allocaIP, codeGenIP);
6579 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6580 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6581 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6583 isTargetDevice, isGPU);
6587 if (!isTargetDevice)
6589 targetCapturedOp, runtimeAttrs);
6597 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6598 llvm::Value *value = moduleTranslation.
lookupValue(var);
6599 moduleTranslation.
mapValue(arg, value);
6601 if (!llvm::isa<llvm::Constant>(value))
6602 kernelInput.push_back(value);
6605 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6612 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6613 kernelInput.push_back(mapData.OriginalValue[i]);
6618 moduleTranslation, dds);
6620 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6622 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6624 llvm::OpenMPIRBuilder::TargetDataInfo info(
6628 auto customMapperCB =
6630 if (!combinedInfos.Mappers[i])
6632 info.HasMapper =
true;
6634 moduleTranslation, targetDirective);
6637 llvm::Value *ifCond =
nullptr;
6638 if (
Value targetIfCond = targetOp.getIfExpr())
6639 ifCond = moduleTranslation.
lookupValue(targetIfCond);
6641 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6643 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6644 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6645 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6650 builder.restoreIP(*afterIP);
6671 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6672 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6674 if (!offloadMod.getIsTargetDevice())
6677 omp::DeclareTargetDeviceType declareType =
6678 attribute.getDeviceType().getValue();
6680 if (declareType == omp::DeclareTargetDeviceType::host) {
6681 llvm::Function *llvmFunc =
6683 llvmFunc->dropAllReferences();
6684 llvmFunc->eraseFromParent();
6690 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6691 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6692 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6694 bool isDeclaration = gOp.isDeclaration();
6695 bool isExternallyVisible =
6698 llvm::StringRef mangledName = gOp.getSymName();
6699 auto captureClause =
6705 std::vector<llvm::GlobalVariable *> generatedRefs;
6707 std::vector<llvm::Triple> targetTriple;
6708 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6710 LLVM::LLVMDialect::getTargetTripleAttrName()));
6711 if (targetTripleAttr)
6712 targetTriple.emplace_back(targetTripleAttr.data());
6714 auto fileInfoCallBack = [&loc]() {
6715 std::string filename =
"";
6716 std::uint64_t lineNo = 0;
6719 filename = loc.getFilename().str();
6720 lineNo = loc.getLine();
6723 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6727 auto vfs = llvm::vfs::getRealFileSystem();
6729 ompBuilder->registerTargetGlobalVariable(
6730 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6731 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6732 mangledName, generatedRefs,
false, targetTriple,
6734 gVal->getType(), gVal);
6736 if (ompBuilder->Config.isTargetDevice() &&
6737 (attribute.getCaptureClause().getValue() !=
6738 mlir::omp::DeclareTargetCaptureClause::to ||
6739 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6740 ompBuilder->getAddrOfDeclareTargetVar(
6741 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6742 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6743 mangledName, generatedRefs,
false, targetTriple,
6744 gVal->getType(),
nullptr,
6765 if (mlir::isa<omp::ThreadprivateOp>(op))
6768 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
6769 mlir::isa<omp::TargetFreeMemOp>(op))
6773 if (
auto declareTargetIface =
6774 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6775 parentFn.getOperation()))
6776 if (declareTargetIface.isDeclareTarget() &&
6777 declareTargetIface.getDeclareTargetDeviceType() !=
6778 mlir::omp::DeclareTargetDeviceType::host)
6785 llvm::Module *llvmModule) {
6786 llvm::Type *i64Ty = builder.getInt64Ty();
6787 llvm::Type *i32Ty = builder.getInt32Ty();
6788 llvm::Type *returnType = builder.getPtrTy(0);
6789 llvm::FunctionType *fnType =
6790 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
6791 llvm::Function *
func = cast<llvm::Function>(
6792 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
6799 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6804 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6808 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6810 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6811 mlir::Type heapTy = allocMemOp.getAllocatedType();
6812 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
6813 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6814 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6815 for (
auto typeParam : allocMemOp.getTypeparams())
6817 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
6819 llvm::CallInst *call =
6820 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6821 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6824 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
6829 llvm::Module *llvmModule) {
6830 llvm::Type *ptrTy = builder.getPtrTy(0);
6831 llvm::Type *i32Ty = builder.getInt32Ty();
6832 llvm::Type *voidTy = builder.getVoidTy();
6833 llvm::FunctionType *fnType =
6834 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
6835 llvm::Function *
func = dyn_cast<llvm::Function>(
6836 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
6843 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6848 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
6852 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
6855 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
6857 llvm::Value *intToPtr =
6858 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
6859 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
6876 bool isOutermostLoopWrapper =
6877 isa_and_present<omp::LoopWrapperInterface>(op) &&
6878 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
6880 if (isOutermostLoopWrapper)
6881 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
6885 .Case([&](omp::BarrierOp op) -> LogicalResult {
6889 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6890 ompBuilder->createBarrier(builder.saveIP(),
6891 llvm::omp::OMPD_barrier);
6893 if (res.succeeded()) {
6896 builder.restoreIP(*afterIP);
6900 .Case([&](omp::TaskyieldOp op) {
6904 ompBuilder->createTaskyield(builder.saveIP());
6907 .Case([&](omp::FlushOp op) {
6919 ompBuilder->createFlush(builder.saveIP());
6922 .Case([&](omp::ParallelOp op) {
6925 .Case([&](omp::MaskedOp) {
6928 .Case([&](omp::MasterOp) {
6931 .Case([&](omp::CriticalOp) {
6934 .Case([&](omp::OrderedRegionOp) {
6937 .Case([&](omp::OrderedOp) {
6940 .Case([&](omp::WsloopOp) {
6943 .Case([&](omp::SimdOp) {
6946 .Case([&](omp::AtomicReadOp) {
6949 .Case([&](omp::AtomicWriteOp) {
6952 .Case([&](omp::AtomicUpdateOp op) {
6955 .Case([&](omp::AtomicCaptureOp op) {
6958 .Case([&](omp::CancelOp op) {
6961 .Case([&](omp::CancellationPointOp op) {
6964 .Case([&](omp::SectionsOp) {
6967 .Case([&](omp::SingleOp op) {
6970 .Case([&](omp::TeamsOp op) {
6973 .Case([&](omp::TaskOp op) {
6976 .Case([&](omp::TaskloopOp op) {
6979 .Case([&](omp::TaskgroupOp op) {
6982 .Case([&](omp::TaskwaitOp op) {
6985 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6986 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6987 omp::CriticalDeclareOp>([](
auto op) {
7000 .Case([&](omp::ThreadprivateOp) {
7003 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7004 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
7007 .Case([&](omp::TargetOp) {
7010 .Case([&](omp::DistributeOp) {
7013 .Case([&](omp::LoopNestOp) {
7016 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7023 .Case([&](omp::NewCliOp op) {
7028 .Case([&](omp::CanonicalLoopOp op) {
7031 .Case([&](omp::UnrollHeuristicOp op) {
7040 .Case([&](omp::TileOp op) {
7041 return applyTile(op, builder, moduleTranslation);
7043 .Case([&](omp::TargetAllocMemOp) {
7046 .Case([&](omp::TargetFreeMemOp) {
7051 <<
"not yet implemented: " << inst->
getName();
7054 if (isOutermostLoopWrapper)
7069 if (isa<omp::TargetOp>(op))
7071 if (isa<omp::TargetDataOp>(op))
7075 if (isa<omp::TargetOp>(oper)) {
7080 if (isa<omp::TargetDataOp>(oper)) {
7090 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
7091 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
7092 !oper->getRegions().empty()) {
7093 if (
auto blockArgsIface =
7094 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
7100 if (isa<mlir::omp::AtomicUpdateOp>(oper))
7101 for (
auto [operand, arg] :
7102 llvm::zip_equal(oper->getOperands(),
7103 oper->getRegion(0).getArguments())) {
7105 arg, builder.CreateLoad(
7111 if (
auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
7112 assert(builder.GetInsertBlock() &&
7113 "No insert block is set for the builder");
7114 for (
auto iv : loopNest.getIVs()) {
7117 iv, llvm::PoisonValue::get(
7122 for (
Region ®ion : oper->getRegions()) {
7129 region, oper->getName().getStringRef().str() +
".fake.region",
7130 builder, moduleTranslation, &phis);
7134 builder.SetInsertPoint(
result.get(),
result.get()->end());
7141 }).wasInterrupted();
7142 return failure(interrupted);
7149class OpenMPDialectLLVMIRTranslationInterface
7150 :
public LLVMTranslationDialectInterface {
7157 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7158 LLVM::ModuleTranslation &moduleTranslation)
const final;
7163 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7164 NamedAttribute attribute,
7165 LLVM::ModuleTranslation &moduleTranslation)
const final;
7170LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7171 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7172 NamedAttribute attribute,
7173 LLVM::ModuleTranslation &moduleTranslation)
const {
7174 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7176 .Case(
"omp.is_target_device",
7177 [&](Attribute attr) {
7178 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7179 llvm::OpenMPIRBuilderConfig &
config =
7181 config.setIsTargetDevice(deviceAttr.getValue());
7187 [&](Attribute attr) {
7188 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7189 llvm::OpenMPIRBuilderConfig &
config =
7191 config.setIsGPU(gpuAttr.getValue());
7196 .Case(
"omp.host_ir_filepath",
7197 [&](Attribute attr) {
7198 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7199 llvm::OpenMPIRBuilder *ompBuilder =
7201 auto VFS = llvm::vfs::getRealFileSystem();
7202 ompBuilder->loadOffloadInfoMetadata(*VFS,
7203 filepathAttr.getValue());
7209 [&](Attribute attr) {
7210 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7214 .Case(
"omp.version",
7215 [&](Attribute attr) {
7216 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7217 llvm::OpenMPIRBuilder *ompBuilder =
7219 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
7220 versionAttr.getVersion());
7225 .Case(
"omp.declare_target",
7226 [&](Attribute attr) {
7227 if (
auto declareTargetAttr =
7228 dyn_cast<omp::DeclareTargetAttr>(attr))
7233 .Case(
"omp.requires",
7234 [&](Attribute attr) {
7235 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7236 using Requires = omp::ClauseRequires;
7237 Requires flags = requiresAttr.getValue();
7238 llvm::OpenMPIRBuilderConfig &
config =
7240 config.setHasRequiresReverseOffload(
7241 bitEnumContainsAll(flags, Requires::reverse_offload));
7242 config.setHasRequiresUnifiedAddress(
7243 bitEnumContainsAll(flags, Requires::unified_address));
7244 config.setHasRequiresUnifiedSharedMemory(
7245 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7246 config.setHasRequiresDynamicAllocators(
7247 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7252 .Case(
"omp.target_triples",
7253 [&](Attribute attr) {
7254 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7255 llvm::OpenMPIRBuilderConfig &
config =
7257 config.TargetTriples.clear();
7258 config.TargetTriples.reserve(triplesAttr.size());
7259 for (Attribute tripleAttr : triplesAttr) {
7260 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7261 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7269 .Default([](Attribute) {
7279LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7280 Operation *op, llvm::IRBuilderBase &builder,
7281 LLVM::ModuleTranslation &moduleTranslation)
const {
7284 if (ompBuilder->Config.isTargetDevice()) {
7294 registry.
insert<omp::OpenMPDialect>();
7296 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars